-rwxr-xr-x 12617 djbsort-20180729/verif/minmax
#!/usr/bin/env python2 import sys from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums n = int(sys.argv[1]) inputname = tuple('in_%d_32' % i for i in range(n)) outputname = tuple('out%d' % i for i in range(n)) def group(s): def t(x): x = list(x) if len(x) == 1: return x return [[s] + x] return t lparen = Literal('(').suppress() rparen = Literal(')').suppress() comma = Literal(',').suppress() equal = Literal('=').suppress() number = Word(nums) name = Word(alphas,alphas+nums+"_") assignment = ( name + equal + Literal('constant').suppress() + lparen + number + comma + number + rparen ).setParseAction(group('constant')) assignment |= ( name + equal + Literal('invert').suppress() + lparen + name + rparen ).setParseAction(group('invert')) assignment |= ( name + equal + Literal('Reverse').suppress() + lparen + name + rparen ).setParseAction(group('Reverse')) for binary in ['xor','or','and','add','sub','mul','signedrshift','signedle','signedlt','equal']: assignment |= ( name + equal + Literal(binary).suppress() + lparen + name + comma + name + rparen ).setParseAction(group(binary)) assignment |= ( name + equal + Literal('If').suppress() + lparen + name + comma + name + comma + name + rparen ).setParseAction(group('If')) assignment |= ( name + equal + Literal('Extract').suppress() + lparen + name + comma + number + comma + number + rparen ).setParseAction(group('Extract')) assignment |= ( name + equal + Literal('Concat').suppress() + lparen + name + ZeroOrMore(comma + name) + rparen ).setParseAction(group('Concat')) assignment |= ( name + equal + name).setParseAction(group('copy') ) assignments = ZeroOrMore(assignment) + StringEnd() program = sys.stdin.read() program = assignments.parseString(program) program = list(program) nextvalue = 0 # indexed by variable name: value = dict() # indexed by value: operation = dict() parents = dict() bits = dict() for v in inputname: nextvalue += 1 value[v] = nextvalue operation[nextvalue] = ['input',v] parents[nextvalue] = [] bits[nextvalue] = 32 for p in program: if p[1] in value: raise Exception('%s assigned twice',p[1]) if p[0] == 'copy': value[p[1]] = value[p[2]] continue nextvalue += 1 operation[nextvalue] = [p[0]] if p[0] == 'constant': parents[nextvalue] = [] operation[nextvalue] += [int(p[2]),int(p[3])] bits[nextvalue] = int(p[2]) elif p[0] in ['signedle','signedlt','equal']: # binary operation producing bit assert bits[value[p[2]]] == bits[value[p[3]]] parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = 1 elif p[0] in ['invert','Reverse']: # unary size-preserving operation parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = bits[value[p[2]]] elif p[0] in ['xor','or','and','add','sub','mul','signedrshift']: # binary size-preserving operation assert bits[value[p[2]]] == bits[value[p[3]]] parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = bits[value[p[2]]] elif p[0] == 'If': assert bits[value[p[2]]] == 1 assert bits[value[p[3]]] == bits[value[p[4]]] parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = bits[value[p[3]]] elif p[0] == 'Extract': top = int(p[3]) bot = int(p[4]) assert top >= bot assert bits[value[p[2]]] > top assert bot >= 0 operation[nextvalue] += [top,bot] parents[nextvalue] = [value[p[2]]] bits[nextvalue] = top + 1 - bot elif p[0] == 'Concat': parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = sum(bits[v] for v in parents[nextvalue]) else: raise Exception('unknown internal operation %s' % p[0]) value[p[1]] = nextvalue progress = True while progress: progress = False # rewrite invert(invert(c)) as c # rewrite Reverse(Reverse(c)) as c for v in operation: if operation[v][0] in ['invert','Reverse']: x = parents[v][0] if operation[x][0] == operation[v][0]: y = parents[x][0] operation[v] = operation[y] parents[v] = parents[y] progress = True # rewrite xor(s,xor(s,t)) as t for v in operation: if operation[v][0] != 'xor': continue s,x = parents[v] if operation[x][0] != 'xor': x,s = parents[v] if operation[x][0] != 'xor': continue y,t = parents[x] if y != s: t,y = parents[x] if y != s: continue operation[v] = operation[t] parents[v] = parents[t] progress = True # if c is constant(b,2**(b-1)-1): # rewrite signedmin(s,c) as s # rewrite signedmax(s,c) as c for v in operation: b = bits[v] if operation[v][0] not in ['signedmin','signedmax']: continue s,c = parents[v] if operation[c] != ['constant',b,2**(b-1)-1]: c,s = parents[v] if operation[c] != ['constant',b,2**(b-1)-1]: continue if operation[v][0] == 'signedmin': operation[v] = operation[s] parents[v] = parents[s] progress = True continue if operation[v][0] == 'signedmax': operation[v] = operation[c] parents[v] = parents[c] progress = True continue # if c is constant(b,2**(b-1)): # rewrite signedmin(s,c) as c # rewrite signedmax(s,c) as s for v in operation: b = bits[v] if operation[v][0] not in ['signedmin','signedmax']: continue s,c = parents[v] if operation[c] != ['constant',b,2**(b-1)]: c,s = parents[v] if operation[c] != ['constant',b,2**(b-1)]: continue if operation[v][0] == 'signedmax': operation[v] = operation[s] parents[v] = parents[s] progress = True continue if operation[v][0] == 'signedmin': operation[v] = operation[c] parents[v] = parents[c] progress = True continue # if c is constant(b,2**b-1): # rewrite xor(c,s) as invert(s) for v in operation: b = bits[v] if operation[v][0] != 'xor': continue c,s = parents[v] if operation[c] != ['constant',b,2**b-1]: s,c = parents[v] if operation[c] != ['constant',b,2**b-1]: continue operation[v] = ['invert'] parents[v] = [s] progress = True # rewrite signedmin(invert(s),invert(t)) as invert(signedmax(s,t)) # rewrite signedmax(invert(s),invert(t)) as invert(signedmin(s,t)) invertminmax = [] for v in operation: b = bits[v] if operation[v][0] in ['signedmin','signedmax']: x,y = parents[v] if operation[x][0] == 'invert' and operation[y][0] == 'invert': invertminmax += [(v,parents[x][0],parents[y][0])] for v,s,t in invertminmax: nextvalue += 1 flip = {'signedmax':'signedmin','signedmin':'signedmax'} operation[nextvalue] = [flip[operation[v][0]]] parents[nextvalue] = [s,t] bits[nextvalue] = bits[v] operation[v] = ['invert'] parents[v] = [nextvalue] progress = True # rewrite invert(s)[high:low] as invert(s[high:low]) extractinvert = [] for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'invert': continue s = parents[x][0] extractinvert += [(v,s)] for v,s in extractinvert: nextvalue += 1 operation[nextvalue] = operation[v] parents[nextvalue] = [s] bits[nextvalue] = bits[v] operation[v] = ['invert'] parents[v] = [nextvalue] progress = True # rewrite s[bits-1:0] as s for v in operation: if operation[v][0] != 'Extract': continue s = parents[v][0] if operation[v][1:] != [bits[s]-1,0]: continue operation[v] = operation[s] parents[v] = parents[s] progress = True # rewrite s[high:low][high2:low2] as s[high2+low:low2+low] for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'Extract': continue s = parents[x][0] high2,low2 = operation[v][1:] high,low = operation[x][1:] operation[v] = ['Extract',high2 + low,low2 + low] parents[v] = [s] progress = True # rewrite Reverse(s)[high:low] as Reverse(s[...]) if possible extractreverse = [] for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'Reverse': continue high,low = operation[v][1:] if low % 8 != 0: continue if high % 8 != 7: continue extractreverse += [(v,parents[x][0],high,low)] for v,s,high,low in extractreverse: nextvalue += 1 operation[nextvalue] = ['Extract',bits[s] - 1 - low,bits[s] - 1 - high] parents[nextvalue] = [s] bits[nextvalue] = high - low + 1 assert bits[nextvalue] == bits[v] operation[v] = ['Reverse'] parents[v] = [nextvalue] progress = True # rewrite Concat(...)[high:low] as ...[high-pos:low-pos] if possible for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'Concat': continue high,low = operation[v][1:] pos = 0 for y in reversed(parents[x]): if pos <= low and high < pos + bits[y]: operation[v] = ['Extract',high - pos,low - pos] parents[v] = [y] progress = True break pos += bits[y] # rewrite If(signedle(s,t),s,t) as signedmin(s,t) # rewrite If(signedlt(s,t),s,t) as signedmin(s,t) for v in operation: if operation[v][0] == 'If': x = parents[v][0] if operation[x][0] in ['signedle','signedlt']: if parents[v][1] == parents[x][0]: if parents[v][2] == parents[x][1]: operation[v] = ['signedmin'] parents[v] = parents[x] progress = True # rewrite If(signedle(t,s),s,t) as signedmax(s,t) # rewrite If(signedlt(t,s),s,t) as signedmax(s,t) for v in operation: if operation[v][0] == 'If': x = parents[v][0] if operation[x][0] in ['signedle','signedlt']: if parents[v][1] == parents[x][1]: if parents[v][2] == parents[x][0]: operation[v] = ['signedmax'] parents[v] = parents[x] progress = True # rewrite If(c,constant(1,1),constant(1,0)) as c for v in operation: if operation[v][0] == 'If': c,x,y = parents[v] if operation[x] == ['constant',1,1]: if operation[y] == ['constant',1,0]: operation[v] = operation[c] parents[v] = parents[c] progress = True # rewrite xor(c,constant(1,1)) as invert(c) for v in operation: if operation[v][0] == 'xor': c,x = parents[v] if operation[x] == ['constant',1,1]: operation[v] = ['invert'] parents[v] = [c] progress = True # rewrite equal(c,constant(1,0)) as invert(c) for v in operation: if operation[v][0] == 'equal': c,x = parents[v] if operation[x] == ['constant',1,0]: operation[v] = ['invert'] parents[v] = [c] progress = True children = dict() for z in operation: children[z] = set() for v in outputname: children[value[v]].add(-1) for z in operation: for x in parents[z]: children[x].add(z) deleting = set(v for v in operation if len(children[v]) == 0) merging = deleting.copy() merge = [] for x in operation: c = list(children[x]) for y,z in [(c[i],c[j]) for j in range(len(c)) for i in range(j)]: if y == -1: continue if z == -1: continue if operation[y] != operation[z]: continue parentsmatch = False if parents[y] == parents[z]: parentsmatch = True if operation[y][0] in ['signedmin','signedmax']: if set(parents[y]) == set(parents[z]): parentsmatch = True if not parentsmatch: continue assert bits[y] == bits[z] if y in merging: continue if z in merging: continue merge += [(y,z)] merging.add(y) merging.add(z) for y,z in merge: # print 'eliminating %s in favor of %s' % (z,y) for t in children[z]: if t == -1: for v in outputname: if value[v] == z: value[v] = y else: for j in range(len(parents[t])): if parents[t][j] == z: parents[t][j] = y deleting.add(z) progress = True for v in deleting: del operation[v] del parents[v] del bits[v] progress = True done = set() def do(v): if v in done: return done.add(v) for x in parents[v]: do(x) if operation[v][0] == 'input': print 'v%d = %s' % (v,operation[v][1]) else: p = ['v%s' % x for x in parents[v]] p += ['%s' % x for x in operation[v][1:]] print 'v%d = %s(%s)' % (v,operation[v][0],','.join(p)) for v in outputname: do(value[v]) print '%s = v%d' % (v,value[v])