#!/usr/bin/env python3

import sys

I = 32
vectorunit = 'avx2'

if len(sys.argv) > 1:
  I = int(sys.argv[1])
  assert I in (32,64)
if len(sys.argv) > 2:
  vectorunit = sys.argv[2]
  assert vectorunit in ('neon','sse42','avx2')

# XXX: further porting requires various tweaks below
# e.g. arm32 does not have vclt_s64 even with neon
# e.g. sse2 does not have blendv_epi8

vectorbits = {'neon':128,'sse2':128,'sse42':128,'avx2':256}[vectorunit]
vectorwords = vectorbits//I

# XXX: do more searches for optimal parameters here
if vectorwords == 8:
  unrolledsort = (8,16),(16,32),(32,64),(64,128)
  unrolledsortxor = 64,128
  unrolledVsort = (8,16),(16,32),(32,64)
  unrolledVsortxor = 32,64,128
  threestages_specialize = 32,
  pads = (161,191),(193,255) # (449,511) is also interesting
  favorxor = True
elif vectorwords == 4:
  if I == 32:
    unrolledsort = (8,16),(16,32)
  else:
    unrolledsort = (8,16),(16,32),(32,64)
  unrolledsortxor = 32,64
  unrolledVsort = (8,16),(16,32),(32,64)
  unrolledVsortxor = 16,32,64
  threestages_specialize = 16,
  pads = (81,95),(97,127) # (225,255) is also interesting
  favorxor = True
elif vectorwords == 2:
  unrolledsort = (8,16),(16,32),
  unrolledsortxor = 32,
  unrolledVsort = (8,16),(16,32),(32,64)
  unrolledVsortxor = 16,32,64
  threestages_specialize = 16,
  pads = ()
  favorxor = True

if vectorunit != 'avx2':
  pads = () # for simplicity in the absence of speed study

sort2_min = min(unrolledsortxor)
sort2 = f'sort_2poweratleast{sort2_min}'
Vsort2_min = min(unrolledVsortxor)
Vsort2 = f'V_sort_2poweratleast{Vsort2_min}'

allowmaskload = False # fast on Intel, and on AMD starting with Zen 2, _but_ AMD documentation claims that it can fail
allowmaskstore = False # fast on Intel but not on AMD
usepartialmem = True # irrelevant unless allowmaskload or allowmaskstore
partiallimit = 64 # XXX: could allow 128 for non-mask versions without extra operations

assert vectorwords&(vectorwords-1) == 0
intI = f'int{I}'
intIvec = f'{intI}x{vectorwords}'
int8vec = f'int8x{(I*vectorwords)//8}'
int32vec = f'int32x{(I*vectorwords)//32}'

