#!/usr/bin/env python3
import sys
import sexpdata
import math
import os
from subprocess import Popen, PIPE, DEVNULL

debug = os.environ.get("DEBUG") != None

loads_kwargs = dict(nil=None, true=None)  # Don't convert nil or #t in s-exp

class CommandFailed(Exception):
    def __init__(self, msg):
        self.msg = msg

class NotEqual(Exception):
    def __init__(self, path, sexp0, sexp1):
        self.path = path
        self.sexp0 = sexp0
        self.sexp1 = sexp1

# Read a whyml file as a s-exp
def read(why3, filename):
    p = Popen([why3, "pp", "--output=sexp", filename], stdout=PIPE, stderr=DEVNULL, encoding='utf8')
    s, _ = p.communicate()
    if p.returncode == 0:
        return sexpdata.loads(s, **loads_kwargs) if s else None
    else:
        raise CommandFailed("cannot print s-expr for original file (returncode={})"
                            .format(p.returncode))

# Pretty-print a whyml file and read the result as a s-exp
def print_and_read(why3, filename):
    p1 = Popen([why3, "pp", "--output=mlw", filename], stdout=PIPE, stderr=DEVNULL, encoding='utf8')
    p2 = Popen([why3, "pp", "--output=sexp", "-"], stdin=p1.stdout, stdout=PIPE, stderr=DEVNULL, encoding='utf8')
    s, _ = p2.communicate()
    if p2.returncode == 0:
        return sexpdata.loads(s, **loads_kwargs) if s else None
    else:
        raise CommandFailed("cannot print s-expr for pretty-printed output (returncode={})"
                            .format(p2.returncode))

def is_location(sexp):
    try:
        return [type(x) for x in sexp] == [str, int, int, int]
    except:
        return False

IGNORE_ID_ATTRS = [
    "W:unused_variable:N", "extraction:array_make", "extraction:array",
    "induction", "mlw:reference_var", "infer", "useraxiom", "W:non_conservative_extension:N",
    "model_trace:flag", "model_trace:first_val", "model_trace:sec_val", "model_trace:TEMP_NAME",
    "model", "W:unmodified_variable:N"
]

def keep_id_attr(at):
    try:
        # (ATstr ((attr_string "...") (attr_tag N)))
        variant, fields = at
        field0, field1 = fields
        field, value = field0
        if variant.value() == 'ATstr' and field.value() == 'attr_string':
            return value not in IGNORE_ID_ATTRS
        else:
            return True
    except:
        return True

def ignore_id_attrs(sexp):
    try:
        # (id_ats (at ...))
        field, value = sexp
        if field.value() == 'id_ats':
            ats = [at for at in value if keep_id_attr(at)]
            return [sexp[0], ats]
        else:
            return sexp
    except:
        return sexp

# Test for sexp (<field_name> _)
def is_field(sexp, field_name):
    try:
        return len(sexp) == 2 and sexp[0].value() == field_name
    except:
        return False

def assert_equal(path, sexp0, sexp1):
    if sexp0 == sexp1:
        return
    if is_location(sexp0) and is_location(sexp1):
        return # Don't bother with locations
    if is_field(sexp0, "attr_tag") and is_field(sexp1, "attr_tag"):
        return # Don't bother with tags
    if type(sexp0) == float and math.isnan(sexp0) and type(sexp1) == float and math.isnan(sexp1):
        return # nan != nan
    if type(sexp0) == list and type(sexp1) == list:
        while True: # Ignore additional parentheses
            try:
                if sexp0[0].value() == "PTparen" and sexp1[0].value() != "PTparen":
                    path = path+[1]
                    sexp0 = sexp0[1]
                elif sexp1[0].value() == "PTparen" and sexp0[0].value() != "PTparen":
                    sexp1 = sexp1[1]
                elif sexp0[0].value() == "Pparen" and sexp1[0].value() != "Pparen":
                    path = path+[1]
                    sexp0 = sexp0[1][0][1]
                elif sexp1[0].value() == "Pparen" and sexp0[0].value() != "Pparen":
                    sexp1 = sexp1[1][0][1]
                elif sexp0[0].value() == "Ptuple" and len(sexp0[1]) == 1 and sexp1[0].value() != "Ptuple":
                    path = path+[1]
                    sexp0 = sexp0[1][0][1]
                elif sexp1[0].value() == "Ptuple" and len(sexp1[1]) == 1 and sexp0[0].value() != "Ptuple":
                    sexp1 = sexp1[1][0][1]
                elif ((sexp0[0].value() == "Tinfix" and sexp1[0].value() == "Tinnfix") or
                      (sexp0[0].value() == "Tinnfix" and sexp1[0].value() == "Tinfix") or
                      (sexp0[0].value() == "Tbinop" and sexp1[0].value() == "Tbinnop") or
                      (sexp0[0].value() == "Tbinnop" and sexp1[0].value() == "Tbinop") or
                      (sexp0[0].value() == "Einfix" and sexp1[0].value() == "Einnfix") or
                      (sexp0[0].value() == "Einnfix" and sexp1[0].value() == "Einfix")):
                    sexp0 = sexp0[1:]
                    sexp1 = sexp1[1:]
                else:
                    break
            except AttributeError:
                break
        sexp0 = ignore_id_attrs(sexp0)
        sexp1 = ignore_id_attrs(sexp1)
        if len(sexp0) > len(sexp1):
            raise NotEqual(path, sexp0, sexp1)
        if len(sexp0) < len(sexp1):
            raise NotEqual(path, sexp0, sexp1)
        for i, (s0, s1) in enumerate(zip(sexp0, sexp1)):
            assert_equal(path+[i], s0, s1)
    else:
        raise NotEqual(path, sexp0, sexp1)

def trace(path, sexp, sexp1):
    if path == []:
        return [sexpdata.Symbol("ERROR"),
                [sexpdata.Symbol("EXPECTED"), sexp],
                [sexpdata.Symbol("FOUND"), sexp1]]
    elif type(sexp) == list:
        return [trace(path[1:], sexp[i], sexp1)
                if i == path[0] else sexp[i]
                for i, x in enumerate(sexp)]

def test(why3, filename):
    print("  {}: ".format(filename), end='', flush=True)
    try:
        sexp0 = read(why3, filename)
        sexp1 = print_and_read(why3, filename)
        assert_equal([], sexp0, sexp1)
        print("ok")
        return True
    except NotEqual as e:
        print("FAILED")
        if debug:
            sexpdata.dump(trace(e.path, sexp0, e.sexp1) or "NO TRACE", sys.stdout)
        return False
    except CommandFailed as e:
        print("COMMAND FAILED:", e.msg)
        return False

def main():
    why3 = sys.argv[1]
    files = sys.argv[2:]
    res = all(test(why3, f) for f in files)
    exit(0 if res else 1)

if __name__ == "__main__":
    main()
