# Natural Language Toolkit: Models for first-order languages with lambda
#
# Copyright (C) 2001-2013 NLTK Project
# Author: Ewan Klein <ewan@inf.ed.ac.uk>,
# URL: <http://nltk.sourceforge.net>
# For license information, see LICENSE.TXT

#TODO:
    #- fix tracing
    #- fix iterator-based approach to existentials

"""
This module provides data structures for representing first-order
models.
"""
from __future__ import print_function, unicode_literals

from pprint import pformat
import inspect
import textwrap

from nltk.decorators import decorator # this used in code that is commented out
from nltk.compat import string_types, python_2_unicode_compatible

from nltk.sem.logic import (AbstractVariableExpression, AllExpression,
                            AndExpression, ApplicationExpression, EqualityExpression,
                            ExistsExpression, IffExpression, ImpExpression,
                            IndividualVariableExpression, LambdaExpression,
                            LogicParser, NegatedExpression, OrExpression,
                            Variable, is_indvar)


class Error(Exception): pass

class Undefined(Error):  pass

def trace(f, *args, **kw):
    argspec = inspect.getargspec(f)
    d = dict(zip(argspec[0], args))
    if d.pop('trace', None):
        print()
        for item in d.items():
            print("%s => %s" % item)
    return f(*args, **kw)

def is_rel(s):
    """
    Check whether a set represents a relation (of any arity).

    :param s: a set containing tuples of str elements
    :type s: set
    :rtype: bool
        """
    # we have the empty relation, i.e. set()
    if len(s) == 0:
        return True
    # all the elements are tuples of the same length
    elif all(isinstance(el, tuple) for el in s) and len(max(s))==len(min(s)):
        return True
    else:
        raise ValueError("Set %r contains sequences of different lengths" % s)

def set2rel(s):
    """
    Convert a set containing individuals (strings or numbers) into a set of
    unary tuples. Any tuples of strings already in the set are passed through
    unchanged.

    For example:
      - set(['a', 'b']) => set([('a',), ('b',)])
      - set([3, 27]) => set([('3',), ('27',)])

    :type s: set
    :rtype: set of tuple of str
    """
    new = set()
    for elem in s:
        if isinstance(elem, string_types):
            new.add((elem,))
        elif isinstance(elem, int):
            new.add((str(elem,)))
        else:
            new.add(elem)
    return new

def arity(rel):
    """
    Check the arity of a relation.
    :type rel: set of tuples
    :rtype: int of tuple of str
    """
    if len(rel) == 0:
        return 0
    return len(list(rel)[0])


@python_2_unicode_compatible
class Valuation(dict):
    """
    A dictionary which represents a model-theoretic Valuation of non-logical constants.
    Keys are strings representing the constants to be interpreted, and values correspond
    to individuals (represented as strings) and n-ary relations (represented as sets of tuples
    of strings).

    An instance of ``Valuation`` will raise a KeyError exception (i.e.,
    just behave like a standard  dictionary) if indexed with an expression that
    is not in its list of symbols.
    """
    def __init__(self, iter):
        """
        :param iter: a list of (symbol, value) pairs.
        """
        dict.__init__(self)
        for (sym, val) in iter:
            if isinstance(val, string_types) or isinstance(val, bool):
                self[sym] = val
            elif isinstance(val, set):
                self[sym] = set2rel(val)
            else:
                msg = textwrap.fill("Error in initializing Valuation. "
                                    "Unrecognized value for symbol '%s':\n%s" % (sym, val), width=66)

                raise ValueError(msg)

    def __getitem__(self, key):
        if key in self:
            return dict.__getitem__(self, key)
        else:
            raise Undefined("Unknown expression: '%s'" % key)

    def __str__(self):
        return pformat(self)

    @property
    def domain(self):
        """Set-theoretic domain of the value-space of a Valuation."""
        dom = []
        for val in self.values():
            if isinstance(val, string_types):
                dom.append(val)
            elif not isinstance(val, bool):
                dom.extend([elem for tuple in val for elem in tuple if elem is not None])
        return set(dom)

    @property
    def symbols(self):
        """The non-logical constants which the Valuation recognizes."""
        return sorted(self.keys())


