#!/usr/bin/env python3

import sys

I = int(sys.argv[1])
assert I in (32,64)

# XXX: do more searches for optimal parameters here
if I == 32:
  unrolledsort = (8,16),(16,32),(32,64),(64,128)
  unrolledsortxor = 64,128
  unrolledVsort = (8,16),(16,32),(32,64)
  unrolledVsortxor = 32,64,128
  threestages_specialize = 32,
  pads = (161,191),(193,255) # (449,511) is also interesting
  favorxor = True
else:
  unrolledsort = (8,16),(16,32),(32,64)
  unrolledsortxor = 32,64
  unrolledVsort = (8,16),(16,32),(32,64)
  unrolledVsortxor = 16,32,64
  threestages_specialize = 16,
  pads = (81,95),(97,127) # (225,255) is also interesting
  favorxor = True

sort2_min = min(unrolledsortxor)
sort2 = f'sort_2poweratleast{sort2_min}'
Vsort2_min = min(unrolledVsortxor)
Vsort2 = f'V_sort_2poweratleast{Vsort2_min}'

vectorlen = 256//I

allowmaskload = False # fast on Intel, and on AMD starting with Zen 2, _but_ AMD documentation claims that it can fail
allowmaskstore = False # fast on Intel but not on AMD
usepartialmem = True # irrelevant unless allowmaskload or allowmaskstore
partiallimit = 64 # XXX: could allow 128 for non-mask versions without extra operations

assert vectorlen&(vectorlen-1) == 0
intI = f'int{I}'
intIvec = f'{intI}x{vectorlen}'
int8vec = f'int8x{(I*vectorlen)//8}'

def preamble():
  print(fr'''/* WARNING: auto-generated (by autogen/sort); do not edit */

#include <immintrin.h>

#include "{intI}_sort.h"
#define {intI} {intI}_t
#define {intI}_largest {hex((1<<(I-1))-1)}

#include "crypto_{intI}.h"
#define {intI}_min crypto_{intI}_min
#define {intI}_MINMAX(a,b) crypto_{intI}_minmax(&(a),&(b))

#define NOINLINE __attribute__((noinline))

typedef __m256i {intIvec};
#define {intIvec}_load(z) _mm256_loadu_si256((__m256i *) (z))
#define {intIvec}_store(z,i) _mm256_storeu_si256((__m256i *) (z),(i))

#define {intIvec}_smaller_mask(a,b) _mm256_cmpgt_epi{I}(b,a)
#define {intIvec}_add _mm256_add_epi{I}
#define {intIvec}_sub _mm256_sub_epi{I}
#define {int8vec}_iftopthenelse(c,t,e) _mm256_blendv_epi8(e,t,c)
#define {intIvec}_leftleft(a,b) _mm256_permute2x128_si256(a,b,0x20)
#define {intIvec}_rightright(a,b) _mm256_permute2x128_si256(a,b,0x31)
''')

  if vectorlen == 8:
    print(fr'''#define {intIvec}_MINMAX(a,b) \
do {{ \
  {intIvec} c = {intIvec}_min(a,b); \
  b = {intIvec}_max(a,b); \
  a = c; \
}} while(0)

#define {intIvec}_min _mm256_min_epi{I}
#define {intIvec}_max _mm256_max_epi{I}
#define {intIvec}_set _mm256_setr_epi{I}
#define {intIvec}_broadcast _mm256_set1_epi{I}
#define {intIvec}_varextract _mm256_permutevar8x32_epi32
#define {intIvec}_extract(v,p0,p1,p2,p3,p4,p5,p6,p7) {intIvec}_varextract(v,_mm256_setr_epi{I}(p0,p1,p2,p3,p4,p5,p6,p7))
#define {intIvec}_constextract_eachside(v,p0,p1,p2,p3) _mm256_shuffle_epi{I}(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_aabb_eachside(a,b,p0,p1,p2,p3) _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b),_MM_SHUFFLE(p3,p2,p1,p0)))
#define {intIvec}_ifconstthenelse(c0,c1,c2,c3,c4,c5,c6,c7,t,e) _mm256_blend_epi{I}(e,t,(c0)|((c1)<<1)|((c2)<<2)|((c3)<<3)|((c4)<<4)|((c5)<<5)|((c6)<<6)|((c7)<<7))
''')

  if vectorlen == 4:
    print(fr'''#define {intIvec}_MINMAX(a,b) \
do {{ \
  {intIvec} t = {intIvec}_smaller_mask(a,b); \
  {intIvec} c = {int8vec}_iftopthenelse(t,a,b); \
  b = {int8vec}_iftopthenelse(t,b,a); \
  a = c; \
}} while(0)

#define int32x8_add _mm256_add_epi32
#define int32x8_sub _mm256_sub_epi32
#define int32x8_set _mm256_setr_epi32
#define int32x8_broadcast _mm256_set1_epi32
#define int32x8_varextract _mm256_permutevar8x32_epi32
#define {intIvec}_set _mm256_setr_epi64x
#define {intIvec}_broadcast _mm256_set1_epi64x
#define {intIvec}_extract(v,p0,p1,p2,p3) _mm256_permute4x64_epi64(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_eachside(v,p0,p1) _mm256_shuffle_epi32(v,_MM_SHUFFLE(2*(p1)+1,2*(p1),2*(p0)+1,2*(p0)))
#define {intIvec}_constextract_a01b01a23b23(a,b,p0,p1,p2,p3) _mm256_castpd_si256(_mm256_shuffle_pd(_mm256_castsi256_pd(a),_mm256_castsi256_pd(b),(p0)|((p1)<<1)|((p2)<<2)|((p3)<<3)))
#define {intIvec}_ifconstthenelse(c0,c1,c2,c3,t,e) _mm256_blend_epi32(e,t,(c0)|((c0)<<1)|((c1)<<2)|((c1)<<3)|((c2)<<4)|((c2)<<5)|((c3)<<6)|((c3)<<7))

#include "crypto_int32.h"
#define int32_min crypto_int32_min
''')

  # XXX: can skip some of the macros above if allowmaskload and allowmaskstore
  if usepartialmem and (allowmaskload or allowmaskstore):
    print(fr'''#define partialmem (partialmem_storage+64)
static const {intI} partialmem_storage[] __attribute__((aligned(128))) = {{
  -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
  -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
  -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
  -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
}} ;

#define {intIvec}_partialload(p,z) _mm256_maskload_epi{I}((void *) (z),p)
#define {intIvec}_partialstore(p,z,i) _mm256_maskstore_epi{I}((void *) (z),p,(i))
''')

