#!/usr/bin/env python3

import sys
import claripy

verbose = False

constpatterns = {
  'bit0': (lambda b:0),
  'bit1': (lambda b:1),
  'const5': (lambda b:5),
  'const26': (lambda b:26),
  'const27': (lambda b:27),
  'const31': (lambda b:31),
  '0star1': (lambda b:1),
  '0star': (lambda b:0),
  '01star': (lambda b:2**(b-1)-1),
  '10star': (lambda b:2**(b-1)),
  '1star': (lambda b:2**b-1),
}

def parse():
  from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
  def lit(x): return Literal(x).suppress()

  def group(name,x):
    def action(r): return name,*r
    return x.copy().set_parse_action(action)

  counter = {}
  def countinggroup(name,x):
    def action(r):
      if name not in counter: counter[name] = 0
      counter[name] += 1
      return f'{name}#{counter[name]}',*r
    return x.copy().set_parse_action(action)

  def fun(name,*args):
    what = lit(name)
    if len(args) > 0:
      what += lparen
      for pos,arg in enumerate(args):
        if pos > 0: what += comma
        what += arg
      what += rparen
    return group(name,what)

  lparen = lit('(')
  rparen = lit(')')
  comma = lit(',')

  Term = Forward()
  Termvar = group('var',Word('ABCDEFGHIJKLMNOPQRSTUVWXYZ'))
  X = lparen+Term+rparen
  X |= group('assign',Termvar+lit('=')+Term)
  X |= fun('Bits',Word(nums),Term)
  X |= group('Concat',lit('Concat')+lparen+Term+ZeroOrMore(comma+Term)+rparen)
  X |= fun('Extract',Term,Word(nums),Word(nums))
  X |= fun('ZeroExt',Term,Word(nums))
  for c in constpatterns:
    X |= countinggroup(c,lit(c))
  for op in 'invert','Reverse':
    X |= fun(op,Term)
  for op in 'add','sub','and','xor','or','equal','unsignedgt','unsignedge','unsignedlt','unsignedle','signedlt','signedle','signedmin','signedmax','lshift','unsignedrshift','signedrshift':
    X |= fun(op,Term,Term)
  X |= fun('If',Term,Term,Term)
  X |= Termvar
  Term <<= X

  result = []
  for line in sys.stdin:
    counter = {}
    line = line.strip()
    x = line.split('->')
    if len(x) != 2: continue
    try:
      left = (Term+StringEnd()).parse_string(x[0])[0]
    except:
      raise Exception(f'failure to parse {x[0]}')
    try:
      right = (Term+StringEnd()).parse_string(x[1])[0]
    except:
      raise Exception(f'failure to parse {x[1]}')
    result.append((line,left,right))

  return result

# returns (bits0,v0),(bits1,v1),...
# where v0,v1,... are variable names in pattern
# (including const patterns such as 01star)
# and bits0,bits1,... are number of bits in those variables
# (or None if the number of bits is constrained)
def variables(pattern,forcebits=None):
  if len(pattern) == 1: # constant
    constname,constcounter = pattern[0].split('#')
    if constname in ('bit0','bit1'):
      if forcebits is None: forcebits = 1
      assert forcebits == 1
    return (forcebits,pattern[0]),
  if pattern[0] == 'var':
    return (forcebits,pattern[1]),
  if pattern[0] in ('Reverse','invert'):
    return variables(pattern[1],forcebits=forcebits)
  if pattern[0] == 'assign':
    assert pattern[1][0] == 'var'
    V = pattern[1][1]
    return variables(pattern[2],forcebits=forcebits)+((forcebits,V),)
  if pattern[0] == 'Bits':
    bits = int(pattern[1])
    if forcebits is None: forcebits = bits
    assert forcebits == bits
    return variables(pattern[2],forcebits=forcebits)
  if pattern[0] in ('SignExt','ZeroExt'):
    if forcebits is not None: forcebits -= int(pattern[2])
    return variables(pattern[1],forcebits=forcebits)
  if pattern[0] == 'Extract':
    return variables(pattern[1])
  if pattern[0] == 'If':
    return variables(pattern[1],forcebits=1)+variables(pattern[2],forcebits=forcebits)+variables(pattern[3],forcebits=forcebits)
  if pattern[0] in ('equal','unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt'):
    if forcebits is not None: assert forcebits == 1
    return variables(pattern[1])+variables(pattern[2])
  if pattern[0] in ('signedmin','signedmax','xor','and','or','add','sub','lshift','unsignedrshift','signedrshift'):
    return variables(pattern[1],forcebits=forcebits)+variables(pattern[2],forcebits=forcebits)
  if pattern[0] == 'Concat':
    result = ()
    for p in pattern[1:]: result += variables(p)
    return result
  sys.stdout.flush()
  raise Exception(f'unrecognized op {pattern[0]}')