@python_2_unicode_compatible
class Assignment(dict):
    """
    A dictionary which represents an assignment of values to variables.

    An assigment can only assign values from its domain.

    If an unknown expression *a* is passed to a model *M*\ 's
    interpretation function *i*, *i* will first check whether *M*\ 's
    valuation assigns an interpretation to *a* as a constant, and if
    this fails, *i* will delegate the interpretation of *a* to
    *g*. *g* only assigns values to individual variables (i.e.,
    members of the class ``IndividualVariableExpression`` in the ``logic``
    module. If a variable is not assigned a value by *g*, it will raise
    an ``Undefined`` exception.

    A variable *Assignment* is a mapping from individual variables to
    entities in the domain. Individual variables are usually indicated
    with the letters ``'x'``, ``'y'``, ``'w'`` and ``'z'``, optionally
    followed by an integer (e.g., ``'x0'``, ``'y332'``).  Assignments are
    created using the ``Assignment`` constructor, which also takes the
    domain as a parameter.

        >>> from nltk.sem.evaluate import Assignment
        >>> dom = set(['u1', 'u2', 'u3', 'u4'])
        >>> g3 = Assignment(dom, [('x', 'u1'), ('y', 'u2')])
        >>> g3
        {'y': 'u2', 'x': 'u1'}

    There is also a ``print`` format for assignments which uses a notation
    closer to that in logic textbooks:

        >>> print(g3)
        g[u2/y][u1/x]

    It is also possible to update an assignment using the ``add`` method:

        >>> dom = set(['u1', 'u2', 'u3', 'u4'])
        >>> g4 = Assignment(dom)
        >>> g4.add('x', 'u1')
        {'x': 'u1'}

    With no arguments, ``purge()`` is equivalent to ``clear()`` on a dictionary:

        >>> g4.purge()
        >>> g4
        {}

    :param domain: the domain of discourse
    :type domain: set
    :param assign: a list of (varname, value) associations
    :type assign: list
    """

    def __init__(self, domain, assign=None):
        dict.__init__(self)
        self.domain = domain
        if assign:
            for (var, val) in assign:
                assert val in self.domain,\
                       "'%s' is not in the domain: %s" % (val, self.domain)
                assert is_indvar(var),\
                       "Wrong format for an Individual Variable: '%s'" % var
                self[var] = val
        self._addvariant()

    def __getitem__(self, key):
        if key in self:
            return dict.__getitem__(self, key)
        else:
            raise Undefined("Not recognized as a variable: '%s'" % key)

    def copy(self):
        new = Assignment(self.domain)
        new.update(self)
        return new

    def purge(self, var=None):
        """
        Remove one or all keys (i.e. logic variables) from an
        assignment, and update ``self.variant``.

        :param var: a Variable acting as a key for the assignment.
        """
        if var:
            val = self[var]
            del self[var]
        else:
            self.clear()
        self._addvariant()
        return None

    def __str__(self):
        """
        Pretty printing for assignments. {'x', 'u'} appears as 'g[u/x]'
        """
        gstring = "g"
        for (val, var) in self.variant:
            gstring += "[%s/%s]" % (val, var)
        return gstring

    def _addvariant(self):
        """
        Create a more pretty-printable version of the assignment.
        """
        list = []
        for item in self.items():
            pair = (item[1], item[0])
            list.append(pair)
        self.variant = list
        return None

    def add(self, var, val):
        """
        Add a new variable-value pair to the assignment, and update
        ``self.variant``.

        """
        assert val in self.domain,\
               "%s is not in the domain %s" % (val, self.domain)
        assert is_indvar(var),\
               "Wrong format for an Individual Variable: '%s'" % var
        self[var] = val
        self._addvariant()
        return self


