1:- module(ccp_mcmc, [mc_evidence/4, mh_machine/4, gibbs_machine/5]).

Gibbs and Metropolis-Hastings explanation samplers */

    5:- use_module(library(insist)).    6:- use_module(library(callutils),   [(*)/4]).    7:- use_module(library(listutils),   [enumerate/2]).    8:- use_module(library(math),        [neg/2, add/3, sub/3, exp/2, map_sum/4]).    9:- use_module(library(data/pair),   [is_pair/1, pair/3, fst/2, fsnd/3, snd/2]).   10:- use_module(library(plrand),      [log_partition_dirichlet/2]).   11:- use_module(library(machines),    [unfold/2, unfolder/3, mapper/3, scan0/4, (:>)/3, mean/2, op(600,yfx,:>)]).   12
   13:- use_module(effects,    [dist/2, uniform/2]).   14:- use_module(learn,      [converge/5, learn/4]).   15:- use_module(switches,   [ map_sum_sw/3, map_sum_sw/4, map_swc/4
   16                          , sw_expectations/2, sw_log_prob/3, sw_posteriors/3, sw_samples/2
   17                          ]).   18:- use_module(graph,      [ top_goal/1, top_value/2, tree_stats/2, sw_trees_stats/3, graph_fold/4
   19                          , graph_inside/3, prune_graph/4, igraph_sample_tree/3
   20                          ]).   21
   22bernoulli(P1,X) :- P0 is 1-P1, dist([P0-0,P1-1],X).
   23
   24mc_evidence(Method, Graph, Prior, Stream) :-
   25   converge(rel(1e-6), learn(vb(Prior), io(log), Graph), _, Prior, VBPost),
   26   sw_expectations(VBPost, VBProbs),
   27   call(top_value*graph_fold(r(log,lse,add,cons),VBProbs), Graph, LogPDataGivenVBProbs),
   28   call(add(LogPDataGivenVBProbs)*sw_log_prob(Prior), VBProbs, LogPDataVBProbs),
   29   method_machine_mapper(Method, Prior, Machine, Mapper),
   30   unfold(call(Machine, Graph, Prior, VBProbs)
   31          :> mapper(p_params_given_post(VBProbs)*Mapper) :> mean
   32          :> mapper(add(LogPDataVBProbs)*neg*log), Stream).
   33
   34p_params_given_post(Probs,Post,P) :- sw_log_prob(Post,Probs,LP), P is exp(LP).
   35
   36method_machine_mapper(gibbs,  _,     ccp_mcmc:gibbs_machine(posterior), =).
   37method_machine_mapper(mh,     Prior, ccp_mcmc:mh_machine, ccp_mcmc:sw_posteriors(Prior)*mcs_counts).
   38
   39gibbs_machine(Rot, Graph, Prior, P1, M) :-
   40   graph_inside(Graph, P0, IG),
   41   rotation(Rot, sw_posteriors(Prior), gstep(P0,IG), sw_samples, Step),
   42   unfolder(scan0(Step), P1, M).
   43
   44:- meta_predicate rotation(+,2,2,2,-).   45rotation(posterior,Post, Step, Sample, Post*Step*Sample).
   46rotation(counts,   Post, Step, Sample, Step*Sample*Post).
   47rotation(params,   Post, Step, Sample, Sample*Post*Step).
   48
   49gstep(P0,IG,P1,Counts) :-
   50   copy_term(P0-IG,P1-IG1),
   51   top_goal(Top),
   52   igraph_sample_tree(IG1, Top, Trees),
   53   tree_stats(Top-Trees, Counts).
   54
   55mh_machine(Graph, Prior, Probs0, M) :-
   56   graph_as_conjunction(Graph, Graph1),
   57   call(snd * top_value * graph_fold(best, Probs0), Graph1, VTrees),
   58   maplist(fst,Prior,SWs),
   59   mcs_init(SWs, VTrees, Keys, State),
   60   (  Keys=[] -> unfolder(scan0(=), State, M)
   61   ;  make_tree_sampler(Graph1, SampleGoal),
   62      (mcs_unit_counts(State) -> Stepper=gibbs; Stepper=mh),
   63      unfolder(scan0(mc_step(Stepper, Keys, SampleGoal, SWs, Prior)), State, M)
   64   ).
   65
   66graph_as_conjunction(Graph, Graph) :- top_value(Graph, [_]), !.
   67graph_as_conjunction(Graph, [Top-[[Dummy]], Dummy-Expls | Graph0]) :-
   68   top_goal(Top), Dummy = '^mcmc':dummy,
   69   select(Top-Expls, Graph, Graph0).
   70
   71mc_sample(SampleGoal, SWs, Probs, T1, T2) :-
   72   mct_goal(T1, Goal), call(SampleGoal, Probs, Goal, Tree),
   73   mct_make(SWs, Goal, Tree, T2).
   74
   75make_tree_sampler(G, ccp_mcmc:sample_goal(IGs)) :-
   76   top_value(G, [Factors]),
   77   sort(Factors, UniqueFactors),
   78   maplist(sub_igraph(G), UniqueFactors, IGs).
   79
   80sub_igraph(G, Goal, Goal-(IG-Ps)) :-
   81   prune_graph(=, Goal, G, SubGraph),
   82   graph_inside(SubGraph, Ps, IG).
   83
   84sample_goal(IGs, PP, Goal, Trees) :-
   85   memberchk(Goal-(IG0-P0), IGs), % use rbtree for faster lookup
   86   copy_term(P0-IG0, P1-IG1),
   87   param_subset(P1, PP),
   88   igraph_sample_tree(IG1, Goal, Trees).
   89
   90param_subset([], _).
   91param_subset([H1-V1|T1], [H2-V2|T2]) :-
   92    compare(Rel, H1, H2),
   93    psub_aux(Rel, H1, V1, V2, T1, T2).
   94
   95psub_aux(>, H1, V1, _, T1, [H2-V2|T2]) :-
   96    compare(Rel, H1, H2),
   97    psub_aux(Rel, H1, V1, V2, T1, T2).
   98psub_aux(=, _, V, V, T1, T2) :-
   99    param_subset(T1, T2).
  100
  101mc_step(gibbs, Keys, SampleGoal, SWs, Prior, State1, State2) :-
  102   mcs_random_select(Keys, TK_O, State1, StateExK),
  103   mcs_dcounts(StateExK, CountsExK),
  104   sw_posteriors(Prior, CountsExK, PostExK),
  105   sw_expectations(PostExK, ProbsExK),
  106   mc_sample(SampleGoal, SWs, ProbsExK, TK_O, TK_P),
  107   mcs_rebuild(TK_P, StateExK, State2).
  108
  109mc_step(mh, Keys, SampleGoal, SWs, Prior, State1, State2) :-
  110   mcs_random_select(Keys, TK_O, State1, StateExK),
  111   mcs_dcounts(StateExK, CountsExK),
  112   sw_posteriors(Prior, CountsExK, PostExK),
  113   sw_expectations(PostExK, ProbsExK),
  114   mc_sample(SampleGoal, SWs, ProbsExK, TK_O, TK_P),
  115   maplist(tree_acceptance_weight(PostExK, ProbsExK), [TK_O, TK_P], [W_O, W_P]),
  116   D is W_P-W_O, (D>= -1e-13 -> Accept=1; call(bernoulli*exp, D, Accept)),
  117   (Accept=0 -> State2=State1; mcs_rebuild(TK_P, StateExK, State2)).
  118
  119tree_acceptance_weight(PostExTree, PProbs, Tree, W) :-
  120   mct_counts(Tree, Counts),
  121   sw_posteriors(PostExTree, Counts, Post),
  122   map_sum_sw(log_partition_dirichlet, Post, LZ),
  123   map_sum_sw(map_sum(log_mul), PProbs, Counts, LP),
  124   W is LZ - LP.
  125log_mul(Prob, N, X) :- X is N*log(Prob).
  126
  127% MCS: Monte Carlo state: rbtree to map K to tree, stash counts
  128mcs_init(SWs, VTrees, Ks, Totals-Map) :-
  129   sw_trees_stats(SWs, VTrees, Totals),
  130   call(list_to_rbtree * enumerate * map_stats(SWs) * include(is_pair), VTrees, Map),
  131   rb_keys(Map, Ks).
  132
  133mcs_random_select(Ks, G-C, Totals-Map, dmhs(K,CountsExK,MapExK)) :-
  134   uniform(Ks,K),
  135   rb_delete(Map, K, G-C, MapExK),
  136   map_swc(sub, C, Totals, CountsExK).
  137
  138mcs_rebuild(G-C, dmhs(K,CountsExK,MapExK), Totals-Map) :-
  139   sw_posteriors(C, CountsExK, Totals),
  140   rb_insert_new(MapExK, K, G-C, Map).
  141
  142mcs_dcounts(dmhs(_,CountsExK,_), CountsExK).
  143mcs_counts(Counts-_, Counts).
  144mcs_unit_counts(_-Map) :-
  145   forall(rb_in(_,_-GCs,Map), forall(member(_-C, GCs), sumlist(C,1))).
  146
  147mct_goal(Goal-_, Goal).
  148mct_make(SWs, Goal, T, Goal-C) :- sw_trees_stats(SWs,T,C).
  149mct_counts(_-C,C).
  150
  151map_stats(SWs, Trees, Stats) :- maplist(fsnd(sw_trees_stats(SWs)), Trees, Stats)