# sa as {v:bits} dict
# returns None if there is an incompatibility
# caller's responsibility to include all variables
def patternbits(pattern,sa):
  if len(pattern) == 1:
    constname,_ = pattern[0].split('#')
    if constname in ('bit0','bit1'):
      if sa[pattern[0]] != 1: return
    bits = sa[pattern[0]]
    value = constpatterns[constname](bits)
    if value < 0 or value >= 2**bits: return
    return bits
  if pattern[0] == 'var':
    return sa[pattern[1]]
  if pattern[0] == 'assign':
    assert pattern[1][0] == 'var'
    subbits = patternbits(pattern[2],sa)
    if subbits != sa[pattern[1][1]]: return
    return subbits
  if pattern[0] in ('Reverse','invert'):
    return patternbits(pattern[1],sa)
  if pattern[0] == 'Bits':
    forcebits = int(pattern[1])
    subbits = patternbits(pattern[2],sa)
    if subbits != forcebits: return # wrong bits or otherwise incompatible
    return forcebits
  if pattern[0] in ('SignExt','ZeroExt'):
    morebits = int(pattern[2])
    subbits = patternbits(pattern[1],sa)
    if subbits is None: return
    return subbits+morebits
  if pattern[0] == 'Extract':
    high,low = map(int,pattern[2:])
    subbits = patternbits(pattern[1],sa)
    if subbits is None: return
    if subbits <= high: return
    return 1+high-low
  if pattern[0] == 'If':
    if patternbits(pattern[1],sa) != 1: return
    firstbits = patternbits(pattern[2],sa)
    if firstbits is None: return
    secondbits = patternbits(pattern[3],sa)
    if secondbits is None: return
    if firstbits != secondbits: return
    return firstbits
  if pattern[0] in ('equal','unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt'):
    firstbits = patternbits(pattern[1],sa)
    if firstbits is None: return
    secondbits = patternbits(pattern[2],sa)
    if secondbits is None: return
    if firstbits != secondbits: return
    return 1
  if pattern[0] in ('signedmin','signedmax','xor','and','or','add','sub','lshift','unsignedrshift','signedrshift'):
    firstbits = patternbits(pattern[1],sa)
    if firstbits is None: return
    secondbits = patternbits(pattern[2],sa)
    if secondbits is None: return
    if firstbits != secondbits: return
    return firstbits
  if pattern[0] == 'Concat':
    result = 0
    for p in pattern[1:]:
      bits = patternbits(p,sa)
      if bits is None: return
      result += bits
    return result
  sys.stdout.flush()
  raise Exception(f'unrecognized op {pattern[0]}')

# claripy distinguishes between BitVec 1 and Bool
def bitvec1tobool(v):
  return v == claripy.BVV(1,1)

def booltobitvec1(v):
  return claripy.If(v,claripy.BVV(1,1),claripy.BVV(0,1))

