1:- module(ccp_graph, [ graph_switches/2, prune_graph/4, top_value/2, top_goal/1
    2                     , graph_fold/4, graph_inside/3, igraph_sample_tree/3, igraph_entropy/3
    3                     , tree_stats/2, sw_trees_stats/3, accum_stats/3, graph_counts/6
    4                     ]).

Inference and statistics on explanation hypergraphs

This module provides algorithms on explanation hypergraphs, based on the ideas of Sato (in PRISM), Klein and Manning [1] and Goodman [2]. Also provided are methods for sampling explanations from their posterior distribution [2] and computing the entropy of the posterior distribution using a method which, to my knowledge, has not been published before.
[1] D. Klein and C. D. Manning. Parsing and hypergraphs. In New developments in parsing technology, pages 351–372. Springer, 2004.
[2] J. Goodman. Parsing inside-out. PhD thesis, Division of Engineering and Applied Sciences, Harvard University, 1998.
tree ---> goal - list(tree).
igraph == f_graph(pair,float,float,float)
       == list(pair(goal,weighted(list(weighted(list(weighted(factor)))))))
weighted(X) == pair(float,X).


   27:- use_module(library(dcg_pair)).   28:- use_module(library(dcg_macros)).   29:- use_module(library(lambdaki)).   30:- use_module(library(typedef)).   31:- use_module(library(math),        [stoch/3]).   32:- use_module(library(listutils),   [cons//1, foldr/4, zip/3]).   33:- use_module(library(callutils),   [mr/5, (*)/4, const/3, true1/1]).   34:- use_module(library(data/pair),   [fst/2, snd/2]).   35:- use_module(library(rbutils),     [rb_in/3, rb_add//2, rb_app//2, rb_get//2]).   36:- use_module(library(autodiff2),   [back/1, deriv/3]).   37:- use_module(library(lazymath),    [max/3, add/3, mul/3, exp/2, log/2, lse/2, stoch/2, log_stoch/2]).   38:- use_module(effects,   [dist/3]).   39:- use_module(switches,  [map_swc/3, map_swc/4]).   40
   41:- multifile sr_inj/4, sr_proj/5, sr_times/4, sr_plus/4, sr_unit/2, sr_zero/2, m_zero/2.   42
   43:- type graph == list(pair(goal, list(list(factor)))).
   44:- type counts_method ---> vit; io(scaling).
   45:- type scaling ---> lin; log.
 top_value(+Pairs:list(pair(goal,A)), -X:A) is semidet
Extract the value associated with the goal '^top':top from a list of goal-value pairs. This can be applied to explanation graphs or the results of graph_fold/4.
   51top_value(Pairs, Top) :- memberchk(('^top':top)-Top, Pairs).
 prune_graph(+P:pred(+tcall(F,_,D),-D), +Top:goal, +G1:f_graph(F,A,B,C), -G2:f_graph(F,A,B,C)) is det
f_graph(F,A,B,C) == list(pair(goal,tcall(F,A,list(tcall(F,B,list(tcall(F,C,factor)))))))

Prune a graph or annotated graph to keep only goals reachable from a given top goal. With apologies, the type is quite complicated. The input and output graphs are lists of goals paired with annotated explanations. The type of an annotation is described by the type constructor F: F(E,D) is the type of a D annotated with an E. The first argument P knows how to strip off any type of annotation and return the D. This is how we dig down into the annotated explanations to find out which subgoals are referenced. For example, if F = pair, then P should be snd. If F(E,D) = D (ie no annotation), then P should be (=). Since PlDoc won't accept high-order type terms, we write F(E,D) as tcall(F,E,D), where tcall is like call for types.

   67prune_graph(Mapper, Top, GL1, GL2) :-
   68   list_to_rbtree(GL1,G1),
   69   rb_empty(E), children(Top,Mapper,G1,E,G2),
   70   rb_visit(G2,GL2).
   72% SA 2017/10 - Temporarily weaken pattern matching to handle arbitrary factors
   73% children(_:=_, _, _) --> !.
   74% children(\_,   _, _) --> !.
   75children(Mod:Goal,  M, G) --> !,
   76   {rb_lookup(Mod:Goal,Entry,G)}, rb_add(Mod:Goal,Entry),
   77   {call(M, Entry, Expls)},
   78   foldl(mr(M,foldl(mr(M,new_children(M,G)))),Expls).
   79children(_,  _, _) --> !.
   81new_children(M, G, F) -->
   82   rb_get(F,_) -> []; children(F,M,G).
 graph_switches(+G:graph, -SWs:list(switch(_))) is det
Extract list of switches referenced in an explanation graph.
   86graph_switches(G,SWs) :- (setof(SW, graph_sw(G,SW), SWs) -> true; SWs=[]).
   87graph_sw(G,SW)        :- member(_-Es,G), member(E,Es), member(SW:=_,E).
 graph_fold(+SR:sr(A,B,C,T), ?P:params(T), +G:graph, -R:list(pair(goal,W))) is det
Folds the semiring SR over the explanation graph G. Produces R, a list of pairs of goals in the original graph with the result of the fold for that goal. Different semirings can produce many kinds of parsing analysis. The algebra is not strictly a semiring, as the times and plus operators have types A, B -> B and B, C -> C respectively as this makes it easier to avoid unnecessary operations like list appending.

An algebra of type sr(A,B,C,T) must provide 4 operators and 2 values:

inject  : T, factor -> A
times   : A, B -> B
plus    : B, C -> C
project : goal, C -> A, W
unit    : B
zero    : C

Semirings are extensible using multifile predicates sr_inj/4, sr_proj/5, sr_times/4, sr_plus/4, sr_unit/2 and sr_zero/2.

Available semirings in this module:

r(pred(+T, -A), pred(+C, -C), pred(+A, +B, -B), pred(+B, +C, -C))
A term containing the operators in restricted forms as callable terms. The unit and zero for the times and plus operators respectively are looked up in m_zero/2.
Finds the best single explanation for each goal. Parameters are assumed to be probabilities.
ann(sr(A, B, C, T))
Annotates the original hypergraph with the results of any semiring analysis.
sr(A1, B1, C1, T)-sr(A1, B1, C1, T)
For each goal, return a pair of results from any two semiring analyses.

Various standard analysis can be obtained by using the appropriate semiring:

r(=, =, mul, add)
Inside algorithm from linear probabilities.
r(=, lse, add, cons)
Inside algorithm with log-scaling from log probabilities
r(=, =, mul, max)
Viterbi probabilities.
  129graph_fold(SR, Params, Graph, GoalSums) :-
  130   rb_empty(E),
  131   foldl(sr_sum(SR), Graph, GoalSums, E, Map),
  132   fmap_sws(Map, SWs),
  133   maplist(fmap_collate_sw(sr_param(SR),true1,Map),SWs,Params).
  135sr_sum(SR, Goal-Expls, Goal-Result) -->
  136   fmap(Goal,Proj), {sr_zero(SR,Zero)},
  137   run_right(foldr(sr_add_prod(SR),Expls), Zero, Sum),
  138   {sr_proj(SR,Goal,Sum,Proj,Result)}.
  140sr_add_prod(SR, Expl) -->
  141   {sr_unit(SR,Unit)},
  142   run_right(foldr(sr_factor(SR), Expl), Unit, Prod) <\> sr_plus(SR,Prod).
  144sr_factor(SR, M:Head)  --> !, fmap(M:Head,X) <\> sr_times(SR,X).
  145sr_factor(SR, SW:=Val) --> !, fmap(SW:=Val,X) <\> sr_times(SR,X).
  146sr_factor(SR, \P)      --> {sr_inj(SR,P,\P,X)}, \> sr_times(SR,X).
  148sr_param(SR,F,X,P) :- sr_inj(SR,P,F,X), !.
  150% --------- semirings ---------
  151sr_inj(id,        _, F, F).
  152sr_inj(r(I,_,_,_),  P, _, X)   :- call(I,P,X).
  153sr_inj(best,      P, F, Q-F)   :- log(P,Q).
  154sr_inj(ann(SR),   P, F, Q-F)   :- sr_inj(SR,P,F,Q).
  155sr_inj(R1-R2,     P, F, Q1-Q2) :- sr_inj(R1,P,F,Q1), sr_inj(R2,P,F,Q2).
  157sr_proj(id,       G, Z,   G, Z).
  158sr_proj(r(_,P,_,_), _, X, Y, Y) :- call(P,X,Y).
  159sr_proj(best,     G, X-E, X-(G-E), X-E).
  160sr_proj(ann(SR),  G, X-Z, Y-G, W-Z)       :- sr_proj(SR,G,X,Y,W).
  161sr_proj(R1-R2,    G, X1-X2, Y1-Y2, Z1-Z2) :- sr_proj(R1,G,X1,Y1,Z1), sr_proj(R2,G,X2,Y2,Z2).
  163sr_plus(id,       Expl) --> cons(Expl).
  164sr_plus(r(_,_,_,O), X) --> call(O,X).
  165sr_plus(best,     X) --> v_max(X).
  166sr_plus(ann(SR),  X-Expl) --> sr_plus(SR,X) <\> cons(X-Expl).
  167sr_plus(R1-R2,    X1-X2) --> sr_plus(R1,X1) <\> sr_plus(R2,X2).
  169sr_times(id,       F)   --> cons(F).
  170sr_times(r(_,_,O,_), X) --> call(O,X).
  171sr_times(best,     X-F) --> add(X) <\> cons(F).
  172sr_times(ann(SR),  X-F) --> sr_times(SR,X) <\> cons(X-F).
  173sr_times(R1-R2,    X1-X2) --> sr_times(R1,X1) <\> sr_times(R2,X2).
  175sr_zero(id,       []).
  176sr_zero(r(_,_,_,O), I) :- m_zero(O,I).
  177sr_zero(best,     Z-_)   :- m_zero(max,Z).
  178sr_zero(ann(SR),  Z-[])  :- sr_zero(SR,Z).
  179sr_zero(R1-R2,    Z1-Z2) :- sr_zero(R1,Z1), sr_zero(R2,Z2).
  181sr_unit(id,       []).
  182sr_unit(r(_,_,O,_), I) :- m_zero(O,I).
  183sr_unit(best,     0.0-[]).
  184sr_unit(ann(SR),  U-[])  :- sr_unit(SR,U).
  185sr_unit(R1-R2,    U1-U2) :- sr_unit(R1,U1), sr_unit(R2,U2).
  196v_max(LX-X,LY-Y,Z) :- when(ground(LX-LY),(LX>=LY -> Z=LX-X; Z=LY-Y)).
 graph_inside(+G:graph, ?P:sw_params, -IG:igraph) is det
  199graph_inside(Graph, Params, IGraph)  :-
  200   graph_fold(ann(r(=,=,mul,add)), Params, Graph, IGraph).
 igraph_sample_tree(+IG:igraph, +H:goal, -Ts:list(tree)) is det
Uses prob effect to sample a tree from a graph annotated with inside probabilities, as produced by graph_inside/3/
  206igraph_sample_tree(Graph, Head, Subtrees) :-
  207   memberchk(Head-(_-Expls), Graph), % Head should be unique in graph
  208   zip(Ps,Es,Expls), stoch(Ps,Ps1,_), dist(Ps1,Es,Expl),
  209   maplist(sample_subexpl_tree(Graph), Expl, Subtrees).
  211sample_subexpl_tree(G, _-(M:Goal), (M:Goal)-Tree) :- !, igraph_sample_tree(G, M:Goal, Tree).
  212sample_subexpl_tree(_, _-Factor,   Factor).
 igraph_entropy(+S:scaling, +IG:igraph, -Es:list(pair(goal,float))) is det
Explanation entropies from annotated explanation graph.
  216igraph_entropy(Scaling, IGraph, GoalEntropies) :-
  217   rb_empty(E),
  218   foldl(goal_entropy(Scaling), IGraph, GoalEntropies, E, Map),
  219   rb_visit(Map, GoalEntropies).
  221goal_entropy(Scaling, Goal-(_ - WeightedExpls), Goal-Entropy) -->
  222   fmap(Goal,Entropy),
  223   {zip(Ws, Es, WeightedExpls), scaling_stoch(Scaling, Ws, Ps)},
  224   run_right(foldl(expl_entropy(Scaling),Ps,Es), 0.0, Entropy).
  226scaling_stoch(lin,X,Y) :- stoch(X,Y).
  227scaling_stoch(log,X,Y) :- log_stoch(X,Y).
  229expl_entropy(Scaling, Pe, Expl) -->
  230   {when(ground(FactorsEntropy-Pe), expl_entropy(Scaling, Pe, FactorsEntropy, ExplEntropy))},
  231   run_right(foldl(mr(snd,factor_entropy),Expl), 0.0, FactorsEntropy) <\> add(ExplEntropy).
  233expl_entropy(lin, Pe, HFactors, HE) :- HE is Pe*(HFactors - log(Pe)).
  234expl_entropy(log, Pe, HFactors, HE) :- HE is exp(Pe)*(HFactors - Pe).
  236factor_entropy(M:Head) --> !, fmap(M:Head,H) <\> add(H).
  237factor_entropy(_) --> [].
 graph_counts(+Meth:counts_method, +PSc:scaling, +G:graph, P:sw_params, C:sw_params, LP:float) is det
Compute expected switch counts C from explanation graph G with switch parameters P. Uses automatic differentiation of the expression for the log of the inside probability LP of the graph. Params can be unbound - binding them later triggers the computations required to yield numerical values in the result.
counts_method ---> io(scaling); vit.
  249graph_counts(Method, PSc, Graph, Params, Eta, LogProb) :-
  250   method_scaling_semiring(Method, ISc, SR, ToLogProb),
  251   graph_fold(SR, P0, Graph, IG), autodiff2:expand_wsums,
  252   call(ToLogProb*top_value, IG, LogProb),
  253   scaling_log_params(ISc, PSc, P0, Params0, LogP0),
  254   map_swc(deriv(LogProb), LogP0, Eta),
  255   back(LogProb), Params=Params0.
  257method_scaling_semiring(vit,     log, r(=,=,autodiff2:add,autodiff2:max), =).
  258method_scaling_semiring(io(lin), lin, r(=,=,autodiff2:mul,autodiff2:add), autodiff2:log).
  259method_scaling_semiring(io(log), log, r(=,autodiff2:lse, autodiff2:add,cons), =).
  260method_scaling_semiring(io(log_wsum), log, r(=,autodiff2:lse, autodiff2:add_to_wsum,cons), =).
  262scaling_log_params(lin, lin, P0,    P0,    LogP0) :- map_swc(autodiff2:llog, P0, LogP0).
  263scaling_log_params(lin, log, P0,    LogP0, LogP0) :- map_swc(autodiff2:exp, LogP0, P0).
  264scaling_log_params(log, lin, LogP0, P0,    LogP0) :- map_swc(autodiff2:log, P0, LogP0).
  265scaling_log_params(log, log, LogP0, LogP0, LogP0).
 accum_stats(+Acc:pred(fmap(int),fmap(int)), +GSWs:pred(fmap(int),list(switch(_))), -Stats:sw_params) is det
  268:- meta_predicate accum_stats(//,2,-).  269accum_stats(Acc, GetSWs, Stats) :-
  270   rb_empty(C0),
  271   call_dcg(Acc,C0,C1), call(GetSWs,C1,SWs),
  272   maplist(fmap_collate_sw(right,=(0),C1),SWs,Stats).
 tree_stats(+T:tree, -C:sw_params) is det
  275tree_stats(Tree,Counts) :- accum_stats(tree_stats(Tree), fmap_sws, Counts).
  277sw_trees_stats(SWs,Trees,Stats) :- accum_stats(tree_stats(_-Trees),const(SWs),Stats).
  279tree_stats(_-Subtrees) --> foldl(subtree_stats,Subtrees).
  280subtree_stats(_-Trees) --> foldl(subtree_stats,Trees).
  281subtree_stats(SW:=Val) --> rb_app(SW:=Val,succ) -> []; rb_add(SW:=Val,1).
  282subtree_stats(\_)      --> [].
  285% --- Factor-value map, used internally --------------
  286% =| fmap(A) == rbtree(factor, A).
 fmap(+F:factor, ?X:A, +M1:fmap(A), -M2:fmap(A)) is det
Unify X with value under K in M1 if present, otherwise add it.
  290fmap(X,Y) --> rb_add(X,Y) -> []; rb_get(X,Y).
 fmap_sws(+M:fmap(A), -SWs:list(switch(_))) is det
Collect sorted list of switches from keys in M.
  294fmap_sws(Map,SWs) :- rb_fold(emit_if_sw,Map,SWs1,[]), sort(SWs1,SWs).
  295emit_if_sw(F-_) --> {F=(SW:=_)} -> [SW]; [].
 fmap_collate_sw(+Conv:pred(factor,+A,-B), +Def:pred(-B), +M:fmap(A), +SW:switch(_), -SWX:pair(switch(_),list(B))) is det
Collect parameter data for each value of a switch. Either the data is extracted from the map and converted using Conv, or created using Def.
  300:- meta_predicate fmap_collate_sw(3,1,+,+,?).  301fmap_collate_sw(Conv,Def,Map,SW,SW-XX) :-
  302   call(SW,_,Vals,[]), maplist(sw_val_or_default(Conv,Def,Map,SW),Vals,XX).
  304sw_val_or_default(Conv,Def,Map,SW,Val,X) :-
  305   rb_lookup(SW:=Val, P, Map) -> call(Conv,SW:=Val,P,X); call(Def,X)