#!/usr/bin/env python3

import os
import sys
import multiprocessing
import subprocess
import traceback
import resource
from datetime import datetime,timezone

compiler = os.getenv('CC','gcc')+' '+os.getenv('CFLAGS','')

top = datetime.now(timezone.utc).strftime('results-%Y%m%d-%H%M%S-UTC')
top = f'{top}-pid{os.getpid()}'
os.makedirs(top)
try:
  os.remove('results')
except FileNotFoundError:
  pass
os.symlink(top,'results')

try:
  os_threads = len(os.sched_getaffinity(0))
except AttributeError:
  os_threads = multiprocessing.cpu_count()
os_threads = os.getenv('THREADS',default=os_threads)
os_threads = int(os_threads)
if os_threads < 1: os_threads = 1

resource.setrlimit(resource.RLIMIT_NOFILE,2*resource.getrlimit(resource.RLIMIT_NOFILE)[1:])

def cputime():
  return resource.getrusage(resource.RUSAGE_SELF).ru_utime + resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime

# ===== handle library in non-system locations

lib = 'libdjbsort.so'
for libdir in os.getenv('LD_LIBRARY_PATH','').split(':'):
  for fn in f'{libdir}/{lib}',f'{libdir}/{lib}.1':
    if os.path.exists(fn):
      lib = fn
      break

# ===== endianness

os.makedirs(f'{top}/run',exist_ok=True)
binary = f'{top}/run/endianness'
with open(f'{binary}.c','w') as f:
    f.write('''#include <stdlib.h>
#include <inttypes.h>

int main()
{
  uint16_t a = 0x0201;
  uint32_t b = 0x04030201;
  uint64_t c = 0x0807060504030201ULL;
  char *A = (char *) &a;
  char *B = (char *) &b;
  char *C = (char *) &c;

  if (A[0] == 1 && A[1] == 2 &&
      B[0] == 1 && B[1] == 2 && B[2] == 3 && B[3] == 4 &&
      C[0] == 1 && C[1] == 2 && C[2] == 3 && C[3] == 4 &&
      C[4] == 5 && C[5] == 6 && C[6] == 7 && C[7] == 8)
    return 12;

  if (A[0] == 2 && A[1] == 1 &&
      B[0] == 4 && B[1] == 3 && B[2] == 2 && B[3] == 1 &&
      C[0] == 8 && C[1] == 7 && C[2] == 6 && C[3] == 5 &&
      C[4] == 4 && C[5] == 3 && C[6] == 2 && C[7] == 1)
    return 21;

  return 99;
}
''')

command = f'{compiler} -o {binary} {binary}.c'
subprocess.run(command.split(),check=True)
endianness = subprocess.run([binary]).returncode
endianness = {12:'littleendian',21:'bigendian'}[endianness]

# ===== figure out implementations supported by emulated CPU
# (not necessarily the same as the physical CPU; run under ./outsim)

impls = {}

for bits in 32,64:
  os.makedirs(f'{top}/run',exist_ok=True)
  binary = f'{top}/run/int{bits}-ic'
  with open(f'{binary}.c','w') as f:
    f.write(fr'''#include <stdio.h>
#include <djbsort.h>

int main()
{{
  long long numimpl = djbsort_numimpl_int{bits}();
  for (long long i = 0;i < numimpl;++i)
    printf("%s %s\n",djbsort_dispatch_int{bits}_implementation(i),djbsort_dispatch_int{bits}_compiler(i));
  fflush(stdout);
  return 0;
}}
''')

  command = f'{compiler} -o {binary} {binary}.c -ldjbsort'
  subprocess.run(command.split(),check=True)

  command = f'./outsim {binary} {lib}'
  proc = subprocess.run(command.split(),stdout=subprocess.PIPE,check=True,universal_newlines=True)
  impls[bits] = proc.stdout.splitlines()

# ===== main binaries

for bits in 32,64:
  Ibytes = bits//8

  os.makedirs(f'{top}/run',exist_ok=True)
  binary = f'{top}/run/int{bits}'
  with open(f'{binary}.c','w') as f:
    f.write(fr'''#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <djbsort.h>

static int64_t ato64(const char *s)
{{
  int64_t result = 0;
  while (*s) result = result*10 + (*s++-'0');
  return result;
}}

int main(int argc,char **argv)
{{
  if (!argv[1]) exit(111);
  int64_t n = ato64(argv[1]);
  int64_t impl = ato64(argv[2]);
  char *implname = argv[3];
  char *implcompiler = argv[4];

  if (strcmp(implname,djbsort_dispatch_int{bits}_implementation(impl))) exit(111);
  if (strcmp(implcompiler,djbsort_dispatch_int{bits}_compiler(impl))) exit(111);

  int{bits}_t *x = (int{bits}_t *) malloc(n*{Ibytes});

  for (long long i = 0;i < n;++i)
    if (read(0,(char *) &x[i],{Ibytes}) != {Ibytes})
      exit(111);

  djbsort_dispatch_int{bits}(impl)(x,n);

  for (long long i = 0;i < n;++i)
    if (write(1,(char *) &x[i],{Ibytes}) != {Ibytes})
      exit(111);

  return 0;
}}
''')

  command = f'{compiler} -o {binary} {binary}.c -ldjbsort'
  subprocess.run(command.split(),check=True)