# ===== serial fixed-size networks

def smallsize():
  print(fr'''NOINLINE
static void {intI}_sort_3through7({intI} *x,long long n)
{{
  if (n >= 4) {{
    {intI} x0 = x[0];
    {intI} x1 = x[1];
    {intI} x2 = x[2];
    {intI} x3 = x[3];
    {intI}_MINMAX(x0,x1);
    {intI}_MINMAX(x2,x3);
    {intI}_MINMAX(x0,x2);
    {intI}_MINMAX(x1,x3);
    {intI}_MINMAX(x1,x2);
    if (n >= 5) {{
      if (n == 5) {{
        {intI} x4 = x[4];
        {intI}_MINMAX(x0,x4);
        {intI}_MINMAX(x2,x4);
        {intI}_MINMAX(x1,x2);
        {intI}_MINMAX(x3,x4);
        x[4] = x4;
      }} else {{
        {intI} x4 = x[4];
        {intI} x5 = x[5];
        {intI}_MINMAX(x4,x5);
        if (n == 6) {{
          {intI}_MINMAX(x0,x4);
          {intI}_MINMAX(x2,x4);
          {intI}_MINMAX(x1,x5);
          {intI}_MINMAX(x3,x5);
        }} else {{
          {intI} x6 = x[6];
          {intI}_MINMAX(x4,x6);
          {intI}_MINMAX(x5,x6);
          {intI}_MINMAX(x0,x4);
          {intI}_MINMAX(x2,x6);
          {intI}_MINMAX(x2,x4);
          {intI}_MINMAX(x1,x5);
          {intI}_MINMAX(x3,x5);
          {intI}_MINMAX(x5,x6);
          x[6] = x6;
        }}
        {intI}_MINMAX(x1,x2);
        {intI}_MINMAX(x3,x4);
        x[4] = x4;
        x[5] = x5;
      }}
    }}
    x[0] = x0;
    x[1] = x1;
    x[2] = x2;
    x[3] = x3;
  }} else {{
    {intI} x0 = x[0];
    {intI} x1 = x[1];
    {intI} x2 = x[2];
    {intI}_MINMAX(x0,x1);
    {intI}_MINMAX(x0,x2);
    {intI}_MINMAX(x1,x2);
    x[0] = x0;
    x[1] = x1;
    x[2] = x2;
  }}
}}
''')

# ===== vectorized fixed-size networks