@python_2_unicode_compatible
class Model(object):
    """
    A first order model is a domain *D* of discourse and a valuation *V*.

    A domain *D* is a set, and a valuation *V* is a map that associates
    expressions with values in the model.
    The domain of *V* should be a subset of *D*.

    Construct a new ``Model``.

    :type domain: set
    :param domain: A set of entities representing the domain of discourse of the model.
    :type valuation: Valuation
    :param valuation: the valuation of the model.
    :param prop: If this is set, then we are building a propositional\
    model and don't require the domain of *V* to be subset of *D*.
    """

    def __init__(self, domain, valuation):
        assert isinstance(domain, set)
        self.domain = domain
        self.valuation = valuation
        if not domain.issuperset(valuation.domain):
            raise Error("The valuation domain, %s, must be a subset of the model's domain, %s"\
                  % (valuation.domain, domain))

    def __repr__(self):
        return "(%r, %r)" % (self.domain, self.valuation)

    def __str__(self):
        return "Domain = %s,\nValuation = \n%s" % (self.domain, self.valuation)

    def evaluate(self, expr, g, trace=None):
        """
        Call the ``LogicParser`` to parse input expressions, and
        provide a handler for ``satisfy``
        that blocks further propagation of the ``Undefined`` error.
        :param expr: An ``Expression`` of ``logic``.
        :type g: Assignment
        :param g: an assignment to individual variables.
        :rtype: bool or 'Undefined'
        """
        try:
            lp = LogicParser()
            parsed = lp.parse(expr)
            value = self.satisfy(parsed, g, trace=trace)
            if trace:
                print()
                print("'%s' evaluates to %s under M, %s" %  (expr, value, g))
            return value
        except Undefined:
            if trace:
                print()
                print("'%s' is undefined under M, %s" %  (expr, g))
            return 'Undefined'


    def satisfy(self, parsed, g, trace=None):
        """
        Recursive interpretation function for a formula of first-order logic.

        Raises an ``Undefined`` error when ``parsed`` is an atomic string
        but is not a symbol or an individual variable.

        :return: Returns a truth value or ``Undefined`` if ``parsed`` is\
        complex, and calls the interpretation function ``i`` if ``parsed``\
        is atomic.

        :param parsed: An expression of ``logic``.
        :type g: Assignment
        :param g: an assignment to individual variables.
        """

        if isinstance(parsed, ApplicationExpression):
            function, arguments = parsed.uncurry()
            if isinstance(function, AbstractVariableExpression):
                #It's a predicate expression ("P(x,y)"), so used uncurried arguments
                funval = self.satisfy(function, g)
                argvals = tuple(self.satisfy(arg, g) for arg in arguments)
                return argvals in funval
            else:
                #It must be a lambda expression, so use curried form
                funval = self.satisfy(parsed.function, g)
                argval = self.satisfy(parsed.argument, g)
                return funval[argval]
        elif isinstance(parsed, NegatedExpression):
            return not self.satisfy(parsed.term, g)
        elif isinstance(parsed, AndExpression):
            return self.satisfy(parsed.first, g) and \
                   self.satisfy(parsed.second, g)
        elif isinstance(parsed, OrExpression):
            return self.satisfy(parsed.first, g) or \
                   self.satisfy(parsed.second, g)
        elif isinstance(parsed, ImpExpression):
            return (not self.satisfy(parsed.first, g)) or \
                   self.satisfy(parsed.second, g)
        elif isinstance(parsed, IffExpression):
            return self.satisfy(parsed.first, g) == \
                   self.satisfy(parsed.second, g)
        elif isinstance(parsed, EqualityExpression):
            return self.satisfy(parsed.first, g) == \
                   self.satisfy(parsed.second, g)
        elif isinstance(parsed, AllExpression):
            new_g = g.copy()
            for u in self.domain:
                new_g.add(parsed.variable.name, u)
                if not self.satisfy(parsed.term, new_g):
                    return False
            return True
        elif isinstance(parsed, ExistsExpression):
            new_g = g.copy()
            for u in self.domain:
                new_g.add(parsed.variable.name, u)
                if self.satisfy(parsed.term, new_g):
                    return True
            return False
        elif isinstance(parsed, LambdaExpression):
            cf = {}
            var = parsed.variable.name
            for u in self.domain:
                val = self.satisfy(parsed.term, g.add(var, u))
                # NB the dict would be a lot smaller if we do this:
                # if val: cf[u] = val
                # But then need to deal with cases where f(a) should yield
                # a function rather than just False.
                cf[u] = val
            return cf
        else:
            return self.i(parsed, g, trace)

    #@decorator(trace_eval)
    def i(self, parsed, g, trace=False):
        """
        An interpretation function.

        Assuming that ``parsed`` is atomic:

        - if ``parsed`` is a non-logical constant, calls the valuation *V*
        - else if ``parsed`` is an individual variable, calls assignment *g*
        - else returns ``Undefined``.

        :param parsed: an ``Expression`` of ``logic``.
        :type g: Assignment
        :param g: an assignment to individual variables.
        :return: a semantic value
        """
        # If parsed is a propositional letter 'p', 'q', etc, it could be in valuation.symbols
        # and also be an IndividualVariableExpression. We want to catch this first case.
        # So there is a procedural consequence to the ordering of clauses here:
        if parsed.variable.name in self.valuation.symbols:
            return self.valuation[parsed.variable.name]
        elif isinstance(parsed, IndividualVariableExpression):
            return g[parsed.variable.name]

        else:
            raise Undefined("Can't find a value for %s" % parsed)

    def satisfiers(self, parsed, varex, g, trace=None, nesting=0):
        """
        Generate the entities from the model's domain that satisfy an open formula.

        :param parsed: an open formula
        :type parsed: Expression
        :param varex: the relevant free individual variable in ``parsed``.
        :type varex: VariableExpression or str
        :param g: a variable assignment
        :type g:  Assignment
        :return: a set of the entities that satisfy ``parsed``.
        """

        spacer = '   '
        indent = spacer + (spacer * nesting)
        candidates = []

        if isinstance(varex, string_types):
            var = Variable(varex)
        else:
            var = varex

        if var in parsed.free():
            if trace:
                print()
                print((spacer * nesting) + "Open formula is '%s' with assignment %s" % (parsed, g))
            for u in self.domain:
                new_g = g.copy()
                new_g.add(var.name, u)
                if trace and trace > 1:
                    lowtrace = trace-1
                else:
                    lowtrace = 0
                value = self.satisfy(parsed, new_g, lowtrace)

                if trace:
                    print(indent + "(trying assignment %s)" % new_g)

                # parsed == False under g[u/var]?
                if value == False:
                    if trace:
                        print(indent + "value of '%s' under %s is False" % (parsed, new_g))

                # so g[u/var] is a satisfying assignment
                else:
                    candidates.append(u)
                    if trace:
                        print(indent + "value of '%s' under %s is %s" % (parsed, new_g, value))

            result = set(c for c in candidates)
        # var isn't free in parsed
        else:
            raise Undefined("%s is not free in %s" % (var.name, parsed))

        return result