# ===== main driver

def doit(arg):
  bits,n,i = arg
  binary = f'{top}/run/int{bits}'

  implname = impls[bits][i].split(' ')[0]
  implcompiler = ' '.join(impls[bits][i].split(' ')[1:])

  runstarttime = cputime()

  result = ''
  for op,command in (
    # beware: implcompiler normally has spaces
    ('unroll',['./unroll',f'{bits}',f'{n}',f'{i}',implname,implcompiler,binary,endianness,lib]),
    ('minmax',['./minmax',f'{bits}',f'{n}']),
    ('decompose',['./decompose',f'{bits}',f'{n}']),
  ):
    taskstarttime = cputime()
    try:
      proc = subprocess.run(command,input=result,stdout=subprocess.PIPE,stderr=subprocess.PIPE,universal_newlines=True)
      out,err = proc.stdout,proc.stderr
    except (OSError,FileNotFoundError):
      os.makedirs(f'{top}/{op}failed/int{bits}/impl{i}',exist_ok=True)
      with open(f'{top}/{op}failed/int{bits}/impl{i}/{n}','w') as f: f.write(traceback.format_exc())
      return bits,n,i,f'{op}failed'
    os.makedirs(f'{top}/{op}time/int{bits}/impl{i}',exist_ok=True)
    with open(f'{top}/{op}time/int{bits}/impl{i}/{n}','w') as f: f.write(f'{cputime()-taskstarttime:.6f}\n')
    if len(out) > 0:
      os.makedirs(f'{top}/{op}stdout/int{bits}/impl{i}',exist_ok=True)
      with open(f'{top}/{op}stdout/int{bits}/impl{i}/{n}','w') as f: f.write(out)
    if len(err) > 0:
      os.makedirs(f'{top}/{op}stderr/int{bits}/impl{i}',exist_ok=True)
      with open(f'{top}/{op}stderr/int{bits}/impl{i}/{n}','w') as f: f.write(err)
    if proc.returncode != 0:
      os.makedirs(f'{top}/{op}failed/int{bits}/impl{i}',exist_ok=True)
      with open(f'{top}/{op}failed/int{bits}/impl{i}/{n}','w') as f: f.write(f'exit code {proc.returncode}\n')
      return bits,n,i,f'{op}failed'
    result = out

  os.makedirs(f'{top}/successtime/int{bits}/impl{i}',exist_ok=True)
  with open(f'{top}/successtime/int{bits}/impl{i}/{n}','w') as f: f.write(f'{cputime()-runstarttime:.6f}\n')
  os.makedirs(f'{top}/success/int{bits}/impl{i}',exist_ok=True)
  with open(f'{top}/success/int{bits}/impl{i}/{n}','w') as f: pass
  return bits,n,i,'success'

for bits in 32,64:
  for i,impl in enumerate(impls[bits]):
    os.makedirs(f'{top}/impldata/int{bits}',exist_ok=True)
    with open(f'{top}/impldata/int{bits}/impl{i}','w') as f:
      f.write(impl+'\n')

def inputs():
  for line in sys.stdin:
    for word in line.split():
      yield int(word)

def todo():
  for n in inputs():
    for bits in 32,64:
      for i in range(len(impls[bits])):
        yield bits,n,i

numresults = 0
numsuccesses = 0
firstfailure = None

with open(f'{top}/overview','w') as f:
  with multiprocessing.Pool(os_threads) as pool:
    for bits,n,i,result in pool.imap(doit,todo(),chunksize=1):
      numresults += 1
      if result == 'success':
        numsuccesses += 1
      else:
        if firstfailure is None:
          firstfailure = f'{result}/int{bits}/impl{i}/{n}'
      f.write(f'{result}/int{bits}/impl{i}/{n} {impls[bits][i]}\n')
      f.flush()

print(f'results: {numresults}')
print(f'successes: {numsuccesses}')
if numsuccesses != numresults:
  print(f'failures: {numresults-numsuccesses}; first is {firstfailure}')