def preamble():
  print('/* WARNING: auto-generated (by autogen/sort); do not edit */')
  print('')
  if vectorunit == 'neon':
    print('#include <arm_neon.h>')
  else:
    print('#include <immintrin.h>')
  print('')

  print(fr'''#include "{intI}_sort.h"
#define {intI} {intI}_t
#define {intI}_largest {hex((1<<(I-1))-1)}

#include "crypto_{intI}.h"
#define {intI}_min crypto_{intI}_min
#define {intI}_MINMAX(a,b) crypto_{intI}_minmax(&(a),&(b))

#define NOINLINE __attribute__((noinline))
''')

  if vectorunit == 'neon':
    print(f'''#include "crypto_int8.h"
#define int8 crypto_int8
#define int8_min crypto_int8_min
#define int8x16 int8x16_t

#include "crypto_uint8.h"
#define uint8 crypto_uint8
#define uint8x16 uint8x16_t
''')
    if I != 32:
      print(f'''#include "crypto_int32.h"
#define int32 crypto_int32
#define int32_min crypto_int32_min
#define int32x4 int32x4_t
''')

    print(f'#define {intIvec} {intIvec}_t')
    print(f'#define u{intIvec} u{intIvec}_t')

    # XXX: should also look for ways to use ld2 etc
    print(fr'''#define {intIvec}_load vld1q_s{I}
#define {intIvec}_store vst1q_s{I}
#define {intIvec}_ifthenelse vbslq_s{I}
''')

    if I == 32:
      print(fr'''#define {intIvec}_smaller_umask vcltq_s{I}
#define {intIvec}_min vminq_s{I}
#define {intIvec}_max vmaxq_s{I}
#define {intIvec}_MINMAX(a,b) \
do {{ \
  {intIvec} c = {intIvec}_min(a,b); \
  b = {intIvec}_max(a,b); \
  a = c; \
}} while(0)
''')

    else:
      print(fr'''#define {int32vec}_smaller_umask vcltq_s32
#define {intIvec}_smaller_umask vcltq_s{I}

#define {intIvec}_MINMAX(a,b) \
do {{ \
  u{intIvec} t = {intIvec}_smaller_umask(a,b); \
  {intIvec} c = {intIvec}_ifthenelse(t,a,b); \
  b = {intIvec}_ifthenelse(t,b,a); \
  a = c; \
}} while(0)
''')

    # XXX: tweak varextract name to reflect differences in out-of-range handling
    # XXX: also use tbx for infty etc
    print(f'''#define {int8vec}_load vld1q_s8
#define {int8vec}_varextract vqtbl1q_s8
#define {int8vec}_add vaddq_s8
#define {int8vec}_sub vsubq_s8
#define {int8vec}_broadcast vdupq_n_s8
#define u{int8vec}_load vld1q_u8
#define u{int8vec}_add vaddq_u8
#define u{int8vec}_sub vsubq_u8
#define u{int8vec}_broadcast vdupq_n_u8
#define {int8vec}_from_{intIvec} vreinterpretq_s8_s{I}
#define u{intIvec}_from_{intIvec} vreinterpretq_u{I}_s{I}
#define {intIvec}_from_u{intIvec} vreinterpretq_s{I}_u{I}
#define {intIvec}_from_{int8vec} vreinterpretq_s{I}_s8
#define {int32vec}_load vld1q_s32
#define {int32vec}_add vaddq_s32
#define {int32vec}_sub vsubq_s32
#define {intIvec}_broadcast vdupq_n_s{I}

static inline u{int8vec} u{int8vec}_set(uint8 x0,uint8 x1,uint8 x2,uint8 x3,uint8 x4,uint8 x5,uint8 x6,uint8 x7,uint8 x8,uint8 x9,uint8 x10,uint8 x11,uint8 x12,uint8 x13,uint8 x14,uint8 x15)
{{
  uint8 x[16] = {{x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15}};
  return u{int8vec}_load(x);
}}
''')

    if I == 32:
      print(f'''static inline {int32vec} {int32vec}_set(int32 x0,int32 x1,int32 x2,int32 x3)
{{
  int32 x[4] = {{x0,x1,x2,x3}};
  return {int32vec}_load(x);
}}
''')

    if I == 64:
      print(f'''static inline {intIvec} {intIvec}_set(int{I} x0,int{I} x1)
{{
  int{I} x[2] = {{x0,x1}};
  return {intIvec}_load(x);
}}
''')

    if I == 32:
      # XXX: maybe better to use vswp for 2301
      print(f'''#define {intIvec}_1032 vrev64q_s32
#define {intIvec}_2301(v) vextq_s32(v,v,2)
#define {intIvec}_3210(v) {intIvec}_1032({intIvec}_2301(v))
#define {intIvec}_a0b0a2b2 vtrn1q_s32
#define {intIvec}_a1b1a3b3 vtrn2q_s32
#define {intIvec}_a0b0a1b1 vzip1q_s32
#define {intIvec}_a2b2a3b3 vzip2q_s32
#define {intIvec}_leftleft(a,b) vreinterpretq_s32_s64(vzip1q_s64(vreinterpretq_s64_s32(a),vreinterpretq_s64_s32(b)))
#define {intIvec}_rightright(a,b) vreinterpretq_s32_s64(vzip2q_s64(vreinterpretq_s64_s32(a),vreinterpretq_s64_s32(b)))
''')
    else:
      print(f'''#define {intIvec}_10(v) vextq_s64(v,v,1)
#define {intIvec}_leftleft vzip1q_s64
#define {intIvec}_rightright vzip2q_s64
''')

  else:
    mmprefix = {'sse2':'_mm_','sse42':'_mm_','avx2':'_mm256_'}[vectorunit]

    print(fr'''typedef __m{vectorbits}i {intIvec};
#define {intIvec}_load(z) {mmprefix}loadu_si{vectorbits}((__m{vectorbits}i *) (z))
#define {intIvec}_store(z,i) {mmprefix}storeu_si{vectorbits}((__m{vectorbits}i *) (z),(i))

#define {intIvec}_smaller_mask(a,b) {mmprefix}cmpgt_epi{I}(b,a)
#define {intIvec}_add {mmprefix}add_epi{I}
#define {intIvec}_sub {mmprefix}sub_epi{I}''')

    print(f'#define {int8vec}_iftopthenelse(c,t,e) {mmprefix}blendv_epi8(e,t,c)')

    if vectorunit == 'sse42':
      print(fr'''#define {intIvec}_leftleft(a,b) {mmprefix}unpacklo_epi64(a,b)
#define {intIvec}_rightright(a,b) {mmprefix}unpackhi_epi64(a,b)''')
      if I == 32:
        print(fr'''#define {intIvec}_a0b0a1b1(a,b) {mmprefix}unpacklo_epi32(a,b)
#define {intIvec}_a2b2a3b3(a,b) {mmprefix}unpackhi_epi32(a,b)''')

    if vectorunit == 'avx2':
      print(fr'''#define {intIvec}_leftleft(a,b) {mmprefix}permute2x128_si256(a,b,0x20)
#define {intIvec}_rightright(a,b) {mmprefix}permute2x128_si256(a,b,0x31)''')

    if I == 32:
      print(fr'''
#define {intIvec}_MINMAX(a,b) \
do {{ \
  {intIvec} c = {intIvec}_min(a,b); \
  b = {intIvec}_max(a,b); \
  a = c; \
}} while(0)
''')

    if I == 64:
      print(fr'''
#define {intIvec}_MINMAX(a,b) \
do {{ \
  {intIvec} t = {intIvec}_smaller_mask(a,b); \
  {intIvec} c = {int8vec}_iftopthenelse(t,a,b); \
  b = {int8vec}_iftopthenelse(t,b,a); \
  a = c; \
}} while(0)
''')

    if I == 32:
      print(fr'''#define {intIvec}_min {mmprefix}min_epi{I}
#define {intIvec}_max {mmprefix}max_epi{I}
#define {intIvec}_set {mmprefix}setr_epi{I}
#define {intIvec}_broadcast {mmprefix}set1_epi{I}''')

    if (I,vectorwords) == (32,4):
      print(fr'''#define {int8vec}_add {mmprefix}add_epi8
#define {int8vec}_sub {mmprefix}sub_epi8
#define {int8vec}_set {mmprefix}setr_epi8
#define {int8vec}_broadcast {mmprefix}set1_epi8
#define {int8vec}_varextract {mmprefix}shuffle_epi8
#define {int32vec}_add {mmprefix}add_epi32
#define {int32vec}_sub {mmprefix}sub_epi32
#define {int32vec}_set {mmprefix}setr_epi32
#define {int32vec}_broadcast {mmprefix}set1_epi32
#define {intIvec}_extract(v,p0,p1,p2,p3) {mmprefix}shuffle_epi{I}(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_ab0ab1ab2ab3(a,b,p0,p1,p2,p3) {mmprefix}castps_si128({mmprefix}blend_ps({mmprefix}castsi128_ps(a),{mmprefix}castsi128_ps(b),(p0)|((p1)<<1)|((p2)<<2)|((p3)<<3)))
#define {intIvec}_1032(v) {intIvec}_extract(v,1,0,3,2)
#define {intIvec}_2301(v) {intIvec}_extract(v,2,3,0,1)
#define {intIvec}_3210(v) {intIvec}_extract(v,3,2,1,0)

#include "crypto_int8.h"
#define int8_min crypto_int8_min
''')

    if (I,vectorwords) == (64,2):
      print(fr'''#define {int8vec}_add {mmprefix}add_epi8
#define {int8vec}_sub {mmprefix}sub_epi8
#define {int8vec}_set {mmprefix}setr_epi8
#define {int8vec}_broadcast {mmprefix}set1_epi8
#define {int8vec}_varextract {mmprefix}shuffle_epi8
#define {int32vec}_add {mmprefix}add_epi32
#define {int32vec}_sub {mmprefix}sub_epi32
#define {int32vec}_set {mmprefix}setr_epi32
#define {int32vec}_broadcast {mmprefix}set1_epi32
#define {intIvec}_extract(v,p0,p1) {mmprefix}shuffle_epi32(v,_MM_SHUFFLE(2*(p1)+1,2*(p1),2*(p0)+1,2*(p0)))
#define {intIvec}_set(a,b) {mmprefix}set_epi64x(b,a)
#define {intIvec}_broadcast {mmprefix}set1_epi64x
#define {intIvec}_10(v) {intIvec}_extract(v,1,0)

#include "crypto_int8.h"
#define int8_min crypto_int8_min
#include "crypto_int32.h"
#define int32_min crypto_int32_min
''')

    if (I,vectorwords) == (32,8):
      print(fr'''#define {intIvec}_varextract {mmprefix}permutevar8x32_epi32
#define {intIvec}_extract(v,p0,p1,p2,p3,p4,p5,p6,p7) {intIvec}_varextract(v,{mmprefix}setr_epi{I}(p0,p1,p2,p3,p4,p5,p6,p7))
#define {intIvec}_constextract_eachside(v,p0,p1,p2,p3) {mmprefix}shuffle_epi{I}(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_aabb_eachside(a,b,p0,p1,p2,p3) {mmprefix}castps_si256({mmprefix}shuffle_ps({mmprefix}castsi256_ps(a),{mmprefix}castsi256_ps(b),_MM_SHUFFLE(p3,p2,p1,p0)))
#define {intIvec}_ifconstthenelse(c0,c1,c2,c3,c4,c5,c6,c7,t,e) {mmprefix}blend_epi{I}(e,t,(c0)|((c1)<<1)|((c2)<<2)|((c3)<<3)|((c4)<<4)|((c5)<<5)|((c6)<<6)|((c7)<<7))
''')

    if (I,vectorwords) == (64,4):
      print(fr'''#define {int32vec}_add {mmprefix}add_epi32
#define {int32vec}_sub {mmprefix}sub_epi32
#define {int32vec}_set {mmprefix}setr_epi32
#define {int32vec}_broadcast {mmprefix}set1_epi32
#define {int32vec}_varextract {mmprefix}permutevar8x32_epi32
#define {intIvec}_set {mmprefix}setr_epi64x
#define {intIvec}_broadcast {mmprefix}set1_epi64x
#define {intIvec}_extract(v,p0,p1,p2,p3) {mmprefix}permute4x64_epi64(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_eachside(v,p0,p1) {mmprefix}shuffle_epi32(v,_MM_SHUFFLE(2*(p1)+1,2*(p1),2*(p0)+1,2*(p0)))
#define {intIvec}_constextract_a01b01a23b23(a,b,p0,p1,p2,p3) {mmprefix}castpd_si256({mmprefix}shuffle_pd({mmprefix}castsi256_pd(a),{mmprefix}castsi256_pd(b),(p0)|((p1)<<1)|((p2)<<2)|((p3)<<3)))
#define {intIvec}_ifconstthenelse(c0,c1,c2,c3,t,e) {mmprefix}blend_epi32(e,t,(c0)|((c0)<<1)|((c1)<<2)|((c1)<<3)|((c2)<<4)|((c2)<<5)|((c3)<<6)|((c3)<<7))
#define {intIvec}_1032(v) {intIvec}_extract(v,1,0,3,2)
#define {intIvec}_2301(v) {intIvec}_extract(v,2,3,0,1)
#define {intIvec}_3210(v) {intIvec}_extract(v,3,2,1,0)

#include "crypto_int32.h"
#define int32_min crypto_int32_min
''')

    # XXX: can skip some of the macros above if allowmaskload and allowmaskstore
    if usepartialmem and (allowmaskload or allowmaskstore):
      print(fr'''#define partialmem (partialmem_storage+64)
static const {intI} partialmem_storage[] __attribute__((aligned(128))) = {{
  -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
  -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
  -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
  -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
}} ;

#define {intIvec}_partialload(p,z) {mmprefix}maskload_epi{I}((void *) (z),p)
#define {intIvec}_partialstore(p,z,i) {mmprefix}maskstore_epi{I}((void *) (z),p,(i))
''')

