```    1:- module(ccp_graph, [ graph_switches/2, prune_graph/4, top_value/2, top_goal/1
2                     , graph_fold/4, graph_inside/3, igraph_sample_tree/3, igraph_entropy/3
3                     , tree_stats/2, sw_trees_stats/3, accum_stats/3, graph_counts/6
4                     ]).```

# Inference and statistics on explanation hypergraphs

This module provides algorithms on explanation hypergraphs, based on the ideas of Sato (in PRISM), Klein and Manning [1] and Goodman [2]. Also provided are methods for sampling explanations from their posterior distribution [2] and computing the entropy of the posterior distribution using a method which, to my knowledge, has not been published before.
[1] D. Klein and C. D. Manning. Parsing and hypergraphs. In New developments in parsing technology, pages 351–372. Springer, 2004.
[2] J. Goodman. Parsing inside-out. PhD thesis, Division of Engineering and Applied Sciences, Harvard University, 1998.
```tree ---> goal - list(tree).
igraph == f_graph(pair,float,float,float)
== list(pair(goal,weighted(list(weighted(list(weighted(factor)))))))
weighted(X) == pair(float,X).```

*/

```   27:- use_module(library(dcg_pair)).   28:- use_module(library(dcg_macros)).   29:- use_module(library(lambdaki)).   30:- use_module(library(typedef)).   31:- use_module(library(math),        [stoch/3]).   32:- use_module(library(listutils),   [cons//1, foldr/4, zip/3]).   33:- use_module(library(callutils),   [mr/5, (*)/4, const/3, true1/1]).   34:- use_module(library(data/pair),   [fst/2, snd/2]).   35:- use_module(library(rbutils),     [rb_in/3, rb_add//2, rb_app//2, rb_get//2]).   36:- use_module(library(autodiff2),   [back/1, deriv/3]).   37:- use_module(library(lazymath),    [max/3, add/3, mul/3, exp/2, log/2, lse/2, stoch/2, log_stoch/2]).   38:- use_module(effects,   [dist/3]).   39:- use_module(switches,  [map_swc/3, map_swc/4]).   40
41:- multifile sr_inj/4, sr_proj/5, sr_times/4, sr_plus/4, sr_unit/2, sr_zero/2, m_zero/2.   42
43:- type graph == list(pair(goal, list(list(factor)))).
44:- type counts_method ---> vit; io(scaling).
45:- type scaling ---> lin; log.```
top_value(+Pairs:list(pair(goal,A)), -X:A) is semidet
Extract the value associated with the goal `'^top':top` from a list of goal-value pairs. This can be applied to explanation graphs or the results of graph_fold/4.
```   51top_value(Pairs, Top) :- memberchk(('^top':top)-Top, Pairs).
52top_goal('^top':top).```
prune_graph(+P:pred(+tcall(F,_,D),-D), +Top:goal, +G1:f_graph(F,A,B,C), -G2:f_graph(F,A,B,C)) is det
`f_graph(F,A,B,C) == list(pair(goal,tcall(F,A,list(tcall(F,B,list(tcall(F,C,factor)))))))`

Prune a graph or annotated graph to keep only goals reachable from a given top goal. With apologies, the type is quite complicated. The input and output graphs are lists of goals paired with annotated explanations. The type of an annotation is described by the type constructor F: `F(E,D)` is the type of a `D` annotated with an `E`. The first argument P knows how to strip off any type of annotation and return the `D`. This is how we dig down into the annotated explanations to find out which subgoals are referenced. For example, if `F = pair`, then P should be `snd`. If `F(E,D) = D` (ie no annotation), then P should be (=). Since PlDoc won't accept high-order type terms, we write `F(E,D)` as `tcall(F,E,D)`, where `tcall` is like `call` for types.