#//////////////////////////////////////////////////////////////////////
# Demo..
#//////////////////////////////////////////////////////////////////////
# number of spacer chars
mult = 30

# Demo 1: Propositional Logic
#################
def propdemo(trace=None):
    """Example of a propositional model."""

    global val1, dom1, m1, g1
    val1 = Valuation([('P', True), ('Q', True), ('R', False)])
    dom1 = set([])
    m1 = Model(dom1, val1)
    g1 = Assignment(dom1)

    print()
    print('*' * mult)
    print("Propositional Formulas Demo")
    print('*' * mult)
    print('(Propositional constants treated as nullary predicates)')
    print()
    print("Model m1:\n", m1)
    print('*' * mult)
    sentences = [
    '(P & Q)',
    '(P & R)',
    '- P',
    '- R',
    '- - P',
    '- (P & R)',
    '(P | R)',
    '(R | P)',
    '(R | R)',
    '(- P | R)',
    '(P | - P)',
    '(P -> Q)',
    '(P -> R)',
    '(R -> P)',
    '(P <-> P)',
    '(R <-> R)',
    '(P <-> R)',
    ]

    for sent in sentences:
        if trace:
            print()
            m1.evaluate(sent, g1, trace)
        else:
            print("The value of '%s' is: %s" % (sent, m1.evaluate(sent, g1)))

# Demo 2: FOL Model
#############

def folmodel(quiet=False, trace=None):
    """Example of a first-order model."""

    global val2, v2, dom2, m2, g2

    v2 = [('adam', 'b1'), ('betty', 'g1'), ('fido', 'd1'),\
         ('girl', set(['g1', 'g2'])), ('boy', set(['b1', 'b2'])), ('dog', set(['d1'])),
         ('love', set([('b1', 'g1'), ('b2', 'g2'), ('g1', 'b1'), ('g2', 'b1')]))]
    val2 = Valuation(v2)
    dom2 = val2.domain
    m2 = Model(dom2, val2)
    g2 = Assignment(dom2, [('x', 'b1'), ('y', 'g2')])

    if not quiet:
        print()
        print('*' * mult)
        print("Models Demo")
        print("*" * mult)
        print("Model m2:\n", "-" * 14,"\n", m2)
        print("Variable assignment = ", g2)

        exprs = ['adam', 'boy', 'love', 'walks', 'x', 'y', 'z']
        lp = LogicParser()
        parsed_exprs = [lp.parse(e) for e in exprs]

        print()
        for parsed in parsed_exprs:
            try:
                print("The interpretation of '%s' in m2 is %s" % (parsed, m2.i(parsed, g2)))
            except Undefined:
                print("The interpretation of '%s' in m2 is Undefined" % parsed)


        applications = [('boy', ('adam')), ('walks', ('adam',)), ('love', ('adam', 'y')), ('love', ('y', 'adam'))]

        for (fun, args) in applications:
            try:
                funval = m2.i(lp.parse(fun), g2)
                argsval = tuple(m2.i(lp.parse(arg), g2) for arg in args)
                print("%s(%s) evaluates to %s" % (fun, args, argsval in funval))
            except Undefined:
                print("%s(%s) evaluates to Undefined" % (fun, args))