class vectorsortingnetwork:
  def __init__(self,layout,xor=True):
    layout = [list(x) for x in layout]
    for x in layout: assert len(x) == vectorlen
    assert len(layout)%2 == 0 # XXX: drop this restriction?
    self.layout = layout
    self.phys = [f'x{r}' for r in range(len(layout))]
    # arch register r is stored in self.phys[r]
    self.operations = []
    self.usedregs = set(self.phys)
    self.usedregssmallindex = set()
    self.usexor = xor
    self.useinfty = False
    if xor:
      self.assign('vecxor',f'{intIvec}_broadcast(xor)')

  def print(self):
    print('{')
    if len(self.usedregssmallindex) > 0:
      print(f'  int32_t {",".join(sorted(self.usedregssmallindex))};')
    if len(self.usedregs) > 0:
      print(f'  {intIvec} {",".join(sorted(self.usedregs))};')
    for op in self.operations:
      if op[0] == 'assign':
        phys,result,newlayout = op[1:]
        if newlayout is None:
          print(f'  {phys} = {result};')
        else:
          print(f'  {phys} = {result}; // {" ".join(map(str,newlayout))}')
      elif op[0] == 'assignsmallindex':
        phys,result = op[1:]
        print(f'  {phys} = {result};')
      elif op[0] == 'store':
        result, = op[1:]
        print(f'  {result};')
      elif op[0] == 'comment':
        c, = op[1:]
        print(f'  // {c}')
      else:
        raise Exception(f'unrecognized operation {op}')
    print('}')
    print('')

  def allocate(self,r):
    '''Assign a new free physical register to arch register r. OK for caller to also use old register until calling allocate again.'''
    self.phys[r] = {'x':'y','y':'x'}[self.phys[r][0]]+self.phys[r][1:]
    # XXX: for generating low-level asm would instead want to cycle between fewer regs
    # but, for generating C (or qhasm), trusting register allocator makes generated assignments a bit more readable
    # (at least for V_sort)

  def comment(self,c):
    self.operations.append(('comment',c))

  def assign(self,phys,result,newlayout=None):
    self.usedregs.add(phys)
    self.operations.append(('assign',phys,result,newlayout))

  def assignsmallindex(self,phys,result):
    self.usedregssmallindex.add(phys)
    self.operations.append(('assignsmallindex',phys,result))

  def createinfty(self):
    if self.useinfty: return
    self.useinfty = True
    self.assign('infty',f'{intIvec}_broadcast({intI}_largest)')

  def partialmask(self,offset):
    if usepartialmem:
      return f'{intIvec}_load(&partialmem[{offset}-n])'
    # XXX: also do version with offset in constants rather than n?
    rangevectorlen = ','.join(map(str,range(vectorlen)))
    return f'{intIvec}_smaller_mask({intIvec}_set({rangevectorlen}),{intIvec}_broadcast(n-{offset}))'

  def load(self,r,partial=False):
    rphys = self.phys[r]
    xor = 'vecxor^' if self.usexor else ''
    # uint instead of int would slightly streamline infty usage
    if partial:
      self.createinfty()
      if allowmaskload:
        self.assign(f'partial{r}',f'{self.partialmask(r*vectorlen)}')
        self.assign(rphys,f'{xor}{intIvec}_partialload(partial{r},x+{vectorlen*r})',self.layout[r])
        self.assign(rphys,f'{int8vec}_iftopthenelse(partial{r},{rphys},infty)')
      else:
        if I == 32:
          mplus = ','.join(map(str,range(r*vectorlen,(r+1)*vectorlen)))
          self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorlen},n)')
          self.assign(f'diff{r}',f'{intIvec}_sub({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))')
          self.assign(rphys,f'{intIvec}_varextract({intIvec}_load(x+pos{r}-8),diff{r})')
          self.assign(rphys,f'{int8vec}_iftopthenelse(diff{r},{rphys},infty)',self.layout[r])
        else:
          mplus = ','.join(map(str,range(2*r*vectorlen,2*(r+1)*vectorlen)))
          self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorlen},n)')
          self.assign(f'diff{r}',f'int32x8_sub(int32x8_set({mplus}),int32x8_broadcast(2*pos{r}))')
          self.assign(rphys,f'int32x8_varextract({intIvec}_load(x+pos{r}-4),diff{r})')
          self.assign(rphys,f'{int8vec}_iftopthenelse(diff{r},{rphys},infty)',self.layout[r])
    else:
      self.assign(rphys,f'{xor}{intIvec}_load(x+{vectorlen*r})',self.layout[r])

  # warning: must do store in reverse order
  def store(self,r,partial=False):
    rphys = self.phys[r]
    xor = 'vecxor^' if self.usexor else ''
    if partial:
      if allowmaskstore:
        if not allowmaskload:
          self.assign(f'partial{r}',f'{self.partialmask(r*vectorlen)}')
        self.operations.append(('store',f'{intIvec}_partialstore(partial{r},x+{vectorlen*r},{xor}{rphys})'))
      else:
        if allowmaskload:
          self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorlen},n)')
        else:
          # this is why store has to be in reverse order
          if I == 32:
            mplus = ','.join(map(str,range(vectorlen))) # XXX: or could offset by minimum r to reuse a vector from before
            self.operations.append(('store',f'{intIvec}_store(x+pos{r}-8,{intIvec}_varextract({xor}{rphys},{intIvec}_add({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))))'))
          else:
            mplus = ','.join(map(str,range(2*vectorlen))) # XXX: or could offset by minimum r to reuse a vector from before
            self.operations.append(('store',f'{intIvec}_store(x+pos{r}-4,int32x8_varextract({xor}{rphys},int32x8_add(int32x8_set({mplus}),int32x8_broadcast(2*pos{r}))))'))
    else:
      self.operations.append(('store',f'{intIvec}_store(x+{vectorlen*r},{xor}{rphys})'))

  def vecswap(self,r,s):
    assert r != s
    self.phys[r],self.phys[s] = self.phys[s],self.phys[r]
    self.layout[r],self.layout[s] = self.layout[s],self.layout[r]

  def minmax(self,r,s):
    assert r != s
    rphys = self.phys[r]
    sphys = self.phys[s]
    newlayout0 = [min(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])]
    newlayout1 = [max(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])]
    self.layout[r],self.layout[s] = newlayout0,newlayout1
    self.allocate(r)
    rphysnew = self.phys[r]
    if I == 32:
      self.assign(rphysnew,f'{intIvec}_min({rphys},{sphys})',newlayout0)
      self.assign(sphys,f'{intIvec}_max({rphys},{sphys})',newlayout1)
    else: # XXX: avx-512 has min_epi64 but avx2 does not
      self.assign('t',f'{intIvec}_smaller_mask({rphys},{sphys})')
      self.assign(rphysnew,f'{int8vec}_iftopthenelse(t,{rphys},{sphys})',newlayout0)
      self.assign(sphys,f'{int8vec}_iftopthenelse(t,{sphys},{rphys})',newlayout1)

  def shuffle1(self,r,L):
    r'''Rearrange layout of vector r to match L.'''
    L = list(L)
    oldL = self.layout[r]
    perm = [oldL.index(a) for a in L]
    rphys = self.phys[r]
    if vectorlen == 8 and perm[4:] == [perm[0]+4,perm[1]+4,perm[2]+4,perm[3]+4]:
      self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L)
    elif vectorlen == 4 and perm[2:] == [perm[0]+2,perm[1]+2]:
      self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]})',L)
    elif vectorlen == 8:
      self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]},{perm[4]},{perm[5]},{perm[6]},{perm[7]})',L)
    elif vectorlen == 4:
      self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L)
    else:
      raise Exception(f'unhandled permutation from {oldL} to {L}')
    self.layout[r] = L

  def shuffle2(self,r,s,L,M,exact=False):
    oldL = self.layout[r]
    oldM = self.layout[s]
    rphys = self.phys[r]
    sphys = self.phys[s]

    try:
      assert vectorlen == 4
      p0 = oldL[:2].index(L[0])
      p1 = oldM[:2].index(L[1])
      p2 = oldL[2:].index(L[2])
      p3 = oldM[2:].index(L[3])
      q0 = oldL[:2].index(M[0])
      q1 = oldM[:2].index(M[1])
      q2 = oldL[2:].index(M[2])
      q3 = oldM[2:].index(M[3])
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphysnew,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{p0},{p1},{p2},{p3})',L)
      self.assign(sphys,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{q0},{q1},{q2},{q3})',M)
      return
    except:
      pass

    try:
      assert vectorlen == 8
      p0 = oldL[:4].index(L[0])
      p1 = oldL[:4].index(L[1])
      p2 = oldM[:4].index(L[2])
      p3 = oldM[:4].index(L[3])
      assert p0 == oldL[4:].index(L[4])
      assert p1 == oldL[4:].index(L[5])
      assert p2 == oldM[4:].index(L[6])
      assert p3 == oldM[4:].index(L[7])
      q0 = oldL[:4].index(M[0])
      q1 = oldL[:4].index(M[1])
      q2 = oldM[:4].index(M[2])
      q3 = oldM[:4].index(M[3])
      assert q0 == oldL[4:].index(M[4])
      assert q1 == oldL[4:].index(M[5])
      assert q2 == oldM[4:].index(M[6])
      assert q3 == oldM[4:].index(M[7])
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',L)
      self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',M)
      return
    except:
      pass

    try:
      assert not exact
      assert vectorlen == 8
      p0 = oldL[:4].index(L[0])
      p1 = oldL[:4].index(L[2])
      p2 = oldM[:4].index(L[1])
      p3 = oldM[:4].index(L[3])
      assert p0 == oldL[4:].index(L[4])
      assert p1 == oldL[4:].index(L[6])
      assert p2 == oldM[4:].index(L[5])
      assert p3 == oldM[4:].index(L[7])
      q0 = oldL[:4].index(M[0])
      q1 = oldL[:4].index(M[2])
      q2 = oldM[:4].index(M[1])
      q3 = oldM[:4].index(M[3])
      assert q0 == oldL[4:].index(M[4])
      assert q1 == oldL[4:].index(M[6])
      assert q2 == oldM[4:].index(M[5])
      assert q3 == oldM[4:].index(M[7])
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = [L[0],L[2],L[1],L[3],L[4],L[6],L[5],L[7]]
      self.layout[s] = [M[0],M[2],M[1],M[3],M[4],M[6],M[5],M[7]]
      self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',self.layout[r])
      self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',self.layout[s])
      return
    except:
      pass

    try:
      half = vectorlen//2
      assert oldL[:half] == L[:half]
      assert oldL[half:] == M[:half]
      assert oldM[:half] == L[half:]
      assert oldM[half:] == M[half:]
      self.allocate(r)
      rphysnew = self.phys[r]
      self.layout[r] = L
      self.layout[s] = M
      self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L)
      self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M)
      return
    except:
      pass

    for bigflip in False,True:
      try:
        if bigflip:
          half = vectorlen//2
          z0 = L[:half]+M[:half]
          z1 = L[half:]+M[half:]
        else:
          z0 = L
          z1 = M
        blend = []
        shuf0 = [None]*vectorlen
        shuf1 = [None]*vectorlen
        for i in range(vectorlen):
          if oldL[i] == z0[i]:
            blend.append(0)
            shuf0[i] = oldL.index(z1[i])
            shuf1[i] = 0
            if i < vectorlen//2:
              assert shuf0[i] < vectorlen//2
            else:
              assert shuf0[i] == shuf0[i-vectorlen//2]+vectorlen//2
          else:
            blend.append(1)
            assert oldM[i] == z1[i]
            shuf0[i] = 0
            shuf1[i] = oldM.index(z0[i])
            if i < vectorlen//2:
              assert shuf1[i] < vectorlen//2
            else:
              assert shuf1[i] == shuf1[i-vectorlen//2]+vectorlen//2
        blend = ','.join(map(str,blend))
        s0 = ','.join(map(str,shuf0[:vectorlen//2]))
        s1 = ','.join(map(str,shuf1[:vectorlen//2]))
        # XXX: encapsulate temporaries better
        self.assign('u',f'{intIvec}_constextract_eachside({sphys},{s1})')
        self.assign('t',f'{intIvec}_constextract_eachside({rphys},{s0})')
        self.assign(rphys,f'{intIvec}_ifconstthenelse({blend},u,{rphys})',L)
        self.assign(sphys,f'{intIvec}_ifconstthenelse({blend},{sphys},t)',M)
        if bigflip:
          self.allocate(r)
          rphysnew = self.phys[r]
          self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L)
          self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M)
        self.layout[r] = L
        self.layout[s] = M
        return
      except:
        pass

    raise Exception(f'unhandled permutation from {oldL},{oldM} to {L},{M}')

  def rearrange_onestep(self,comparators):
    numvectors = len(self.layout)

    for k in range(0,numvectors,2):
      if all(comparators[a] == a for a in self.layout[k]):
        if all(comparators[a] == a for a in self.layout[k+1]):
          continue

      collected = set(self.layout[k]+self.layout[k+1])
      if not all(comparators[a] in collected for a in collected):
        for j in range(numvectors):
          if all(comparators[a] in self.layout[j] for a in self.layout[k]):
            self.vecswap(j,k+1)
            return True

      if any(comparators[a] in self.layout[k] for a in self.layout[k]):
        newlayout0 = list(self.layout[k])
        newlayout1 = list(self.layout[k+1])
        for i in range(vectorlen):
          a = newlayout0[i]
          b = comparators[a]
          if b == newlayout1[i]: continue
          if b in newlayout0:
            j = newlayout0.index(b)
            assert j > i
            newlayout1[i],newlayout0[j] = newlayout0[j],newlayout1[i]
          else:
            j = newlayout1.index(b)
            assert j > i
            newlayout1[i],newlayout1[j] = newlayout1[j],newlayout1[i]
        self.shuffle2(k,k+1,newlayout0,newlayout1)
        return True

      if [comparators[a] for a in self.layout[k]] != self.layout[k+1]:
        newlayout = [comparators[a] for a in self.layout[k]]
        self.shuffle1(k+1,newlayout)
        return True

    return False

  def rearrange(self,comparators):
    comparators = dict(comparators)
    while self.rearrange_onestep(comparators): pass

def fixedsize(nlow,nhigh,xor=False,V=False):
  nlow = int(nlow)
  nhigh = int(nhigh)
  assert 0 <= nlow
  assert nlow <= nhigh
  assert nhigh-partiallimit <= nlow
  assert nhigh%vectorlen == 0
  assert nlow%vectorlen == 0

  lgn = 4
  while 2**lgn < nhigh: lgn += 1
  if nhigh < 2*vectorlen:
    raise Exception(f'unable to handle sizes below {2*vectorlen}')

  if V:
    funname = f'{intI}_V_sort_'
  else:
    funname = f'{intI}_sort_'
  funargs = f'{intI} *x'
  if nlow < nhigh:
    funname += f'{nlow}through'
    funargs += ',long long n'
  funname += f'{nhigh}'
  if xor:
    funname += '_xor'
    funargs += f',{intI} xor'
  print('NOINLINE')
  print(f'static void {funname}({funargs})')

  # ===== decide on initial layout of nodes

  numvectors = nhigh//vectorlen
  layout = {}
  if V:
    for k in range(numvectors//2):
      layout[k] = list(reversed(range(vectorlen*(numvectors//2-1-k),vectorlen*(numvectors//2-k))))
    for k in range(numvectors//2,numvectors):
      layout[k] = list(range(vectorlen*k,vectorlen*(k+1)))
  else:
    for k in range(0,numvectors,nhigh//vectorlen):
      for offset in range(nhigh//vectorlen):
        layout[k+offset] = list(range(vectorlen*k+offset,vectorlen*(k+nhigh//vectorlen),nhigh//vectorlen))
  layout = [layout[k] for k in range(len(layout))]

  # ===== build network

  S = vectorsortingnetwork(layout,xor=xor)

  for k in range(numvectors):
    S.load(k,partial=k*vectorlen>=nlow)

  for lgsubsort in range(1,lgn+1):
    if V and lgsubsort < lgn: continue
    for stage in reversed(range(lgsubsort)):
      if nhigh >= 16 and (lgsubsort,stage) == (1,0):
        comparators = {a:a^1 for a in range(nhigh)}
      elif nhigh >= 32 and (lgsubsort,stage) == (2,1):
        comparators = {a:a^2 for a in range(nhigh)}
      elif nhigh >= 32 and (lgsubsort,stage) == (2,0):
        comparators = {a:a^(3*(1&((a>>0)^(a>>1)))) for a in range(nhigh)}
      elif nhigh >= 64 and (lgsubsort,stage) == (3,2):
        comparators = {a:a^4 for a in range(nhigh)}
      elif nhigh >= 64 and (lgsubsort,stage) == (3,1):
        comparators = {a:a+[0,0,2,2,-2,-2,0,0][a%8] for a in range(nhigh)}
      elif nhigh >= 64 and (lgsubsort,stage) == (3,0):
        comparators = {a:a+[0,1,-1,1,-1,1,-1,0][a%8] for a in range(nhigh)}
      elif nhigh >= 128 and (lgsubsort,stage) == (4,3):
        comparators = {a:a^8 for a in range(nhigh)}
      elif nhigh >= 128 and (lgsubsort,stage) == (4,2):
        comparators = {a:a+[0,0,0,0,4,4,4,4,-4,-4,-4,-4,0,0,0,0][a%16] for a in range(nhigh)}
      elif nhigh >= 128 and (lgsubsort,stage) == (4,1):
        comparators = {a:a+[0,0,2,2,-2,-2,2,2,-2,-2,2,2,-2,-2,0,0][a%16] for a in range(nhigh)}
      elif nhigh >= 128 and (lgsubsort,stage) == (4,0):
        comparators = {a:a+[0,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,0][a%16] for a in range(nhigh)}
      else:
        if stage == lgsubsort-1:
          stagemask = (2<<stage)-1
        else:
          stagemask = 1<<stage
        comparators = {a:a^stagemask for a in range(nhigh)}

      if vectorlen == 8 and 2<<stage == nhigh and not V:
        for k in range(0,numvectors,2):
          p = S.layout[k]
          newlayout = [p[k^((((k>>2)^(k>>1))&1)*6)] for k in range(vectorlen)] # XXX
          S.shuffle1(k,newlayout)

      strcomparators = ' '.join(f'{a}:{comparators[a]}' for a in range(nhigh) if comparators[a] > a)
      S.comment(f'stage ({lgsubsort},{stage}) {strcomparators}')

      S.rearrange(comparators)

      for k in range(0,numvectors,2):
        if all(comparators[a] == a for a in S.layout[k]):
          if all(comparators[a] == a for a in S.layout[k+1]):
            continue

        S.minmax(k,k+1)

  for k in range(0,numvectors,2):
    for offset in 0,1:
      for i in range(numvectors):
        if k*vectorlen+offset in S.layout[i]:
          if i != k+offset:
            S.vecswap(i,k+offset)
          break

  for k in range(0,numvectors,2):
    y0 = list(range(k*vectorlen,(k+1)*vectorlen))
    y1 = list(range((k+1)*vectorlen,(k+2)*vectorlen))
    S.shuffle2(k,k+1,y0,y1,exact=True)

  for k in range(numvectors) if allowmaskstore or nlow == nhigh else reversed(range(numvectors)):
    S.store(k,partial=k*vectorlen>=nlow)

  S.print()

# ===== V_sort

def threestages(k,down=False,p=None,atleast=None):
  assert k in (4,5,6,7,8)

  print('NOINLINE')
  updown = 'down' if down else 'up'
  if p is not None:
    print(f'static void {intI}_threestages_{k}_{updown}_{p}({intI} *x)')
  elif atleast is not None:
    print(f'static void {intI}_threestages_{k}_{updown}_atleast{atleast}({intI} *x,long long p)')
  else:
    print(f'static void {intI}_threestages_{k}_{updown}({intI} *x,long long p,long long n)')

  print('{')
  print('  long long i;')
  if p is not None:
    print(f'  long long p = {p};')
  if p is not None or atleast is not None:
    print(f'  long long n = p;')

  for vector in True,False: # must be this order
    if p is not None and p%vectorlen == 0 and not vector: break
    if atleast is not None and atleast%vectorlen == 0 and not vector: break

    xtype = intIvec if vector else intI
    if vector:
      print(f'  for (i = 0;i+{vectorlen} <= n;i += {vectorlen}) {{')
    else:
      print('  for (;i < n;++i) {')

    for j in range(k):
      addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i'
      if vector:
        print(f'    {xtype} x{j} = {xtype}_load(&x[{addr}]);')
      else:
        print(f'    {xtype} x{j} = x[{addr}];')

    for i,j in (0,4),(1,5),(2,6),(3,7),(0,2),(1,3),(4,6),(5,7),(0,1),(2,3),(4,5),(6,7):
      if j >= k: continue
      if down: i,j = j,i
      print(f'    {xtype}_MINMAX(x{i},x{j});')

    for j in range(k):
      addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i'
      if vector:
        print(f'    {xtype}_store(&x[{addr}],x{j});')
      else:
        print(f'    x[{addr}] = x{j};')

    print('  }')

  print('}')
  print('')

def V_sort():
  for nlow,nhigh in unrolledVsort:
    fixedsize(nlow,nhigh,V=True)
  for n in unrolledVsortxor:
    fixedsize(n,n,xor=True,V=True)

  threestages(8)
  threestages(7)
  threestages(6)
  threestages(5)
  threestages(4)

  threestages_min = Vsort2_min
  for p in threestages_specialize:
    assert p == threestages_min
    threestages_min *= 2
    threestages(8,p=p)
    threestages(8,down=True,p=p)
  threestages(8,down=True,atleast=threestages_min)

  threestages(6,down=True)

  print(f'''// XXX: currently xor must be 0 or -1
NOINLINE
static void {intI}_{Vsort2}_xor({intI} *x,long long n,{intI} xor)
{{''')

  for n in unrolledVsortxor:
    print(f'  if (n == {n}) {{ {intI}_V_sort_{n}_xor(x,xor); return; }}')

  assert unrolledVsortxor[:3] == (Vsort2_min,Vsort2_min*2,Vsort2_min*4)
  # so n is at least Vsort2_min*8, justifying the following recursive calls

  for p in threestages_specialize:
    assert p in unrolledVsortxor
    print(f'''  if (n == {p*8}) {{
    if (xor)
      {intI}_threestages_8_down_{p}(x);
    else
      {intI}_threestages_8_up_{p}(x);
    for (long long i = 0;i < 8;++i)
      {intI}_V_sort_{p}_xor(x+{p}*i,xor);
    return;
  }}''')

  print(f'''  if (xor)
    {intI}_threestages_8_down_atleast{threestages_min}(x,n>>3);
  else
    {intI}_threestages_8_up(x,n>>3,n>>3);
  for (long long i = 0;i < 8;++i)
    {intI}_{Vsort2}_xor(x+(n>>3)*i,n>>3,xor);
}}

/* q is power of 2; want only merge stages q,q/2,q/4,...,1 */
// XXX: assuming 8 <= q < n <= 2q; q is a power of 2
NOINLINE
static void {intI}_V_sort({intI} *x,long long q,long long n)
{{''')

  assert any(nhigh == Vsort2_min for nlow,nhigh in unrolledVsort)
  for nlow,nhigh in unrolledVsort:
    if nhigh == Vsort2_min and favorxor:
      print(f'  if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''')
    print(f'  if (n <= {nhigh}) {{ {intI}_V_sort_{nlow}through{nhigh}(x,n); return; }}')

  if not favorxor:
    print(f'  if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''')

  print(f'''
  // 64 <= q < n < 2q
  q >>= 2;
  // 64 <= 4q < n < 8q

  if (7*q < n) {{
    {intI}_threestages_8_up(x,q,n-7*q);
    {intI}_threestages_7_up(x+n-7*q,q,8*q-n);
  }} else if (6*q < n) {{
    {intI}_threestages_7_up(x,q,n-6*q);
    {intI}_threestages_6_up(x+n-6*q,q,7*q-n);
  }} else if (5*q < n) {{
    {intI}_threestages_6_up(x,q,n-5*q);
    {intI}_threestages_5_up(x+n-5*q,q,6*q-n);
  }} else {{
    {intI}_threestages_5_up(x,q,n-4*q);
    {intI}_threestages_4_up(x+n-4*q,q,5*q-n);
  }}

  // now want to handle each batch of q entries separately

  {intI}_V_sort(x,q>>1,q);
  {intI}_V_sort(x+q,q>>1,q);
  {intI}_V_sort(x+2*q,q>>1,q);
  {intI}_V_sort(x+3*q,q>>1,q);
  x += 4*q;
  n -= 4*q;
  while (n >= q) {{
    {intI}_V_sort(x,q>>1,q);
    x += q;
    n -= q;
  }}

  // have n entries left in last batch, with 0 <= n < q
  if (n <= 1) return;
  while (q >= n) q >>= 1; // empty merge stage
  // now 1 <= q < n <= 2q
  if (q >= 8) {{ {intI}_V_sort(x,q,n); return; }}

  if (n == 8) {{
    {intI}_MINMAX(x[0],x[4]);
    {intI}_MINMAX(x[1],x[5]);
    {intI}_MINMAX(x[2],x[6]);
    {intI}_MINMAX(x[3],x[7]);
    {intI}_MINMAX(x[0],x[2]);
    {intI}_MINMAX(x[1],x[3]);
    {intI}_MINMAX(x[0],x[1]);
    {intI}_MINMAX(x[2],x[3]);
    {intI}_MINMAX(x[4],x[6]);
    {intI}_MINMAX(x[5],x[7]);
    {intI}_MINMAX(x[4],x[5]);
    {intI}_MINMAX(x[6],x[7]);
    return;
  }}
  if (4 <= n) {{
    for (long long i = 0;i < n-4;++i)
      {intI}_MINMAX(x[i],x[4+i]);
    {intI}_MINMAX(x[0],x[2]);
    {intI}_MINMAX(x[1],x[3]);
    {intI}_MINMAX(x[0],x[1]);
    {intI}_MINMAX(x[2],x[3]);
    n -= 4;
    x += 4;
  }}
  if (3 <= n)
    {intI}_MINMAX(x[0],x[2]);
  if (2 <= n)
    {intI}_MINMAX(x[0],x[1]);
}}
''')

# ===== main sort

def main_sort_prep():
  smallsize()
  for nlow,nhigh in unrolledsort:
    fixedsize(nlow,nhigh)
  for n in unrolledsortxor:
    fixedsize(n,n,xor=True)

def main_sort():
  print('// XXX: currently xor must be 0 or -1')
  print('NOINLINE')
  print(f'static void {intI}_{sort2}_xor({intI} *x,long long n,{intI} xor)')
  print('{')

  for n in unrolledsortxor:
    print(f'  if (n == {n}) {{ {intI}_sort_{n}_xor(x,xor); return; }}')

  print(f'  {intI}_{sort2}_xor(x,n>>1,~xor);')
  print(f'  {intI}_{sort2}_xor(x+(n>>1),n>>1,xor);')
  print(f'  {intI}_{Vsort2}_xor(x,n,xor);')
  print('}')

  print(fr'''
void {intI}_sort({intI} *x,long long n)
{{ long long q;
  if (n <= 1) return;
  if (n == 2) {{ {intI}_MINMAX(x[0],x[1]); return; }}
  if (n <= 7) {{ {intI}_sort_3through7(x,n); return; }}''')
  # XXX: n cutoff here should be another variable to optimize

  nmin = 8
  # invariant: n in program is at least nmin

  assert any(nhigh == sort2_min for nlow,nhigh in unrolledsort)
  for nlow,nhigh in unrolledsort:
    if nhigh == sort2_min and favorxor:
      print(f'  if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''')
    if nlow <= nmin:
      print(f'  if (n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}')
      while nlow <= nmin and nmin <= nhigh: nmin += 1
    else:
      print(f'  if ({nlow} <= n && n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}')

  if not favorxor:
    print(f'  if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''')

  qmin = 1
  while qmin < nmin-qmin: qmin += qmin
  assert sort2_min <= qmin

  for padlow,padhigh in pads:
    padlowdown = padlow
    while padlowdown%vectorlen: padlowdown -= 1
    padhighup = padhigh
    while padhighup%vectorlen: padhighup += 1

    print(fr'''  if ({padlow} <= n && n <= {padhigh}) {{
    {intI} buf[{padhighup}];
    for (long long i = {padlowdown};i < {padhighup};++i) buf[i] = {intI}_largest;
    for (long long i = 0;i < n;++i) buf[i] = x[i];
    {intI}_sort(buf,{padhighup});
    for (long long i = 0;i < n;++i) x[i] = buf[i];
    return;
  }}''')

  assert sort2_min%2 == 0
  assert sort2_min//2 >= Vsort2_min

  print(fr'''
  q = {qmin};
  while (q < n - q) q += q;
  // {qmin} <= q < n < 2q

  if ({sort2_min*16} <= n && n <= (7*q)>>2) {{
    long long m = (3*q)>>2; // strategy: sort m, sort n-m, merge
    long long r = q>>3; // at least {sort2_min} since q is at least {sort2_min*8}
    {intI}_{sort2}_xor(x,4*r,0);
    {intI}_{sort2}_xor(x+4*r,r,0);
    {intI}_{sort2}_xor(x+5*r,r,-1);
    {intI}_{Vsort2}_xor(x+4*r,2*r,-1);
    {intI}_threestages_6_down(x,r,r);
    for (long long i = 0;i < 6;++i)
      {intI}_{Vsort2}_xor(x+i*r,r,-1);
    {intI}_sort(x+m,n-m);
  }} else if ({sort2_min*2} <= q && n == (3*q)>>1) {{
    // strategy: sort q, sort q/2, merge
    long long r = q>>2; // at least {sort2_min//2} since q is at least {sort2_min*2}
    {intI}_{sort2}_xor(x,4*r,-1);
    {intI}_{sort2}_xor(x+4*r,2*r,0);
    {intI}_threestages_6_up(x,r,r);
    for (long long i = 0;i < 6;++i)
      {intI}_{Vsort2}_xor(x+i*r,r,0);
    return;
  }} else {{
    {intI}_{sort2}_xor(x,q,-1);
    {intI}_sort(x+q,n-q);
  }}

  {intI}_V_sort(x,q,n);
}}''')

# ===== driver

preamble()
main_sort_prep()
V_sort()
main_sort()