# ===== serial fixed-size networks

def smallsize():
  print(fr'''NOINLINE
static void {intI}_sort_3through7({intI} *x,long long n)
{{
  if (n >= 4) {{
    {intI} x0 = x[0];
    {intI} x1 = x[1];
    {intI} x2 = x[2];
    {intI} x3 = x[3];
    {intI}_MINMAX(x0,x1);
    {intI}_MINMAX(x2,x3);
    {intI}_MINMAX(x0,x2);
    {intI}_MINMAX(x1,x3);
    {intI}_MINMAX(x1,x2);
    if (n >= 5) {{
      if (n == 5) {{
        {intI} x4 = x[4];
        {intI}_MINMAX(x0,x4);
        {intI}_MINMAX(x2,x4);
        {intI}_MINMAX(x1,x2);
        {intI}_MINMAX(x3,x4);
        x[4] = x4;
      }} else {{
        {intI} x4 = x[4];
        {intI} x5 = x[5];
        {intI}_MINMAX(x4,x5);
        if (n == 6) {{
          {intI}_MINMAX(x0,x4);
          {intI}_MINMAX(x2,x4);
          {intI}_MINMAX(x1,x5);
          {intI}_MINMAX(x3,x5);
        }} else {{
          {intI} x6 = x[6];
          {intI}_MINMAX(x4,x6);
          {intI}_MINMAX(x5,x6);
          {intI}_MINMAX(x0,x4);
          {intI}_MINMAX(x2,x6);
          {intI}_MINMAX(x2,x4);
          {intI}_MINMAX(x1,x5);
          {intI}_MINMAX(x3,x5);
          {intI}_MINMAX(x5,x6);
          x[6] = x6;
        }}
        {intI}_MINMAX(x1,x2);
        {intI}_MINMAX(x3,x4);
        x[4] = x4;
        x[5] = x5;
      }}
    }}
    x[0] = x0;
    x[1] = x1;
    x[2] = x2;
    x[3] = x3;
  }} else {{
    {intI} x0 = x[0];
    {intI} x1 = x[1];
    {intI} x2 = x[2];
    {intI}_MINMAX(x0,x1);
    {intI}_MINMAX(x0,x2);
    {intI}_MINMAX(x1,x2);
    x[0] = x0;
    x[1] = x1;
    x[2] = x2;
  }}
}}
''')

# ===== vectorized fixed-size networks