# Demo 3: FOL
#########

def foldemo(trace=None):
    """
    Interpretation of closed expressions in a first-order model.
    """
    folmodel(quiet=True)

    print()
    print('*' * mult)
    print("FOL Formulas Demo")
    print('*' * mult)

    formulas = [
    'love (adam, betty)',
    '(adam = mia)',
    '\\x. (boy(x) | girl(x))',
    '\\x. boy(x)(adam)',
    '\\x y. love(x, y)',
    '\\x y. love(x, y)(adam)(betty)',
    '\\x y. love(x, y)(adam, betty)',
    '\\x y. (boy(x) & love(x, y))',
    '\\x. exists y. (boy(x) & love(x, y))',
    'exists z1. boy(z1)',
    'exists x. (boy(x) &  -(x = adam))',
    'exists x. (boy(x) & all y. love(y, x))',
    'all x. (boy(x) | girl(x))',
    'all x. (girl(x) -> exists y. boy(y) & love(x, y))',    #Every girl loves exists boy.
    'exists x. (boy(x) & all y. (girl(y) -> love(y, x)))',  #There is exists boy that every girl loves.
    'exists x. (boy(x) & all y. (girl(y) -> love(x, y)))',  #exists boy loves every girl.
    'all x. (dog(x) -> - girl(x))',
    'exists x. exists y. (love(x, y) & love(x, y))'
    ]


    for fmla in formulas:
        g2.purge()
        if trace:
            m2.evaluate(fmla, g2, trace)
        else:
            print("The value of '%s' is: %s" % (fmla, m2.evaluate(fmla, g2)))


# Demo 3: Satisfaction
#############

def satdemo(trace=None):
    """Satisfiers of an open formula in a first order model."""

    print()
    print('*' * mult)
    print("Satisfiers Demo")
    print('*' * mult)

    folmodel(quiet=True)

    formulas = [
               'boy(x)',
               '(x = x)',
               '(boy(x) | girl(x))',
               '(boy(x) & girl(x))',
               'love(adam, x)',
               'love(x, adam)',
               '-(x = adam)',
               'exists z22. love(x, z22)',
               'exists y. love(y, x)',
               'all y. (girl(y) -> love(x, y))',
               'all y. (girl(y) -> love(y, x))',
               'all y. (girl(y) -> (boy(x) & love(y, x)))',
               '(boy(x) & all y. (girl(y) -> love(x, y)))',
               '(boy(x) & all y. (girl(y) -> love(y, x)))',
               '(boy(x) & exists y. (girl(y) & love(y, x)))',
               '(girl(x) -> dog(x))',
               'all y. (dog(y) -> (x = y))',
               'exists y. love(y, x)',
               'exists y. (love(adam, y) & love(y, x))'
                ]

    if trace:
        print(m2)

    lp = LogicParser()
    for fmla in formulas:
        print(fmla)
        lp.parse(fmla)

    parsed = [lp.parse(fmla) for fmla in formulas]

    for p in parsed:
        g2.purge()
        print("The satisfiers of '%s' are: %s" % (p, m2.satisfiers(p, 'x', g2, trace)))


def demo(num=0, trace=None):
    """
    Run exists demos.

     - num = 1: propositional logic demo
     - num = 2: first order model demo (only if trace is set)
     - num = 3: first order sentences demo
     - num = 4: satisfaction of open formulas demo
     - any other value: run all the demos

    :param trace: trace = 1, or trace = 2 for more verbose tracing
    """
    demos = {
        1: propdemo,
        2: folmodel,
        3: foldemo,
        4: satdemo}

    try:
        demos[num](trace=trace)
    except KeyError:
        for num in demos:
            demos[num](trace=trace)


if __name__ == "__main__":
    demo(2, trace=0)