```   67prune_graph(Mapper, Top, GL1, GL2) :-
68   list_to_rbtree(GL1,G1),
69   rb_empty(E), children(Top,Mapper,G1,E,G2),
70   rb_visit(G2,GL2).
71
72% SA 2017/10 - Temporarily weaken pattern matching to handle arbitrary factors
73% children(_:=_, _, _) --> !.
74% children(\_,   _, _) --> !.
75children(Mod:Goal,  M, G) --> !,
77   {call(M, Entry, Expls)},
78   foldl(mr(M,foldl(mr(M,new_children(M,G)))),Expls).
79children(_,  _, _) --> !.
80
81new_children(M, G, F) -->
82   rb_get(F,_) -> []; children(F,M,G).```
graph_switches(+G:graph, -SWs:list(switch(_))) is det
Extract list of switches referenced in an explanation graph.
```   86graph_switches(G,SWs) :- (setof(SW, graph_sw(G,SW), SWs) -> true; SWs=[]).
87graph_sw(G,SW)        :- member(_-Es,G), member(E,Es), member(SW:=_,E).```
graph_fold(+SR:sr(A,B,C,T), ?P:params(T), +G:graph, -R:list(pair(goal,W))) is det
Folds the semiring SR over the explanation graph G. Produces R, a list of pairs of goals in the original graph with the result of the fold for that goal. Different semirings can produce many kinds of parsing analysis. The algebra is not strictly a semiring, as the times and plus operators have types `A, B -> B` and `B, C -> C` respectively as this makes it easier to avoid unnecessary operations like list appending.

An algebra of type `sr(A,B,C,T)` must provide 4 operators and 2 values:

```inject  : T, factor -> A
times   : A, B -> B
plus    : B, C -> C
project : goal, C -> A, W
unit    : B
zero    : C```

Semirings are extensible using multifile predicates sr_inj/4, sr_proj/5, sr_times/4, sr_plus/4, sr_unit/2 and sr_zero/2.

Available semirings in this module:

r(pred(+T, -A), pred(+C, -C), pred(+A, +B, -B), pred(+B, +C, -C))
A term containing the operators in restricted forms as callable terms. The unit and zero for the times and plus operators respectively are looked up in m_zero/2.
best
Finds the best single explanation for each goal. Parameters are assumed to be probabilities.
ann(sr(A, B, C, T))
Annotates the original hypergraph with the results of any semiring analysis.
sr(A1, B1, C1, T)-sr(A1, B1, C1, T)
For each goal, return a pair of results from any two semiring analyses.

Various standard analysis can be obtained by using the appropriate semiring:

