-rwxr-xr-x 23974 djbsort-20190516/verif/minmax
#!/usr/bin/env python3 import sys from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums n = int(sys.argv[1]) inputname = tuple('in_%d_32' % i for i in range(n)) outputname = tuple('out%d' % i for i in range(n)) def group(s): def t(x): x = list(x) if len(x) == 1: return x return [[s] + x] return t lparen = Literal('(').suppress() rparen = Literal(')').suppress() comma = Literal(',').suppress() equal = Literal('=').suppress() number = Word(nums) name = Word(alphas,alphas+nums+"_") assignment = ( name + equal + Literal('constant').suppress() + lparen + number + comma + number + rparen ).setParseAction(group('constant')) for unary in ['invert','Reverse']: assignment |= ( name + equal + Literal(unary).suppress() + lparen + name + rparen ).setParseAction(group(unary)) for binary in ['xor','or','and','add','sub','mul','signedrshift','signedle','signedlt','equal']: assignment |= ( name + equal + Literal(binary).suppress() + lparen + name + comma + name + rparen ).setParseAction(group(binary)) assignment |= ( name + equal + Literal('If').suppress() + lparen + name + comma + name + comma + name + rparen ).setParseAction(group('If')) assignment |= ( name + equal + Literal('Extract').suppress() + lparen + name + comma + number + comma + number + rparen ).setParseAction(group('Extract')) assignment |= ( name + equal + Literal('SignExt').suppress() + lparen + name + comma + number + rparen ).setParseAction(group('SignExt')) for manyary in ['Concat']: assignment |= ( name + equal + Literal(manyary).suppress() + lparen + name + ZeroOrMore(comma + name) + rparen ).setParseAction(group(manyary)) assignment |= ( name + equal + name).setParseAction(group('copy') ) assignments = ZeroOrMore(assignment) + StringEnd() program = sys.stdin.read() program = assignments.parseString(program) program = list(program) nextvalue = 0 # indexed by variable name: value = dict() # indexed by value: operation = dict() parents = dict() bits = dict() for v in inputname: nextvalue += 1 value[v] = nextvalue operation[nextvalue] = ['input',v] parents[nextvalue] = [] bits[nextvalue] = 32 for p in program: if p[1] in value: raise Exception('%s assigned twice',p[1]) if p[0] == 'copy': value[p[1]] = value[p[2]] continue nextvalue += 1 operation[nextvalue] = [p[0]] if p[0] == 'constant': parents[nextvalue] = [] operation[nextvalue] += [int(p[2]),int(p[3])] bits[nextvalue] = int(p[2]) elif p[0] in ['signedle','signedlt','equal']: # binary operation producing bit assert bits[value[p[2]]] == bits[value[p[3]]] parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = 1 elif p[0] in ['invert','Reverse']: # unary size-preserving operation parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = bits[value[p[2]]] elif p[0] in ['xor','or','and','add','sub','mul','signedrshift']: # binary size-preserving operation assert bits[value[p[2]]] == bits[value[p[3]]] parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = bits[value[p[2]]] elif p[0] == 'If': assert bits[value[p[2]]] == 1 assert bits[value[p[3]]] == bits[value[p[4]]] parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = bits[value[p[3]]] elif p[0] == 'Extract': top = int(p[3]) bot = int(p[4]) assert top >= bot assert bits[value[p[2]]] > top assert bot >= 0 operation[nextvalue] += [top,bot] parents[nextvalue] = [value[p[2]]] bits[nextvalue] = top + 1 - bot elif p[0] == 'SignExt': morebits = int(p[3]) operation[nextvalue] += [morebits] parents[nextvalue] = [value[p[2]]] bits[nextvalue] = bits[value[p[2]]] + morebits elif p[0] == 'Concat': parents[nextvalue] = [value[v] for v in p[2:]] bits[nextvalue] = sum(bits[v] for v in parents[nextvalue]) else: raise Exception('unknown internal operation %s' % p[0]) value[p[1]] = nextvalue progress = True while progress: progress = False # rewrite invert(invert(c)) as c # rewrite Reverse(Reverse(c)) as c for v in operation: if operation[v][0] in ['invert','Reverse']: x = parents[v][0] if operation[x][0] == operation[v][0]: y = parents[x][0] operation[v] = operation[y] parents[v] = parents[y] progress = True # rewrite s^(s^t) as t for v in operation: if operation[v][0] != 'xor': continue s,x = parents[v] if operation[x][0] != 'xor': x,s = parents[v] if operation[x][0] != 'xor': continue y,t = parents[x] if y != s: t,y = parents[x] if y != s: continue operation[v] = operation[t] parents[v] = parents[t] progress = True # rewrite (s^t)^(s^u) as xor(t,u) for v in operation: if operation[v][0] != 'xor': continue x,y = parents[v] if operation[x][0] != 'xor': continue if operation[y][0] != 'xor': continue done = False for s,t in [parents[x],reversed(parents[x])]: if done: continue for s2,u in [parents[y],reversed(parents[y])]: if done: continue if s == s2: parents[v] = [t,u] progress = True done = True # rewrite (s^(t^u))^t as s^u for v in operation: if operation[v][0] != 'xor': continue done = False for x,t in [parents[v],reversed(parents[v])]: if done: continue if operation[x][0] != 'xor': continue for s,y in [parents[x],reversed(parents[x])]: if done: continue if operation[y][0] != 'xor': continue for t2,u in [parents[y],reversed(parents[y])]: if done: continue if t2 == t: parents[v] = [s,u] progress = True done = True # rewrite q^(r^signedmax(q,r)) as signedmin(q,r) # rewrite q^(r^signedmin(q,r)) as signedmax(q,r) for v in operation: if operation[v][0] != 'xor': continue done = False for q,x in [parents[v],reversed(parents[v])]: if done: break if operation[x][0] != 'xor': continue for r,y in [parents[x],reversed(parents[x])]: if done: break if operation[y][0] not in ['signedmax','signedmin']: continue if parents[y] != [q,r]: if parents[y] != [r,q]: continue flip = {'signedmax':'signedmin','signedmin':'signedmax'} operation[v] = [flip[operation[y][0]]] parents[v] = [q,r] progress = True done = True # rewrite (q^r)^signedmax(q,r) as signedmin(q,r) # rewrite (q^r)^signedmin(q,r) as signedmax(q,r) for v in operation: if operation[v][0] != 'xor': continue x,y = parents[v] if operation[y][0] != 'signedmax': y,x = parents[v] if operation[y][0] not in ['signedmax','signedmin']: continue if operation[x][0] != 'xor': continue q,r = parents[x] if parents[y] != [q,r]: if parents[y] != [r,q]: continue flip = {'signedmax':'signedmin','signedmin':'signedmax'} operation[v] = [flip[operation[y][0]]] parents[v] = [q,r] progress = True # rewrite (s^q)^(r^signedmin(q,r)) as s^signedmax(q,r) xorxorxormin = [] for v in operation: if operation[v][0] != 'xor': continue done = False for x,y in [parents[v],reversed(parents[v])]: if done: break if operation[x][0] != 'xor': continue if operation[y][0] != 'xor': continue for s,q in [parents[x],reversed(parents[x])]: if done: break for r,t in [parents[y],reversed(parents[y])]: if done: break if operation[t][0] != 'signedmin': break if parents[t] != [q,r]: if parents[t] != [r,q]: break xorxorxormin += [(v,s,q,r)] done = True for v,s,q,r in xorxorxormin: nextvalue += 1 operation[nextvalue] = ['signedmax'] parents[nextvalue] = [q,r] bits[nextvalue] = bits[v] operation[v] = ['xor'] parents[v] = [s,nextvalue] progress = True # rewrite (s^q)^(r^signedmax(q,r)) as s^signedmin(q,r) xorxorxormax = [] for v in operation: if operation[v][0] != 'xor': continue done = False for x,y in [parents[v],reversed(parents[v])]: if done: break if operation[x][0] != 'xor': continue if operation[y][0] != 'xor': continue for s,q in [parents[x],reversed(parents[x])]: if done: break for r,t in [parents[y],reversed(parents[y])]: if done: break if operation[t][0] != 'signedmax': break if parents[t] != [q,r]: if parents[t] != [r,q]: break xorxorxormax += [(v,s,q,r)] done = True for v,s,q,r in xorxorxormax: nextvalue += 1 operation[nextvalue] = ['signedmin'] parents[nextvalue] = [q,r] bits[nextvalue] = bits[v] operation[v] = ['xor'] parents[v] = [s,nextvalue] progress = True # rewrite q^(s^(r^signedmin(q,r))) as s^signedmax(q,r) xorxorxormin = [] for v in operation: if operation[v][0] != 'xor': continue done = False for q,x in [parents[v],reversed(parents[v])]: if done: break if operation[x][0] != 'xor': continue for s,y in [parents[x],reversed(parents[x])]: if done: break if operation[y][0] != 'xor': continue for r,t in [parents[y],reversed(parents[y])]: if done: break if operation[t][0] != 'signedmin': break if parents[t] != [q,r]: if parents[t] != [r,q]: break xorxorxormin += [(v,s,q,r)] done = True for v,s,q,r in xorxorxormin: nextvalue += 1 operation[nextvalue] = ['signedmax'] parents[nextvalue] = [q,r] bits[nextvalue] = bits[v] operation[v] = ['xor'] parents[v] = [s,nextvalue] progress = True # rewrite s=sub(q,r);Q=q[31];R=r[31];S=s[31];t=Q^R;u=Q^S;v=t&u;e=s==0;f=(S^v)|e;If(f,q,r) # as signedmin(q,r) signedmin = [] for y in operation: if operation[y][0] != 'If': continue f,q,r = parents[y] if operation[f][0] != 'or': continue Sv,e = parents[f] if operation[e][0] != 'equal': e,Sv = Sv,e if operation[e][0] != 'equal': continue S,v = parents[Sv] if operation[v][0] != 'and': v,S = S,v if operation[v][0] != 'and': continue s,c = parents[e] if operation[c] != ['constant',bits[y],0]: c,s = parents[e] if operation[c] != ['constant',bits[y],0]: continue t,u = parents[v] if operation[t][0] != 'xor': continue if operation[u][0] != 'xor': continue # XXX: also handle reverse order of t,u Q,R = parents[t] if operation[Q] != ['Extract',bits[y]-1,bits[y]-1]: continue if operation[R] != ['Extract',bits[y]-1,bits[y]-1]: continue # XXX: also handle reverse order of Q,R Q2,S = parents[u] if Q2 != Q: continue if operation[S] != ['Extract',bits[y]-1,bits[y]-1]: continue if parents[Q] != [q]: continue if parents[R] != [r]: continue if parents[S] != [s]: continue if operation[s][0] != 'sub': continue if parents[s] != [q,r]: continue signedmin += [(y,q,r)] for y,q,r in signedmin: operation[y] = ['signedmin'] parents[y] = [q,r] progress = True # rewrite s=sub(q,r);Q=q[31];R=r[31];S=s[31];t=Q^R;u=Q^S;v=t&u;e=s==0;f=(S^v)|e;If(f,r,q) # as signedmax(q,r) signedmax = [] for y in operation: if operation[y][0] != 'If': continue f,r,q = parents[y] if operation[f][0] != 'or': continue Sv,e = parents[f] if operation[e][0] != 'equal': e,Sv = Sv,e if operation[e][0] != 'equal': continue S,v = parents[Sv] if operation[v][0] != 'and': v,S = S,v if operation[v][0] != 'and': continue s,c = parents[e] if operation[c] != ['constant',bits[y],0]: c,s = parents[e] if operation[c] != ['constant',bits[y],0]: continue t,u = parents[v] if operation[t][0] != 'xor': continue if operation[u][0] != 'xor': continue # XXX: also handle reverse order of t,u Q,R = parents[t] if operation[Q] != ['Extract',bits[y]-1,bits[y]-1]: continue if operation[R] != ['Extract',bits[y]-1,bits[y]-1]: continue # XXX: also handle reverse order of Q,R Q2,S = parents[u] if Q2 != Q: continue if operation[S] != ['Extract',bits[y]-1,bits[y]-1]: continue if parents[Q] != [q]: continue if parents[R] != [r]: continue if parents[S] != [s]: continue if operation[s][0] != 'sub': continue if parents[s] != [q,r]: continue signedmax += [(y,q,r)] for y,q,r in signedmax: operation[y] = ['signedmax'] parents[y] = [q,r] progress = True # rewrite s=q-r;t=q^r;u=s^q;v=u&t;w=v^s;x=signedrshift(w,constant(b,b-1));y=x&t # as q^signedmax(q,r) xorsignedmax = [] for y in operation: if operation[y][0] != 'and': continue x,t = parents[y] if operation[x][0] != 'signedrshift': continue if operation[t][0] != 'xor': continue w,dist = parents[x] if operation[dist][0] != 'constant': continue if operation[dist][2] != operation[dist][1] - 1: continue if operation[w][0] != 'xor': continue v,s = parents[w] if operation[v][0] != 'and': s,v = parents[w] if operation[v][0] != 'and': continue if operation[s][0] != 'sub': continue u,t2 = parents[v] if t2 != t: t2,u = parents[v] if t2 != t: continue if operation[u][0] != 'xor': continue q,s2 = parents[u] if s2 != s: s2,q = parents[u] if s2 != s: continue r,q2 = parents[t] if q2 != q: q2,r = parents[t] if q2 != q: continue if parents[s] != [q,r]: continue xorsignedmax += [(y,q,r)] for y,q,r in xorsignedmax: nextvalue += 1 operation[nextvalue] = ['signedmax'] parents[nextvalue] = [q,r] bits[nextvalue] = bits[y] operation[y] = ['xor'] parents[y] = [q,nextvalue] progress = True # rewrite s=r-q;t=q^r;u=s^r;v=u&t;w=v^s;x=signedrshift(w,constant(b,b-1));y=x&t # as q^signedmin(q,r) xorsignedmin = [] for y in operation: if operation[y][0] != 'and': continue x,t = parents[y] if operation[x][0] != 'signedrshift': continue if operation[t][0] != 'xor': continue w,dist = parents[x] if operation[dist][0] != 'constant': continue if operation[dist][2] != operation[dist][1] - 1: continue if operation[w][0] != 'xor': continue v,s = parents[w] if operation[v][0] != 'and': s,v = parents[w] if operation[v][0] != 'and': continue if operation[s][0] != 'sub': continue u,t2 = parents[v] if t2 != t: t2,u = parents[v] if t2 != t: continue if operation[u][0] != 'xor': continue r,s2 = parents[u] if s2 != s: s2,r = parents[u] if s2 != s: continue r2,q = parents[t] if r2 != r: q,r2 = parents[t] if r2 != r: continue if parents[s] != [r,q]: continue xorsignedmin += [(y,q,r)] for y,q,r in xorsignedmin: nextvalue += 1 operation[nextvalue] = ['signedmin'] parents[nextvalue] = [q,r] bits[nextvalue] = bits[y] operation[y] = ['xor'] parents[y] = [q,nextvalue] progress = True # signedrshift(SignExt(s,bits),constant(2*bits,bits-1))[bits-1:0] # => signedrshift(s,constant(bits,bits-1)) simplifyrshift = [] for v in operation: if operation[v] != ['Extract',bits[v]-1,0]: continue x = parents[v][0] if operation[x][0] != 'signedrshift': continue y,c = parents[x] if operation[y] != ['SignExt',bits[v]]: continue if operation[c] != ['constant',2*bits[v],bits[v]-1]: continue s = parents[y][0] simplifyrshift += [(v,s)] for v,s in simplifyrshift: nextvalue += 1 operation[nextvalue] = ['constant',bits[v],bits[v]-1] parents[nextvalue] = [] bits[nextvalue] = bits[v] operation[v] = ['signedrshift'] parents[v] = [s,nextvalue] progress = True # if c is constant(b,2**(b-1)-1): # rewrite signedmin(s,c) as s # rewrite signedmax(s,c) as c for v in operation: b = bits[v] if operation[v][0] not in ['signedmin','signedmax']: continue s,c = parents[v] if operation[c] != ['constant',b,2**(b-1)-1]: c,s = parents[v] if operation[c] != ['constant',b,2**(b-1)-1]: continue if operation[v][0] == 'signedmin': operation[v] = operation[s] parents[v] = parents[s] progress = True continue if operation[v][0] == 'signedmax': operation[v] = operation[c] parents[v] = parents[c] progress = True continue # if c is constant(b,2**(b-1)): # rewrite signedmin(s,c) as c # rewrite signedmax(s,c) as s for v in operation: b = bits[v] if operation[v][0] not in ['signedmin','signedmax']: continue s,c = parents[v] if operation[c] != ['constant',b,2**(b-1)]: c,s = parents[v] if operation[c] != ['constant',b,2**(b-1)]: continue if operation[v][0] == 'signedmax': operation[v] = operation[s] parents[v] = parents[s] progress = True continue if operation[v][0] == 'signedmin': operation[v] = operation[c] parents[v] = parents[c] progress = True continue # if c is constant(b,2**b-1): # rewrite xor(c,s) as invert(s) for v in operation: b = bits[v] if operation[v][0] != 'xor': continue c,s = parents[v] if operation[c] != ['constant',b,2**b-1]: s,c = parents[v] if operation[c] != ['constant',b,2**b-1]: continue operation[v] = ['invert'] parents[v] = [s] progress = True # rewrite signedmin(invert(s),invert(t)) as invert(signedmax(s,t)) # rewrite signedmax(invert(s),invert(t)) as invert(signedmin(s,t)) invertminmax = [] for v in operation: b = bits[v] if operation[v][0] in ['signedmin','signedmax']: x,y = parents[v] if operation[x][0] == 'invert' and operation[y][0] == 'invert': invertminmax += [(v,parents[x][0],parents[y][0])] for v,s,t in invertminmax: nextvalue += 1 flip = {'signedmax':'signedmin','signedmin':'signedmax'} operation[nextvalue] = [flip[operation[v][0]]] parents[nextvalue] = [s,t] bits[nextvalue] = bits[v] operation[v] = ['invert'] parents[v] = [nextvalue] progress = True # rewrite invert(s)[high:low] as invert(s[high:low]) extractinvert = [] for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'invert': continue s = parents[x][0] extractinvert += [(v,s)] for v,s in extractinvert: nextvalue += 1 operation[nextvalue] = operation[v] parents[nextvalue] = [s] bits[nextvalue] = bits[v] operation[v] = ['invert'] parents[v] = [nextvalue] progress = True # rewrite s[bits-1:0] as s for v in operation: if operation[v][0] != 'Extract': continue s = parents[v][0] if operation[v][1:] != [bits[s]-1,0]: continue operation[v] = operation[s] parents[v] = parents[s] progress = True # rewrite s[high:low][high2:low2] as s[high2+low:low2+low] for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'Extract': continue s = parents[x][0] high2,low2 = operation[v][1:] high,low = operation[x][1:] operation[v] = ['Extract',high2 + low,low2 + low] parents[v] = [s] progress = True # rewrite Reverse(s)[high:low] as Reverse(s[...]) if possible extractreverse = [] for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'Reverse': continue high,low = operation[v][1:] if low % 8 != 0: continue if high % 8 != 7: continue extractreverse += [(v,parents[x][0],high,low)] for v,s,high,low in extractreverse: nextvalue += 1 operation[nextvalue] = ['Extract',bits[s] - 1 - low,bits[s] - 1 - high] parents[nextvalue] = [s] bits[nextvalue] = high - low + 1 assert bits[nextvalue] == bits[v] operation[v] = ['Reverse'] parents[v] = [nextvalue] progress = True # rewrite Concat(...)[high:low] as ...[high-pos:low-pos] if possible for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'Concat': continue high,low = operation[v][1:] pos = 0 for y in reversed(parents[x]): if pos <= low and high < pos + bits[y]: operation[v] = ['Extract',high - pos,low - pos] parents[v] = [y] progress = True break pos += bits[y] # rewrite If(signedle(s,t),s,t) as signedmin(s,t) # rewrite If(signedlt(s,t),s,t) as signedmin(s,t) for v in operation: if operation[v][0] == 'If': x = parents[v][0] if operation[x][0] in ['signedle','signedlt']: if parents[v][1] == parents[x][0]: if parents[v][2] == parents[x][1]: operation[v] = ['signedmin'] parents[v] = parents[x] progress = True # rewrite If(signedle(t,s),s,t) as signedmax(s,t) # rewrite If(signedlt(t,s),s,t) as signedmax(s,t) for v in operation: if operation[v][0] == 'If': x = parents[v][0] if operation[x][0] in ['signedle','signedlt']: if parents[v][1] == parents[x][1]: if parents[v][2] == parents[x][0]: operation[v] = ['signedmax'] parents[v] = parents[x] progress = True # rewrite If(c,constant(1,1),constant(1,0)) as c for v in operation: if operation[v][0] == 'If': c,x,y = parents[v] if operation[x] == ['constant',1,1]: if operation[y] == ['constant',1,0]: operation[v] = operation[c] parents[v] = parents[c] progress = True # rewrite xor(c,constant(1,1)) as invert(c) for v in operation: if operation[v][0] == 'xor': c,x = parents[v] if operation[x] == ['constant',1,1]: operation[v] = ['invert'] parents[v] = [c] progress = True # rewrite equal(c,constant(1,0)) as invert(c) for v in operation: if operation[v][0] == 'equal': c,x = parents[v] if operation[x] == ['constant',1,0]: operation[v] = ['invert'] parents[v] = [c] progress = True children = dict() for z in operation: children[z] = set() for v in outputname: children[value[v]].add(-1) for z in operation: for x in parents[z]: children[x].add(z) deleting = set(v for v in operation if len(children[v]) == 0) merging = deleting.copy() merge = [] for x in operation: c = list(children[x]) for y,z in [(c[i],c[j]) for j in range(len(c)) for i in range(j)]: if y == -1: continue if z == -1: continue if operation[y] != operation[z]: continue parentsmatch = False if parents[y] == parents[z]: parentsmatch = True if operation[y][0] in ['signedmin','signedmax']: if set(parents[y]) == set(parents[z]): parentsmatch = True if not parentsmatch: continue assert bits[y] == bits[z] if y in merging: continue if z in merging: continue merge += [(y,z)] merging.add(y) merging.add(z) for y,z in merge: # print('eliminating %s in favor of %s' % (z,y)) for t in children[z]: if t == -1: for v in outputname: if value[v] == z: value[v] = y else: for j in range(len(parents[t])): if parents[t][j] == z: parents[t][j] = y deleting.add(z) progress = True for v in deleting: del operation[v] del parents[v] del bits[v] progress = True done = set() def do(v): if v in done: return done.add(v) for x in parents[v]: do(x) if operation[v][0] == 'input': print('v%d = %s' % (v,operation[v][1])) else: p = ['v%s' % x for x in parents[v]] p += ['%s' % x for x in operation[v][1:]] print('v%d = %s(%s)' % (v,operation[v][0],','.join(p))) for v in outputname: do(value[v]) print('%s = v%d' % (v,value[v]))