#!/usr/bin/env python3

import contextlib
import os
import re
import sys

from urllib.request import urlopen

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
SELF_PATH = os.path.relpath(__file__, PROJECT_ROOT).replace('\\', '/')

syscalls = {}
archs = ['arm64', 'arm32', 'x86', 'x86_64']


def print_tables():
  print('// DO NOT EDIT. Auto-generated by %s' % SELF_PATH)
  print('#ifndef SRC_KERNEL_UTILS_SYSCALL_TABLE_GENERATED_H_')
  print('#define SRC_KERNEL_UTILS_SYSCALL_TABLE_GENERATED_H_')
  print('')
  print('#include <stdint.h>')
  print('')
  print('#include "src/kernel_utils/syscall_table.h"')
  print('')
  print('namespace perfetto {')
  print('')
  # First of all merge all the syscall names from all tables and emit one big
  # string with unique names.
  syscall_names = set()
  for arch in archs:
    tab = syscalls[arch]
    for i in range(max(tab.keys()) + 1):
      syscall_names.add(tab.get(i, ''))
  print('constexpr char kAllSyscalllNames[] =')
  offset_by_name = {}
  last_off = 0
  for syscall_name in sorted(syscall_names):
    if syscall_name in offset_by_name:
      continue
    print('  "%s\\0"' % syscall_name)
    offset_by_name[syscall_name] = last_off
    last_off += len(syscall_name) + 1
  if last_off >= 0xffff:
    raise Exception('SyscallTable::OffT must be increased to uint32_t')
  print(';\n')
  # Then print the individual tables.
  for arch in archs:
    tab = syscalls[arch]
    print('struct SyscallTable_%s {' % arch)
    print('  static constexpr const char* names = kAllSyscalllNames;')
    print('  static constexpr SyscallTable::OffT offsets[] {')
    for i in range(max(tab.keys()) + 1):
      syscall_name = tab.get(i, '')
      print('%d, // %d: %s' % (offset_by_name[syscall_name], i, syscall_name))
    print('  };')
    print('};\n')
  print('}  // namespace perfetto')
  print('#endif  // SRC_KERNEL_UTILS_SYSCALL_TABLE_GENERATED_H_')


# Parses a .tbl file (new format).
def parse_tlb(data):
  table = {}
  for line in data.splitlines():
    line = line.strip()
    if line.startswith('#') or not (line):
      continue
    parts = line.split()
    table[int(parts[0])] = 'sys_' + parts[2]
  return table


# Parses a #define __NR_xx 1234 old-style unistd.h header.
def parse_def(data):
  table = {}
  for line in data.splitlines():
    m = re.match(r'^#define\s+__NR\d*?_(\w+)\s+(\d+)\s*$', line.strip())
    if not m or m.group(1) == 'syscalls':  # __NR_syscalls is just a sentinel.
      continue
    table[int(m.group(2))] = 'sys_' + m.group(1)
  return table


def Main():
  KSRC = 'https://raw.githubusercontent.com/torvalds/linux/v6.7/'

  response = urlopen(KSRC + 'arch/x86/entry/syscalls/syscall_64.tbl')
  syscalls['x86_64'] = parse_tlb(response.read().decode())

  response = urlopen(KSRC + 'arch/x86/entry/syscalls/syscall_32.tbl')
  syscalls['x86'] = parse_tlb(response.read().decode())

  response = urlopen(KSRC + 'arch/arm/tools/syscall.tbl')
  syscalls['arm32'] = parse_tlb(response.read().decode())

  # From:
  # arch/arm64/include/asm/unistd.h
  #   -> arch/arm64/include/uapi/asm/unistd.h
  #     -> include/uapi/asm-generic/unistd.h
  response = urlopen(KSRC + 'include/uapi/asm-generic/unistd.h')
  syscalls['arm64'] = parse_def(response.read().decode())

  dst_file = os.path.join(PROJECT_ROOT, 'src', 'kernel_utils',
                          'syscall_table_generated.h')
  tmp_file = dst_file + '.tmp'

  print('Writing ', dst_file)
  with open(tmp_file, 'w') as f:
    with contextlib.redirect_stdout(f):
      print_tables()
  os.replace(tmp_file, dst_file)

  print('Running clang-format (might fail if depot_tools isn\'t in the PATH)')
  os.system('clang-format -i ' + dst_file)


if __name__ == '__main__':
  sys.exit(Main())
