-rwxr-xr-x 7003 djbsort-20180710/verif/minmax
#!/usr/bin/env python2 import sys from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums n = int(sys.argv[1]) inputname = tuple('in_%d_32' % i for i in range(n)) outputname = tuple('out%d' % i for i in range(n)) def group(s): def t(x): x = list(x) if len(x) == 1: return x return [[s] + x] return t lparen = Literal('(').suppress() rparen = Literal(')').suppress() comma = Literal(',').suppress() equal = Literal('=').suppress() number = Word(nums) name = Word(alphas,alphas+nums+"_") assignSLE = (name + equal + Literal('SLE').suppress() + lparen + name + comma + name + rparen ).setParseAction(group('SLE')) assignIf = (name + equal + Literal('If').suppress() + lparen + name + comma + name + comma + name + rparen ).setParseAction(group('If')) assignExtract = (name + equal + Literal('Extract').suppress() + lparen + name + comma + number + comma + number + rparen ).setParseAction(group('Extract')) assignConcat = (name + equal + Literal('Concat').suppress() + lparen + name + ZeroOrMore(comma + name) + rparen ).setParseAction(group('Concat')) assigncopy = (name + equal + name).setParseAction(group('copy')) assignment = assignSLE | assignIf | assignExtract | assignConcat | assigncopy assignments = ZeroOrMore(assignment) + StringEnd() program = sys.stdin.read() program = assignments.parseString(program) program = list(program) nextvalue = 0 # indexed by variable name: value = dict() # indexed by value: operation = dict() parents = dict() bits = dict() for v in inputname: nextvalue += 1 value[v] = nextvalue operation[nextvalue] = ['input',v] parents[nextvalue] = [] bits[nextvalue] = 32 for p in program: if p[1] in value: raise Exception('%s assigned twice',p[1]) if p[0] == 'copy': value[p[1]] = value[p[2]] continue nextvalue += 1 operation[nextvalue] = [p[0]] if p[0] == 'SLE': value[p[1]] = nextvalue parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = 1 elif p[0] == 'If': assert bits[value[p[2]]] == 1 assert bits[value[p[3]]] == bits[value[p[4]]] value[p[1]] = nextvalue parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = bits[value[p[3]]] elif p[0] == 'Extract': top = int(p[3]) bot = int(p[4]) assert top >= bot assert bits[value[p[2]]] > top assert bot >= 0 value[p[1]] = nextvalue operation[nextvalue] += [top,bot] parents[nextvalue] = [value[p[2]]] bits[nextvalue] = top + 1 - bot elif p[0] == 'Concat': value[p[1]] = nextvalue parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = sum(bits[v] for v in parents[nextvalue]) else: raise Exception('unknown internal operation %s' % p[0]) progress = True while progress: progress = False # rewrite If(SLE(s,t),s,t) as Min(s,t) for v in operation: if operation[v][0] == 'If': x = parents[v][0] if operation[x][0] == 'SLE': if parents[v][1] == parents[x][0]: if parents[v][2] == parents[x][1]: operation[v] = ['Min'] parents[v] = parents[x] progress = True # rewrite If(SLE(t,s),s,t) as Max(s,t) for v in operation: if operation[v][0] == 'If': x = parents[v][0] if operation[x][0] == 'SLE': if parents[v][1] == parents[x][1]: if parents[v][2] == parents[x][0]: operation[v] = ['Max'] parents[v] = parents[x] progress = True # rewrite If(c,s[top:bot],t[top:bot]) # as If(c,s,t)[top:bot] ifextract = [] for z in operation: if operation[z][0] == 'If': c,x,y = parents[z] if operation[x][0] == 'Extract': if operation[y] == operation[x]: # extract same bits ifextract += [z] for z in ifextract: c,x,y = parents[z] s = parents[x][0] t = parents[y][0] nextvalue += 1 operation[nextvalue] = ['If'] parents[nextvalue] = [c,s,t] bits[nextvalue] = bits[s] assert bits[s] == bits[t] operation[z] = operation[x] parents[z] = [nextvalue] progress = True children = dict() for z in operation: children[z] = set() for v in outputname: children[value[v]].add(-1) for z in operation: for x in parents[z]: children[x].add(z) deleting = set(v for v in operation if len(children[v]) == 0) merging = deleting.copy() merge = [] for x in operation: c = list(children[x]) for y,z in [(c[i],c[j]) for j in range(len(c)) for i in range(j)]: if y == -1: continue if z == -1: continue if operation[y] != operation[z]: continue parentsmatch = False if parents[y] == parents[z]: parentsmatch = True if operation[y][0] in ['Min','Max']: if set(parents[y]) == set(parents[z]): parentsmatch = True if not parentsmatch: continue assert bits[y] == bits[z] if y in merging: continue if z in merging: continue merge += [(y,z)] merging.add(y) merging.add(z) for y,z in merge: # print 'eliminating %s in favor of %s' % (z,y) for t in children[z]: if t == -1: for v in outputname: if value[v] == z: value[v] = y else: for j in range(len(parents[t])): if parents[t][j] == z: parents[t][j] = y deleting.add(z) progress = True for v in deleting: del operation[v] del parents[v] del bits[v] progress = True children = dict() for z in operation: children[z] = set() for v in outputname: children[value[v]].add(-1) for z in operation: for x in parents[z]: children[x].add(z) concatextract = [] for z in operation: if operation[z][0] == 'Concat': if all(operation[y][0] == 'Extract' for y in parents[z]): source = parents[parents[z][0]][0] if all(parents[y][0] == source for y in parents[z]): remainingbits = bits[z] ok = True for y in parents[z]: if operation[y][1] == remainingbits - 1: remainingbits = operation[y][2] else: ok = False if ok: concatextract += [(z,source)] for y,source in concatextract: for z in children[y]: if z == -1: for v in outputname: if value[v] == y: value[v] = source else: for i in range(len(parents[z])): if parents[z][i] == y: parents[z][i] = source progress = True done = set() def do(v): if v in done: return done.add(v) for x in parents[v]: do(x) if operation[v][0] == 'input': print 'v%d = %s' % (v,operation[v][1]) else: p = ['v%s' % x for x in parents[v]] p += ['%s' % x for x in operation[v][1:]] print 'v%d = %s(%s)' % (v,operation[v][0],','.join(p)) for v in outputname: do(value[v]) print '%s = v%d' % (v,value[v])