# appends to constraints list in place
def formula(pattern,sa,constraints):
  if len(pattern) == 1:
    constname,_ = pattern[0].split('#')
    bits = sa[pattern[0]]
    value = constpatterns[constname](bits)
    return claripy.BVV(value,bits)
  if pattern[0] == 'var':
    bits = sa[pattern[1]]
    return claripy.BVS(pattern[1],bits,explicit_name=True)
  if pattern[0] == 'assign':
    result = formula(pattern[2],sa,constraints)
    constraints.append(formula(pattern[1],sa,constraints) == result)
    return result
  if pattern[0] == 'Bits': return formula(pattern[2],sa,constraints)
  if pattern[0] == 'Reverse': return claripy.Reverse(formula(pattern[1],sa,constraints))
  if pattern[0] == 'invert': return formula(pattern[1],sa,constraints).__invert__()
  if pattern[0] == 'SignExt': return claripy.SignExt(int(pattern[2]),formula(pattern[1],sa,constraints))
  if pattern[0] == 'ZeroExt': return claripy.ZeroExt(int(pattern[2]),formula(pattern[1],sa,constraints))
  if pattern[0] == 'Extract': return claripy.Extract(int(pattern[2]),int(pattern[3]),formula(pattern[1],sa,constraints))
  if pattern[0] == 'If': return claripy.If(bitvec1tobool(formula(pattern[1],sa,constraints)),formula(pattern[2],sa,constraints),formula(pattern[3],sa,constraints))
  if pattern[0] == 'Concat': return claripy.Concat(*(formula(p,sa,constraints) for p in pattern[1:]))
  if len(pattern) == 3:
    op1 = formula(pattern[1],sa,constraints)
    op2 = formula(pattern[2],sa,constraints)
    if pattern[0] == 'equal': return op1==op2
    if pattern[0] == 'add': return op1+op2
    if pattern[0] == 'sub': return op1-op2
    if pattern[0] == 'xor': return op1^op2
    if pattern[0] == 'and': return op1&op2
    if pattern[0] == 'or': return op1|op2
    if pattern[0] == 'unsignedge': return booltobitvec1(op1.UGE(op2))
    if pattern[0] == 'unsignedgt': return booltobitvec1(op1.UGT(op2))
    if pattern[0] == 'unsignedle': return booltobitvec1(op1.ULE(op2))
    if pattern[0] == 'unsignedlt': return booltobitvec1(op1.ULT(op2))
    if pattern[0] == 'signedle': return booltobitvec1(op1.SLE(op2))
    if pattern[0] == 'signedlt': return booltobitvec1(op1.SLT(op2))
    if pattern[0] == 'signedmin': return claripy.If(op1.SLE(op2),op1,op2)
    if pattern[0] == 'signedmax': return claripy.If(op1.SLE(op2),op2,op1)
    if pattern[0] == 'lshift': return op1.__lshift__(op2)
    if pattern[0] == 'unsignedrshift': return op1.LShR(op2)
    if pattern[0] == 'signedrshift': return op1.__rshift__(op2)
  sys.stdout.flush()
  raise Exception(f'unrecognized op {pattern[0]}')

def verify_specific_size(left,right,sa):
  constraints = []
  leftformula = formula(left,sa,constraints)
  rightformula = formula(right,sa,constraints)
  if verbose:
    print('left formula',leftformula)
    print('right formula',rightformula)
    print('constraints',constraints)
  smt = claripy.Solver(timeout=-1) # XXX warning: by default (without timeout=-1), Solver will produce wrong results after 5 minutes
  smt.add(leftformula != rightformula)
  for c in constraints:
    smt.add(c)
  mismatch = smt.satisfiable()
  if mismatch:
    print('found mismatch:')
    for v in sa:
      vexample = smt.eval(formula(('var',v),sa,constraints),1)
      print(f'  {v} = {vexample[0]} = {hex(vexample[0])}')
    sys.stdout.flush()
    raise Exception('aborting given mismatch')
  print('SMT solver says ok')

# yields sa as (v,bits) tuple
def sizeassignments(varbits):
  if len(varbits) == 0:
    yield ()
    return
  v,bits = varbits[0]
  # XXX: in principle should check for sizes actually used in minmax
  todo = (1,8,16,32,64,128,256,512) if bits is None else (bits,)
  for b in todo:
    for subsa in sizeassignments(varbits[1:]):
      yield ((v,b),)+subsa

def verify(line,left,right):
  print('rule:',line)
  leftvars = variables(left)
  rightvars = variables(right)
  if verbose:
    print('left',left)
    print('right',right)
    print('left variables',leftvars)
    print('right variables',rightvars)
  constcounter = 0
  printvarorder = []
  var2bits = {}
  for bits,var in leftvars+rightvars:
    if var not in var2bits: var2bits[var] = bits
    if bits is not None:
      if var2bits[var] is None: var2bits[var] = bits
      assert var2bits[var] == bits
    if var not in printvarorder: printvarorder.append(var)
  varbits = sorted(var2bits.items())
  foundsa = False
  for sa in sizeassignments(varbits):
    sa = dict(sa)
    leftbits = patternbits(left,sa)
    if leftbits is None: continue
    rightbits = patternbits(right,sa)
    if rightbits is None: continue
    if leftbits != rightbits: continue

    foundsa = True

    banner = 'considering sizes'
    for var in printvarorder:
      if var.startswith('bit'): continue
      bits = sa[var]
      if '#' in var:
        constname = var.split('#')[0]
        if len([var2 for var2 in printvarorder if var2.startswith(constname+'#')]) == 1:
          var = constname
      banner += f' {var}:{bits}'
    banner += f' output:{leftbits}'
    print(banner)

    verify_specific_size(left,right,sa)

  if not foundsa:
    sys.stdout.flush()
    raise Exception('does this have a size assignment?')
  print('=====')
  sys.stdout.flush()

rules = parse()
for line,left,right in rules:
  verify(line,left,right)