class vectorsortingnetwork:
  def __init__(self,layout,xor=True):
    layout = [list(x) for x in layout]
    for x in layout: assert len(x) == vectorwords
    assert len(layout)%2 == 0 # XXX: drop this restriction?
    self.layout = layout
    self.phys = [f'x{r}' for r in range(len(layout))]
    # arch register r is stored in self.phys[r]
    self.operations = []
    self.usedregs = set(self.phys)
    self.usedregssmallindex = set()
    self.usexor = xor
    self.useinfty = False
    if xor:
      self.assign('vecxor',f'{intIvec}_broadcast(xor)')

  def print(self):
    print('{')
    if len(self.usedregssmallindex) > 0:
      print(f'  int32_t {",".join(sorted(self.usedregssmallindex))};')
    if len(self.usedregs) > 0:
      print(f'  {intIvec} {",".join(sorted(self.usedregs))};')
    for op in self.operations:
      if op[0] == 'assign':
        phys,result,newlayout = op[1:]
        if newlayout is None:
          print(f'  {phys} = {result};')
        else:
          print(f'  {phys} = {result}; // {" ".join(map(str,newlayout))}')
      elif op[0] == 'assignsmallindex':
        phys,result = op[1:]
        print(f'  {phys} = {result};')
      elif op[0] == 'store':
        result, = op[1:]
        print(f'  {result};')
      elif op[0] == 'comment':
        c, = op[1:]
        print(f'  // {c}')
      else:
        raise Exception(f'unrecognized operation {op}')
    print('}')
    print('')

  def allocate(self,r):
    '''Assign a new free physical register to arch register r. OK for caller to also use old register until calling allocate again.'''
    self.phys[r] = {'x':'y','y':'x'}[self.phys[r][0]]+self.phys[r][1:]
    # XXX: for generating low-level asm would instead want to cycle between fewer regs
    # but, for generating C (or qhasm), trusting register allocator makes generated assignments a bit more readable
    # (at least for V_sort)

  def comment(self,c):
    self.operations.append(('comment',c))

  def assign(self,phys,result,newlayout=None):
    self.usedregs.add(phys)
    self.operations.append(('assign',phys,result,newlayout))

  def assignsmallindex(self,phys,result):
    self.usedregssmallindex.add(phys)
    self.operations.append(('assignsmallindex',phys,result))

  def createinfty(self):
    if self.useinfty: return
    self.useinfty = True
    self.assign('infty',f'{intIvec}_broadcast({intI}_largest)')

  def partialmask(self,offset):
    if usepartialmem:
      return f'{intIvec}_load(&partialmem[{offset}-n])'
    # XXX: also do version with offset in constants rather than n?
    rangevectorwords = ','.join(map(str,range(vectorwords)))
    return f'{intIvec}_smaller_mask({intIvec}_set({rangevectorwords}),{intIvec}_broadcast(n-{offset}))'

  def load(self,r,partial=False):
    rphys = self.phys[r]
    xor = 'vecxor^' if self.usexor else ''
    if not partial:
      self.assign(rphys,f'{xor}{intIvec}_load(x+{vectorwords*r})',self.layout[r])
      return
    # uint instead of int would slightly streamline infty usage
    self.createinfty()
    if allowmaskload:
      self.assign(f'partial{r}',f'{self.partialmask(r*vectorwords)}')
      self.assign(rphys,f'{xor}{intIvec}_partialload(partial{r},x+{vectorwords*r})',self.layout[r])
      self.assign(rphys,f'{int8vec}_iftopthenelse(partial{r},{rphys},infty)')
      return

    self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorwords},n)')
    xdata = f'{intIvec}_load(x+pos{r}-{vectorwords})'

    if vectorbits == 128:
      if vectorunit == 'neon':
        mplus = ','.join(map(str,range(0,vectorbits//8)))
        rotated = f'{intIvec}_from_{int8vec}({int8vec}_varextract({int8vec}_from_{intIvec}({xdata}),u{int8vec}_add(u{int8vec}_set({mplus}),u{int8vec}_broadcast({I//8}*((-pos{r})&{(vectorbits//I-1)})))))'
      else:
        mplus = ','.join(map(str,range(vectorbits//8,2*vectorbits//8)))
        rotated = f'{int8vec}_varextract({xdata},{int8vec}_sub({int8vec}_set({mplus}),{int8vec}_broadcast({I//8}*(pos{r}&{(vectorbits//I-1)}))))'
      mplus = ','.join(map(str,range(r*vectorwords,(r+1)*vectorwords)))
      if vectorunit == 'neon':
        control = f'{intIvec}_smaller_umask({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))'
      else:
        control = f'{intIvec}_smaller_mask({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))'
    elif (I,vectorwords) == (32,8):
      mplus = ','.join(map(str,range(r*vectorwords,(r+1)*vectorwords)))
      self.assign(f'diff{r}',f'{intIvec}_sub({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))')
      rotated = f'{intIvec}_varextract({xdata},diff{r})'
      control = f'diff{r}'
    elif (I,vectorwords) == (64,4):
      mplus = ','.join(map(str,range(2*r*vectorwords,2*(r+1)*vectorwords)))
      self.assign(f'diff{r}',f'{int32vec}_sub({int32vec}_set({mplus}),{int32vec}_broadcast(2*pos{r}))')
      rotated = f'{int32vec}_varextract({xdata},diff{r})'
      control = f'diff{r}'
    else:
      raise Exception('unhandled partial load')

    if vectorunit == 'neon':
      self.assign(rphys,f'{intIvec}_ifthenelse({control},{rotated},infty)',self.layout[r])
    else:
      self.assign(rphys,f'{int8vec}_iftopthenelse({control},{rotated},infty)',self.layout[r])

  # warning: must do store in reverse order
  def store(self,r,partial=False):
    rphys = self.phys[r]
    xor = 'vecxor^' if self.usexor else ''
    if not partial:
      self.operations.append(('store',f'{intIvec}_store(x+{vectorwords*r},{xor}{rphys})'))
      return
    if allowmaskstore:
      if not allowmaskload:
        self.assign(f'partial{r}',f'{self.partialmask(r*vectorwords)}')
      self.operations.append(('store',f'{intIvec}_partialstore(partial{r},x+{vectorwords*r},{xor}{rphys})'))
      return
    if allowmaskload:
      self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorwords},n)')

    # this is why store has to be in reverse order
    if vectorbits == 128:
      if vectorunit == 'neon':
        mplus = ','.join(map(str,range(0,vectorbits//8)))
        storeshift = f'u{int8vec}_sub(u{int8vec}_set({mplus}),u{int8vec}_broadcast({I//8}*((-pos{r})&{(vectorbits//I-1)})))'
        xdata = f'{intIvec}_from_{int8vec}({int8vec}_varextract({int8vec}_from_{intIvec}({xor}{rphys}),{storeshift}))'
      else:
        mplus = ','.join(map(str,range(vectorbits//8,2*vectorbits//8)))
        storeshift = f'{int8vec}_add({int8vec}_set({mplus}),{int8vec}_broadcast({I//8}*(pos{r}&{(vectorbits//I-1)})))'
        xdata = f'{int8vec}_varextract({xor}{rphys},{storeshift})'
    elif (I,vectorwords) == (32,8):
      mplus = ','.join(map(str,range(vectorwords)))
      xdata = f'{intIvec}_varextract({xor}{rphys},{intIvec}_add({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r})))'
    elif (I,vectorwords) == (64,4):
      mplus = ','.join(map(str,range(2*vectorwords)))
      xdata = f'{int32vec}_varextract({xor}{rphys},{int32vec}_add({int32vec}_set({mplus}),{int32vec}_broadcast(2*pos{r})))'
    else:
      raise Exception('unhandled partial store')

    self.operations.append(('store',f'{intIvec}_store(x+pos{r}-{vectorwords},{xdata})'))

  def vecswap(self,r,s):
    assert r != s
    self.phys[r],self.phys[s] = self.phys[s],self.phys[r]
    self.layout[r],self.layout[s] = self.layout[s],self.layout[r]

  def minmax(self,r,s):
    assert r != s
    rphys = self.phys[r]
    sphys = self.phys[s]
    newlayout0 = [min(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])]
    newlayout1 = [max(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])]
    self.layout[r],self.layout[s] = newlayout0,newlayout1
    self.allocate(r)
    rphysnew = self.phys[r]
    if I == 32:
      self.assign(rphysnew,f'{intIvec}_min({rphys},{sphys})',newlayout0)
      self.assign(sphys,f'{intIvec}_max({rphys},{sphys})',newlayout1)
    else:
      if vectorunit == 'neon':
        self.assign('t',f'{intIvec}_from_u{intIvec}({intIvec}_smaller_umask({rphys},{sphys}))')
        self.assign(rphysnew,f'{intIvec}_ifthenelse(u{intIvec}_from_{intIvec}(t),{rphys},{sphys})',newlayout0)
        self.assign(sphys,f'{intIvec}_ifthenelse(u{intIvec}_from_{intIvec}(t),{sphys},{rphys})',newlayout1)
      else:
        # XXX: avx-512 has min_epi64 but avx2 does not
        self.assign('t',f'{intIvec}_smaller_mask({rphys},{sphys})')
        self.assign(rphysnew,f'{int8vec}_iftopthenelse(t,{rphys},{sphys})',newlayout0)
        self.assign(sphys,f'{int8vec}_iftopthenelse(t,{sphys},{rphys})',newlayout1)

  def shuffle1(self,r,L):
    r'''Rearrange layout of vector r to match L.'''
    L = list(L)
    oldL = self.layout[r]
    perm = [oldL.index(a) for a in L]
    rphys = self.phys[r]
    if vectorwords == 2 and perm == [1,0]:
      self.assign(rphys,f'{intIvec}_10({rphys})',L)
    elif vectorwords == 4 and perm == [1,0,3,2]:
      self.assign(rphys,f'{intIvec}_1032({rphys})',L)
    elif vectorwords == 4 and perm == [2,3,0,1]:
      self.assign(rphys,f'{intIvec}_2301({rphys})',L)
    elif vectorwords == 4 and perm == [3,2,1,0]:
      self.assign(rphys,f'{intIvec}_3210({rphys})',L)
    elif (I,vectorwords) == (32,8) and perm[4:] == [perm[0]+4,perm[1]+4,perm[2]+4,perm[3]+4]:
      self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L)
    elif (I,vectorwords) == (64,4) and perm[2:] == [perm[0]+2,perm[1]+2]:
      self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]})',L)
    elif vectorwords == 8 and vectorunit != 'neon':
      self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]},{perm[4]},{perm[5]},{perm[6]},{perm[7]})',L)
    elif vectorwords == 4 and vectorunit != 'neon':
      self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L)
    elif vectorwords == 2 and vectorunit != 'neon':
      self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]})',L)
    else:
      raise Exception(f'unhandled permutation from {oldL} to {L}')
    self.layout[r] = L

  def shuffle2(self,r,s,L,M,exact=False):
    oldL = self.layout[r]
    oldM = self.layout[s]
    rphys = self.phys[r]
    sphys = self.phys[s]

    try:
      assert (I,vectorwords) == (64,4)
      p0 = oldL[:2].index(L[0])
      p1 = oldM[:2].index(L[1])
      p2 = oldL[2:].index(L[2])
      p3 = oldM[2:].index(L[3])
      q0 = oldL[:2].index(M[0])
      q1 = oldM[:2].index(M[1])
      q2 = oldL[2:].index(M[2])
      q3 = oldM[2:].index(M[3])
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphysnew,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{p0},{p1},{p2},{p3})',L)
      self.assign(sphys,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{q0},{q1},{q2},{q3})',M)
      return
    except:
      pass

    try:
      assert (I,vectorwords) == (32,8)
      p0 = oldL[:4].index(L[0])
      p1 = oldL[:4].index(L[1])
      p2 = oldM[:4].index(L[2])
      p3 = oldM[:4].index(L[3])
      assert p0 == oldL[4:].index(L[4])
      assert p1 == oldL[4:].index(L[5])
      assert p2 == oldM[4:].index(L[6])
      assert p3 == oldM[4:].index(L[7])
      q0 = oldL[:4].index(M[0])
      q1 = oldL[:4].index(M[1])
      q2 = oldM[:4].index(M[2])
      q3 = oldM[:4].index(M[3])
      assert q0 == oldL[4:].index(M[4])
      assert q1 == oldL[4:].index(M[5])
      assert q2 == oldM[4:].index(M[6])
      assert q3 == oldM[4:].index(M[7])
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',L)
      self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',M)
      return
    except:
      pass

    try:
      assert not exact
      assert (I,vectorwords) == (32,8)
      p0 = oldL[:4].index(L[0])
      p1 = oldL[:4].index(L[2])
      p2 = oldM[:4].index(L[1])
      p3 = oldM[:4].index(L[3])
      assert p0 == oldL[4:].index(L[4])
      assert p1 == oldL[4:].index(L[6])
      assert p2 == oldM[4:].index(L[5])
      assert p3 == oldM[4:].index(L[7])
      q0 = oldL[:4].index(M[0])
      q1 = oldL[:4].index(M[2])
      q2 = oldM[:4].index(M[1])
      q3 = oldM[:4].index(M[3])
      assert q0 == oldL[4:].index(M[4])
      assert q1 == oldL[4:].index(M[6])
      assert q2 == oldM[4:].index(M[5])
      assert q3 == oldM[4:].index(M[7])
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = [L[0],L[2],L[1],L[3],L[4],L[6],L[5],L[7]]
      self.layout[s] = [M[0],M[2],M[1],M[3],M[4],M[6],M[5],M[7]]
      self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',self.layout[r])
      self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',self.layout[s])
      return
    except:
      pass

    try:
      half = vectorwords//2
      assert oldL[:half] == L[:half]
      assert oldL[half:] == M[:half]
      assert oldM[:half] == L[half:]
      assert oldM[half:] == M[half:]
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L)
      self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M)
      return
    except:
      pass

    try:
      assert oldL == [L[2],L[0],M[2],M[0]]
      assert oldM == [L[3],L[1],M[3],M[1]]
      self.assign('t',f'{intIvec}_a0b0a1b1({rphys},{sphys})',[L[2],L[3],L[0],L[1]])
      self.assign('u',f'{intIvec}_a2b2a3b3({rphys},{sphys})',[M[2],M[3],M[0],M[1]])
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphys,f'{intIvec}_2301(t)',L)
      self.assign(sphys,f'{intIvec}_2301(u)',M)
      return
    except:
      pass

    try:
      assert vectorunit == 'neon'
      assert oldL == [L[0],M[0],L[2],M[2]]
      assert oldM == [L[1],M[1],L[3],M[3]]
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphysnew,f'{intIvec}_a0b0a2b2({rphys},{sphys})',L)
      self.assign(sphys,f'{intIvec}_a1b1a3b3({rphys},{sphys})',M)
      return
    except:
      pass

    try:
      assert vectorunit == 'neon'
      assert oldL == [L[0],L[2],M[0],M[2]]
      assert oldM == [L[1],L[3],M[1],M[3]]
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphysnew,f'{intIvec}_a0b0a1b1({rphys},{sphys})',L)
      self.assign(sphys,f'{intIvec}_a2b2a3b3({rphys},{sphys})',M)
      return
    except:
      pass

    try:
      assert vectorunit != 'neon'
      assert not exact
      # XXX: generalize this
      assert oldL == [L[0],M[0],L[2],M[2]]
      assert oldM == [L[1],M[1],L[3],M[3]]
      self.assign('t',f'{intIvec}_1032({rphys})',[M[0],L[0],M[2],L[2]])
      self.layout[r] = [L[1],L[0],L[3],L[2]]
      self.layout[s] = M
      self.assign(rphys,f'{intIvec}_constextract_ab0ab1ab2ab3(t,{sphys},1,0,1,0)',self.layout[r])
      self.assign(sphys,f'{intIvec}_constextract_ab0ab1ab2ab3(t,{sphys},0,1,0,1)',self.layout[s])
      return
    except:
      pass

    try:
      assert vectorunit != 'neon'
      # XXX: generalize this
      assert oldL == [M[0],L[0],M[2],L[2]]
      assert oldM == [M[1],L[1],M[3],L[3]]
      self.assign('t',f'{intIvec}_1032({rphys})',[L[0],M[0],L[2],M[2]])
      self.assign('u',f'{intIvec}_1032({sphys})',[L[1],M[1],L[2],M[3]])
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphysnew,f'{intIvec}_constextract_ab0ab1ab2ab3(t,{sphys},0,1,0,1)',L)
      self.assign(sphys,f'{intIvec}_constextract_ab0ab1ab2ab3(u,{rphys},1,0,1,0)',M)
      return
    except:
      pass

    for bigflip in False,True:
      try:
        assert vectorbits == 256
        if bigflip:
          half = vectorwords//2
          z0 = L[:half]+M[:half]
          z1 = L[half:]+M[half:]
        else:
          z0 = L
          z1 = M
        blend = []
        shuf0 = [None]*vectorwords
        shuf1 = [None]*vectorwords
        for i in range(vectorwords):
          if oldL[i] == z0[i]:
            blend.append(0)
            shuf0[i] = oldL.index(z1[i])
            shuf1[i] = 0
            if i < vectorwords//2:
              assert shuf0[i] < vectorwords//2
            else:
              assert shuf0[i] == shuf0[i-vectorwords//2]+vectorwords//2
          else:
            blend.append(1)
            assert oldM[i] == z1[i]
            shuf0[i] = 0
            shuf1[i] = oldM.index(z0[i])
            if i < vectorwords//2:
              assert shuf1[i] < vectorwords//2
            else:
              assert shuf1[i] == shuf1[i-vectorwords//2]+vectorwords//2
        blend = ','.join(map(str,blend))
        s0 = ','.join(map(str,shuf0[:vectorwords//2]))
        s1 = ','.join(map(str,shuf1[:vectorwords//2]))
        # XXX: encapsulate temporaries better
        self.assign('u',f'{intIvec}_constextract_eachside({sphys},{s1})')
        self.assign('t',f'{intIvec}_constextract_eachside({rphys},{s0})')
        self.assign(rphys,f'{intIvec}_ifconstthenelse({blend},u,{rphys})',L)
        self.assign(sphys,f'{intIvec}_ifconstthenelse({blend},{sphys},t)',M)
        if bigflip:
          self.allocate(r)
          rphysnew = self.phys[r]
          self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L)
          self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M)
        self.layout[r] = L
        self.layout[s] = M
        return
      except:
        pass

    raise Exception(f'unhandled permutation from {oldL},{oldM} to {L},{M}')

  def rearrange_onestep(self,comparators):
    numvectors = len(self.layout)

    for k in range(0,numvectors,2):
      if all(comparators[a] == a for a in self.layout[k]):
        if all(comparators[a] == a for a in self.layout[k+1]):
          continue

      collected = set(self.layout[k]+self.layout[k+1])
      if not all(comparators[a] in collected for a in collected):
        for j in range(numvectors):
          if all(comparators[a] in self.layout[j] for a in self.layout[k]):
            self.vecswap(j,k+1)
            return True

      if any(comparators[a] in self.layout[k] for a in self.layout[k]):
        newlayout0 = list(self.layout[k])
        newlayout1 = list(self.layout[k+1])
        for i in range(vectorwords):
          a = newlayout0[i]
          b = comparators[a]
          if b == newlayout1[i]: continue
          if b in newlayout0:
            j = newlayout0.index(b)
            assert j > i
            newlayout1[i],newlayout0[j] = newlayout0[j],newlayout1[i]
          else:
            j = newlayout1.index(b)
            assert j > i
            newlayout1[i],newlayout1[j] = newlayout1[j],newlayout1[i]
        self.shuffle2(k,k+1,newlayout0,newlayout1)
        return True

      if [comparators[a] for a in self.layout[k]] != self.layout[k+1]:
        newlayout = [comparators[a] for a in self.layout[k]]
        self.shuffle1(k+1,newlayout)
        return True

    return False

  def rearrange(self,comparators):
    comparators = dict(comparators)
    while self.rearrange_onestep(comparators): pass

def fixedsize(nlow,nhigh,xor=False,V=False):
  nlow = int(nlow)
  nhigh = int(nhigh)
  assert 0 <= nlow
  assert nlow <= nhigh
  assert nhigh-partiallimit <= nlow
  assert nhigh%vectorwords == 0
  assert nlow%vectorwords == 0

  lgn = 4
  while 2**lgn < nhigh: lgn += 1
  if nhigh < 2*vectorwords:
    raise Exception(f'unable to handle sizes below {2*vectorwords}')

  if V:
    funname = f'{intI}_V_sort_'
  else:
    funname = f'{intI}_sort_'
  funargs = f'{intI} *x'
  if nlow < nhigh:
    funname += f'{nlow}through'
    funargs += ',long long n'
  funname += f'{nhigh}'
  if xor:
    funname += '_xor'
    funargs += f',{intI} xor'
  print('NOINLINE')
  print(f'static void {funname}({funargs})')

  # ===== decide on initial layout of nodes

  numvectors = nhigh//vectorwords
  layout = {}
  if V:
    for k in range(numvectors//2):
      layout[k] = list(reversed(range(vectorwords*(numvectors//2-1-k),vectorwords*(numvectors//2-k))))
    for k in range(numvectors//2,numvectors):
      layout[k] = list(range(vectorwords*k,vectorwords*(k+1)))
  else:
    for k in range(0,numvectors,nhigh//vectorwords):
      for offset in range(nhigh//vectorwords):
        layout[k+offset] = list(range(vectorwords*k+offset,vectorwords*(k+nhigh//vectorwords),nhigh//vectorwords))
  layout = [layout[k] for k in range(len(layout))]

  # ===== build network

  S = vectorsortingnetwork(layout,xor=xor)

  for k in range(numvectors):
    S.load(k,partial=k*vectorwords>=nlow)

  for lgsubsort in range(1,lgn+1):
    if V and lgsubsort < lgn: continue
    for stage in reversed(range(lgsubsort)):
      if nhigh >= 2*vectorwords and (lgsubsort,stage) == (1,0):
        comparators = {a:a^1 for a in range(nhigh)}
      elif nhigh >= 4*vectorwords and (lgsubsort,stage) == (2,1):
        comparators = {a:a^2 for a in range(nhigh)}
      elif nhigh >= 4*vectorwords and (lgsubsort,stage) == (2,0):
        comparators = {a:a^(3*(1&((a>>0)^(a>>1)))) for a in range(nhigh)}
      elif nhigh >= 8*vectorwords and (lgsubsort,stage) == (3,2):
        comparators = {a:a^4 for a in range(nhigh)}
      elif nhigh >= 8*vectorwords and (lgsubsort,stage) == (3,1):
        comparators = {a:a+[0,0,2,2,-2,-2,0,0][a%8] for a in range(nhigh)}
      elif nhigh >= 8*vectorwords and (lgsubsort,stage) == (3,0):
        comparators = {a:a+[0,1,-1,1,-1,1,-1,0][a%8] for a in range(nhigh)}
      elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,3):
        comparators = {a:a^8 for a in range(nhigh)}
      elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,2):
        comparators = {a:a+[0,0,0,0,4,4,4,4,-4,-4,-4,-4,0,0,0,0][a%16] for a in range(nhigh)}
      elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,1):
        comparators = {a:a+[0,0,2,2,-2,-2,2,2,-2,-2,2,2,-2,-2,0,0][a%16] for a in range(nhigh)}
      elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,0):
        comparators = {a:a+[0,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,0][a%16] for a in range(nhigh)}
      else:
        if stage == lgsubsort-1:
          stagemask = (2<<stage)-1
        else:
          stagemask = 1<<stage
        comparators = {a:a^stagemask for a in range(nhigh)}

      if vectorwords == 8 and 2<<stage == nhigh and not V:
        for k in range(0,numvectors,2):
          p = S.layout[k]
          newlayout = [p[k^((((k>>2)^(k>>1))&1)*6)] for k in range(vectorwords)] # XXX
          S.shuffle1(k,newlayout)

      strcomparators = ' '.join(f'{a}:{comparators[a]}' for a in range(nhigh) if comparators[a] > a)
      S.comment(f'stage ({lgsubsort},{stage}) {strcomparators}')

      S.rearrange(comparators)

      for k in range(0,numvectors,2):
        if all(comparators[a] == a for a in S.layout[k]):
          if all(comparators[a] == a for a in S.layout[k+1]):
            continue

        S.minmax(k,k+1)

  for k in range(0,numvectors,2):
    for offset in 0,1:
      for i in range(numvectors):
        if k*vectorwords+offset in S.layout[i]:
          if i != k+offset:
            S.vecswap(i,k+offset)
          break

  for k in range(0,numvectors,2):
    y0 = list(range(k*vectorwords,(k+1)*vectorwords))
    y1 = list(range((k+1)*vectorwords,(k+2)*vectorwords))
    S.shuffle2(k,k+1,y0,y1,exact=True)

  for k in range(numvectors) if allowmaskstore or nlow == nhigh else reversed(range(numvectors)):
    S.store(k,partial=k*vectorwords>=nlow)

  S.print()

# ===== V_sort

def threestages(k,down=False,p=None,atleast=None):
  assert k in (4,5,6,7,8)

  print('NOINLINE')
  updown = 'down' if down else 'up'
  if p is not None:
    print(f'static void {intI}_threestages_{k}_{updown}_{p}({intI} *x)')
  elif atleast is not None:
    print(f'static void {intI}_threestages_{k}_{updown}_atleast{atleast}({intI} *x,long long p)')
  else:
    print(f'static void {intI}_threestages_{k}_{updown}({intI} *x,long long p,long long n)')

  print('{')
  print('  long long i;')
  if p is not None:
    print(f'  long long p = {p};')
  if p is not None or atleast is not None:
    print(f'  long long n = p;')

  for vector in True,False: # must be this order
    if p is not None and p%vectorwords == 0 and not vector: break
    if atleast is not None and atleast%vectorwords == 0 and not vector: break

    xtype = intIvec if vector else intI
    if vector:
      print(f'  for (i = 0;i+{vectorwords} <= n;i += {vectorwords}) {{')
    else:
      print('  for (;i < n;++i) {')

    for j in range(k):
      addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i'
      if vector:
        print(f'    {xtype} x{j} = {xtype}_load(&x[{addr}]);')
      else:
        print(f'    {xtype} x{j} = x[{addr}];')

    for i,j in (0,4),(1,5),(2,6),(3,7),(0,2),(1,3),(4,6),(5,7),(0,1),(2,3),(4,5),(6,7):
      if j >= k: continue
      if down: i,j = j,i
      print(f'    {xtype}_MINMAX(x{i},x{j});')

    for j in range(k):
      addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i'
      if vector:
        print(f'    {xtype}_store(&x[{addr}],x{j});')
      else:
        print(f'    x[{addr}] = x{j};')

    print('  }')

  print('}')
  print('')

def V_sort():
  for nlow,nhigh in unrolledVsort:
    fixedsize(nlow,nhigh,V=True)
  for n in unrolledVsortxor:
    fixedsize(n,n,xor=True,V=True)

  threestages(8)
  threestages(7)
  threestages(6)
  threestages(5)
  threestages(4)

  threestages_min = Vsort2_min
  for p in threestages_specialize:
    assert p == threestages_min
    threestages_min *= 2
    threestages(8,p=p)
    threestages(8,down=True,p=p)
  threestages(8,down=True,atleast=threestages_min)

  threestages(6,down=True)

  print(f'''// XXX: currently xor must be 0 or -1
NOINLINE
static void {intI}_{Vsort2}_xor({intI} *x,long long n,{intI} xor)
{{''')

  for n in unrolledVsortxor:
    print(f'  if (n == {n}) {{ {intI}_V_sort_{n}_xor(x,xor); return; }}')

  assert unrolledVsortxor[:3] == (Vsort2_min,Vsort2_min*2,Vsort2_min*4)
  # so n is at least Vsort2_min*8, justifying the following recursive calls

  for p in threestages_specialize:
    assert p in unrolledVsortxor
    print(f'''  if (n == {p*8}) {{
    if (xor)
      {intI}_threestages_8_down_{p}(x);
    else
      {intI}_threestages_8_up_{p}(x);
    for (long long i = 0;i < 8;++i)
      {intI}_V_sort_{p}_xor(x+{p}*i,xor);
    return;
  }}''')

  print(f'''  if (xor)
    {intI}_threestages_8_down_atleast{threestages_min}(x,n>>3);
  else
    {intI}_threestages_8_up(x,n>>3,n>>3);
  for (long long i = 0;i < 8;++i)
    {intI}_{Vsort2}_xor(x+(n>>3)*i,n>>3,xor);
}}

/* q is power of 2; want only merge stages q,q/2,q/4,...,1 */
// XXX: assuming 8 <= q < n <= 2q; q is a power of 2
NOINLINE
static void {intI}_V_sort({intI} *x,long long q,long long n)
{{''')

  assert any(nhigh == Vsort2_min for nlow,nhigh in unrolledVsort)
  for nlow,nhigh in unrolledVsort:
    if nhigh == Vsort2_min and favorxor:
      print(f'  if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''')
    print(f'  if (n <= {nhigh}) {{ {intI}_V_sort_{nlow}through{nhigh}(x,n); return; }}')

  if not favorxor:
    print(f'  if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''')

  assert any(nhigh == 64 for nlow,nhigh in unrolledVsort)

  print(f'''
  // 64 <= q < n < 2q
  q >>= 2;
  // 64 <= 4q < n < 8q

  if (7*q < n) {{
    {intI}_threestages_8_up(x,q,n-7*q);
    {intI}_threestages_7_up(x+n-7*q,q,8*q-n);
  }} else if (6*q < n) {{
    {intI}_threestages_7_up(x,q,n-6*q);
    {intI}_threestages_6_up(x+n-6*q,q,7*q-n);
  }} else if (5*q < n) {{
    {intI}_threestages_6_up(x,q,n-5*q);
    {intI}_threestages_5_up(x+n-5*q,q,6*q-n);
  }} else {{
    {intI}_threestages_5_up(x,q,n-4*q);
    {intI}_threestages_4_up(x+n-4*q,q,5*q-n);
  }}

  // now want to handle each batch of q entries separately

  {intI}_V_sort(x,q>>1,q);
  {intI}_V_sort(x+q,q>>1,q);
  {intI}_V_sort(x+2*q,q>>1,q);
  {intI}_V_sort(x+3*q,q>>1,q);
  x += 4*q;
  n -= 4*q;
  while (n >= q) {{
    {intI}_V_sort(x,q>>1,q);
    x += q;
    n -= q;
  }}

  // have n entries left in last batch, with 0 <= n < q
  if (n <= 1) return;
  while (q >= n) q >>= 1; // empty merge stage
  // now 1 <= q < n <= 2q
  if (q >= 8) {{ {intI}_V_sort(x,q,n); return; }}

  if (n == 8) {{
    {intI}_MINMAX(x[0],x[4]);
    {intI}_MINMAX(x[1],x[5]);
    {intI}_MINMAX(x[2],x[6]);
    {intI}_MINMAX(x[3],x[7]);
    {intI}_MINMAX(x[0],x[2]);
    {intI}_MINMAX(x[1],x[3]);
    {intI}_MINMAX(x[0],x[1]);
    {intI}_MINMAX(x[2],x[3]);
    {intI}_MINMAX(x[4],x[6]);
    {intI}_MINMAX(x[5],x[7]);
    {intI}_MINMAX(x[4],x[5]);
    {intI}_MINMAX(x[6],x[7]);
    return;
  }}
  if (4 <= n) {{
    for (long long i = 0;i < n-4;++i)
      {intI}_MINMAX(x[i],x[4+i]);
    {intI}_MINMAX(x[0],x[2]);
    {intI}_MINMAX(x[1],x[3]);
    {intI}_MINMAX(x[0],x[1]);
    {intI}_MINMAX(x[2],x[3]);
    n -= 4;
    x += 4;
  }}
  if (3 <= n)
    {intI}_MINMAX(x[0],x[2]);
  if (2 <= n)
    {intI}_MINMAX(x[0],x[1]);
}}
''')

# ===== main sort

def main_sort_prep():
  smallsize()
  for nlow,nhigh in unrolledsort:
    fixedsize(nlow,nhigh)
  for n in unrolledsortxor:
    fixedsize(n,n,xor=True)

def main_sort():
  print('// XXX: currently xor must be 0 or -1')
  print('NOINLINE')
  print(f'static void {intI}_{sort2}_xor({intI} *x,long long n,{intI} xor)')
  print('{')

  for n in unrolledsortxor:
    print(f'  if (n == {n}) {{ {intI}_sort_{n}_xor(x,xor); return; }}')

  print(f'  {intI}_{sort2}_xor(x,n>>1,~xor);')
  print(f'  {intI}_{sort2}_xor(x+(n>>1),n>>1,xor);')
  print(f'  {intI}_{Vsort2}_xor(x,n,xor);')
  print('}')

  print(fr'''
void {intI}_sort({intI} *x,long long n)
{{ long long q;
  if (n <= 1) return;
  if (n == 2) {{ {intI}_MINMAX(x[0],x[1]); return; }}
  if (n <= 7) {{ {intI}_sort_3through7(x,n); return; }}''')
  # XXX: n cutoff here should be another variable to optimize

  nmin = 8
  # invariant: n in program is at least nmin

  assert any(nhigh == sort2_min for nlow,nhigh in unrolledsort)
  for nlow,nhigh in unrolledsort:
    if nhigh == sort2_min and favorxor:
      print(f'  if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''')
    if nlow <= nmin:
      print(f'  if (n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}')
      while nlow <= nmin and nmin <= nhigh: nmin += 1
    else:
      print(f'  if ({nlow} <= n && n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}')

  if not favorxor:
    print(f'  if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''')

  qmin = 1
  while qmin < nmin-qmin: qmin += qmin
  assert sort2_min <= qmin

  for padlow,padhigh in pads:
    padlowdown = padlow
    while padlowdown%vectorwords: padlowdown -= 1
    padhighup = padhigh
    while padhighup%vectorwords: padhighup += 1

    print(fr'''  if ({padlow} <= n && n <= {padhigh}) {{
    {intI} buf[{padhighup}];
    for (long long i = {padlowdown};i < {padhighup};++i) buf[i] = {intI}_largest;
    for (long long i = 0;i < n;++i) buf[i] = x[i];
    {intI}_sort(buf,{padhighup});
    for (long long i = 0;i < n;++i) x[i] = buf[i];
    return;
  }}''')

  assert sort2_min%2 == 0
  assert sort2_min//2 >= Vsort2_min

  print(fr'''
  q = {qmin};
  while (q < n - q) q += q;
  // {qmin} <= q < n < 2q

  if ({sort2_min*16} <= n && n <= (7*q)>>2) {{
    long long m = (3*q)>>2; // strategy: sort m, sort n-m, merge
    long long r = q>>3; // at least {sort2_min} since q is at least {sort2_min*8}
    {intI}_{sort2}_xor(x,4*r,0);
    {intI}_{sort2}_xor(x+4*r,r,0);
    {intI}_{sort2}_xor(x+5*r,r,-1);
    {intI}_{Vsort2}_xor(x+4*r,2*r,-1);
    {intI}_threestages_6_down(x,r,r);
    for (long long i = 0;i < 6;++i)
      {intI}_{Vsort2}_xor(x+i*r,r,-1);
    {intI}_sort(x+m,n-m);
  }} else if ({sort2_min*2} <= q && n == (3*q)>>1) {{
    // strategy: sort q, sort q/2, merge
    long long r = q>>2; // at least {sort2_min//2} since q is at least {sort2_min*2}
    {intI}_{sort2}_xor(x,4*r,-1);
    {intI}_{sort2}_xor(x+4*r,2*r,0);
    {intI}_threestages_6_up(x,r,r);
    for (long long i = 0;i < 6;++i)
      {intI}_{Vsort2}_xor(x+i*r,r,0);
    return;
  }} else {{
    {intI}_{sort2}_xor(x,q,-1);
    {intI}_sort(x+q,n-q);
  }}

  {intI}_V_sort(x,q,n);
}}''')

# ===== driver

preamble()
main_sort_prep()
V_sort()
main_sort()