r(=, =, mul, add)
Inside algorithm from linear probabilities.
r(=, lse, add, cons)
Inside algorithm with log-scaling from log probabilities
r(=, =, mul, max)
Viterbi probabilities.
```  129graph_fold(SR, Params, Graph, GoalSums) :-
130   rb_empty(E),
131   foldl(sr_sum(SR), Graph, GoalSums, E, Map),
132   fmap_sws(Map, SWs),
133   maplist(fmap_collate_sw(sr_param(SR),true1,Map),SWs,Params).
134
135sr_sum(SR, Goal-Expls, Goal-Result) -->
136   fmap(Goal,Proj), {sr_zero(SR,Zero)},
137   run_right(foldr(sr_add_prod(SR),Expls), Zero, Sum),
138   {sr_proj(SR,Goal,Sum,Proj,Result)}.
139
141   {sr_unit(SR,Unit)},
142   run_right(foldr(sr_factor(SR), Expl), Unit, Prod) <\> sr_plus(SR,Prod).
143
144sr_factor(SR, M:Head)  --> !, fmap(M:Head,X) <\> sr_times(SR,X).
145sr_factor(SR, SW:=Val) --> !, fmap(SW:=Val,X) <\> sr_times(SR,X).
146sr_factor(SR, \P)      --> {sr_inj(SR,P,\P,X)}, \> sr_times(SR,X).
147
148sr_param(SR,F,X,P) :- sr_inj(SR,P,F,X), !.
149
150% --------- semirings ---------
151sr_inj(id,        _, F, F).
152sr_inj(r(I,_,_,_),  P, _, X)   :- call(I,P,X).
153sr_inj(best,      P, F, Q-F)   :- log(P,Q).
154sr_inj(ann(SR),   P, F, Q-F)   :- sr_inj(SR,P,F,Q).
155sr_inj(R1-R2,     P, F, Q1-Q2) :- sr_inj(R1,P,F,Q1), sr_inj(R2,P,F,Q2).
156
157sr_proj(id,       G, Z,   G, Z).
158sr_proj(r(_,P,_,_), _, X, Y, Y) :- call(P,X,Y).
159sr_proj(best,     G, X-E, X-(G-E), X-E).
160sr_proj(ann(SR),  G, X-Z, Y-G, W-Z)       :- sr_proj(SR,G,X,Y,W).
161sr_proj(R1-R2,    G, X1-X2, Y1-Y2, Z1-Z2) :- sr_proj(R1,G,X1,Y1,Z1), sr_proj(R2,G,X2,Y2,Z2).
162
163sr_plus(id,       Expl) --> cons(Expl).
164sr_plus(r(_,_,_,O), X) --> call(O,X).
165sr_plus(best,     X) --> v_max(X).
166sr_plus(ann(SR),  X-Expl) --> sr_plus(SR,X) <\> cons(X-Expl).
167sr_plus(R1-R2,    X1-X2) --> sr_plus(R1,X1) <\> sr_plus(R2,X2).
168
169sr_times(id,       F)   --> cons(F).
170sr_times(r(_,_,O,_), X) --> call(O,X).
171sr_times(best,     X-F) --> add(X) <\> cons(F).
172sr_times(ann(SR),  X-F) --> sr_times(SR,X) <\> cons(X-F).
173sr_times(R1-R2,    X1-X2) --> sr_times(R1,X1) <\> sr_times(R2,X2).
174
175sr_zero(id,       []).
176sr_zero(r(_,_,_,O), I) :- m_zero(O,I).
177sr_zero(best,     Z-_)   :- m_zero(max,Z).
178sr_zero(ann(SR),  Z-[])  :- sr_zero(SR,Z).
179sr_zero(R1-R2,    Z1-Z2) :- sr_zero(R1,Z1), sr_zero(R2,Z2).
180
181sr_unit(id,       []).
182sr_unit(r(_,_,O,_), I) :- m_zero(O,I).
183sr_unit(best,     0.0-[]).
184sr_unit(ann(SR),  U-[])  :- sr_unit(SR,U).
185sr_unit(R1-R2,    U1-U2) :- sr_unit(R1,U1), sr_unit(R2,U2).
186
188m_zero(mul,1.0).
189m_zero(max,-inf).
190m_zero(cons,[]).
191m_zero(autodiff2:mul,1.0).
194m_zero(autodiff2:max,-inf).
195
196v_max(LX-X,LY-Y,Z) :- when(ground(LX-LY),(LX>=LY -> Z=LX-X; Z=LY-Y)).```
graph_inside(+G:graph, ?P:sw_params, -IG:igraph) is det
```  199graph_inside(Graph, Params, IGraph)  :-
200   graph_fold(ann(r(=,=,mul,add)), Params, Graph, IGraph).```
igraph_sample_tree(+IG:igraph, +H:goal, -Ts:list(tree)) is det
Uses prob effect to sample a tree from a graph annotated with inside probabilities, as produced by graph_inside/3/
```  206igraph_sample_tree(Graph, Head, Subtrees) :-
207   memberchk(Head-(_-Expls), Graph), % Head should be unique in graph
208   zip(Ps,Es,Expls), stoch(Ps,Ps1,_), dist(Ps1,Es,Expl),
209   maplist(sample_subexpl_tree(Graph), Expl, Subtrees).
210
211sample_subexpl_tree(G, _-(M:Goal), (M:Goal)-Tree) :- !, igraph_sample_tree(G, M:Goal, Tree).
212sample_subexpl_tree(_, _-Factor,   Factor).```
igraph_entropy(+S:scaling, +IG:igraph, -Es:list(pair(goal,float))) is det
Explanation entropies from annotated explanation graph.
```  216igraph_entropy(Scaling, IGraph, GoalEntropies) :-
217   rb_empty(E),
218   foldl(goal_entropy(Scaling), IGraph, GoalEntropies, E, Map),
219   rb_visit(Map, GoalEntropies).
220
221goal_entropy(Scaling, Goal-(_ - WeightedExpls), Goal-Entropy) -->
222   fmap(Goal,Entropy),
223   {zip(Ws, Es, WeightedExpls), scaling_stoch(Scaling, Ws, Ps)},
224   run_right(foldl(expl_entropy(Scaling),Ps,Es), 0.0, Entropy).
225
226scaling_stoch(lin,X,Y) :- stoch(X,Y).
227scaling_stoch(log,X,Y) :- log_stoch(X,Y).
228
229expl_entropy(Scaling, Pe, Expl) -->
230   {when(ground(FactorsEntropy-Pe), expl_entropy(Scaling, Pe, FactorsEntropy, ExplEntropy))},
231   run_right(foldl(mr(snd,factor_entropy),Expl), 0.0, FactorsEntropy) <\> add(ExplEntropy).
232
233expl_entropy(lin, Pe, HFactors, HE) :- HE is Pe*(HFactors - log(Pe)).
234expl_entropy(log, Pe, HFactors, HE) :- HE is exp(Pe)*(HFactors - Pe).
235
237factor_entropy(_) --> [].```
graph_counts(+Meth:counts_method, +PSc:scaling, +G:graph, P:sw_params, C:sw_params, LP:float) is det
Compute expected switch counts C from explanation graph G with switch parameters P. Uses automatic differentiation of the expression for the log of the inside probability LP of the graph. Params can be unbound - binding them later triggers the computations required to yield numerical values in the result.
`counts_method ---> io(scaling); vit.`
```  249graph_counts(Method, PSc, Graph, Params, Eta, LogProb) :-
250   method_scaling_semiring(Method, ISc, SR, ToLogProb),
251   graph_fold(SR, P0, Graph, IG), autodiff2:expand_wsums,
252   call(ToLogProb*top_value, IG, LogProb),
253   scaling_log_params(ISc, PSc, P0, Params0, LogP0),
254   map_swc(deriv(LogProb), LogP0, Eta),
255   back(LogProb), Params=Params0.
256
257method_scaling_semiring(vit,     log, r(=,=,autodiff2:add,autodiff2:max), =).
258method_scaling_semiring(io(lin), lin, r(=,=,autodiff2:mul,autodiff2:add), autodiff2:log).
259method_scaling_semiring(io(log), log, r(=,autodiff2:lse, autodiff2:add,cons), =).
260method_scaling_semiring(io(log_wsum), log, r(=,autodiff2:lse, autodiff2:add_to_wsum,cons), =).
261
262scaling_log_params(lin, lin, P0,    P0,    LogP0) :- map_swc(autodiff2:llog, P0, LogP0).
263scaling_log_params(lin, log, P0,    LogP0, LogP0) :- map_swc(autodiff2:exp, LogP0, P0).
264scaling_log_params(log, lin, LogP0, P0,    LogP0) :- map_swc(autodiff2:log, P0, LogP0).
265scaling_log_params(log, log, LogP0, LogP0, LogP0).```
accum_stats(+Acc:pred(fmap(int),fmap(int)), +GSWs:pred(fmap(int),list(switch(_))), -Stats:sw_params) is det
```  268:- meta_predicate accum_stats(//,2,-).  269accum_stats(Acc, GetSWs, Stats) :-
270   rb_empty(C0),
271   call_dcg(Acc,C0,C1), call(GetSWs,C1,SWs),
272   maplist(fmap_collate_sw(right,=(0),C1),SWs,Stats).```
tree_stats(+T:tree, -C:sw_params) is det
```  275tree_stats(Tree,Counts) :- accum_stats(tree_stats(Tree), fmap_sws, Counts).
276
277sw_trees_stats(SWs,Trees,Stats) :- accum_stats(tree_stats(_-Trees),const(SWs),Stats).
278
279tree_stats(_-Subtrees) --> foldl(subtree_stats,Subtrees).
280subtree_stats(_-Trees) --> foldl(subtree_stats,Trees).
281subtree_stats(SW:=Val) --> rb_app(SW:=Val,succ) -> []; rb_add(SW:=Val,1).
282subtree_stats(\_)      --> [].
283right(_,X,X).
284
285% --- Factor-value map, used internally --------------
286% =| fmap(A) == rbtree(factor, A).
```
fmap(+F:factor, ?X:A, +M1:fmap(A), -M2:fmap(A)) is det
Unify X with value under K in M1 if present, otherwise add it.
`  290fmap(X,Y) --> rb_add(X,Y) -> []; rb_get(X,Y).`
fmap_sws(+M:fmap(A), -SWs:list(switch(_))) is det
Collect sorted list of switches from keys in M.
```  294fmap_sws(Map,SWs) :- rb_fold(emit_if_sw,Map,SWs1,[]), sort(SWs1,SWs).
295emit_if_sw(F-_) --> {F=(SW:=_)} -> [SW]; [].```
fmap_collate_sw(+Conv:pred(factor,+A,-B), +Def:pred(-B), +M:fmap(A), +SW:switch(_), -SWX:pair(switch(_),list(B))) is det
Collect parameter data for each value of a switch. Either the data is extracted from the map and converted using Conv, or created using Def.
```  300:- meta_predicate fmap_collate_sw(3,1,+,+,?).  301fmap_collate_sw(Conv,Def,Map,SW,SW-XX) :-
302   call(SW,_,Vals,[]), maplist(sw_val_or_default(Conv,Def,Map,SW),Vals,XX).
303
304sw_val_or_default(Conv,Def,Map,SW,Val,X) :-
305   rb_lookup(SW:=Val, P, Map) -> call(Conv,SW:=Val,P,X); call(Def,X)```