1:- module(ccp_learn, [converge/5, learn/4]).

Expectation-maximisation, variational Bayes and deterministic annealing.


    6:- use_module(library(data/pair),  [snd/2]).    7:- use_module(library(callutils),  [(*)/4, true2/2]).    8:- use_module(library(plrand),     [log_partition_dirichlet/2]).    9:- use_module(library(autodiff2),  [esc/3, add/3, mul/3, pow/3, max/3, gather_ops/3]).   10:- use_module(library(clambda),    [clambda/2, run_lambda_compiler/1]).   11:- use_module(library(plflow),     [ops_body/4, sub/3, stoch/2, mean_log_dir/2, log_part_dir/2, log_prob_dir/3]).   12:- use_module(graph,    [graph_counts/6]).   13:- use_module(switches, [map_sw/3, map_swc/3, map_swc/4, map_sum_sw/3]).   14
   16mul_add(K,X,Y,Z) :- mul(K,Y,KY), add(X,KY,Z).
   17map_sum_(P,X,Sum)   :- maplist(P,X,Z),   esc(sum_list,Z,[Sum]).
   18map_sum_(P,X,Y,Sum) :- maplist(P,X,Y,Z), esc(sum_list,Z,[Sum]).
   19map_sum_sw_(P,X,Sum)   :- map_sum_(P*snd,X,Sum).
   20map_sum_sw_(P,X,Y,Sum) :- map_sum_(f2sw1(P),X,Y,Sum).
   21f2sw1(P,SW-X,SW-Y,Z) :- call(P,X,Y,Z).
 learn(+Method:learn_method, +Stats:stats_method, +ITemp:number, +G:graph, -U:learner) is det
 learn(+Method:learn_method, +Stats:stats_method, +G:graph, -U:learner) is det
Get update predicate for several EM-based parameter learning methods. learn/4 invokes learn/5 with ITemp=1.0.
learn_method ---> ml; map(sw_params); vb(sw_params).
stats_method ---> io(scaling); vit.
scaling      ---> lin; log.
learner == pred(-float, +sw_params, -sw_params).
   33learn(Method, StatsMethod, Graph, Step) :-
   34   learn(Method, StatsMethod, 1.0, Graph, Obj, P1, P2),
   35   maplist(term_variables, [P1,P2], [Ins,Outs]),
   36   gather_ops(Ins, [Obj|Outs], Ops), length(Ops, NumOps),
   37   debug(learn(setup), 'Compiled ~d operations.', [NumOps]),
   38   ops_body(Ins, [Obj|Outs], Ops, Body),
   39   clambda(lambda([Obj,P1,P2], Body), Step).
   41learn(ml, Stats, ITemp, Graph, LL, P1, P2) :-
   42   once(graph_counts(Stats, lin, Graph, PP, Eta, LL)),
   43   map_swc(pow(ITemp), P1, PP),
   44   map_sw(stoch, Eta, P2).
   46learn(map(Prior), Stats, ITemp, Graph, Obj, P1, P2) :-
   47   once(graph_counts(Stats, lin, Graph, PP, Eta, LL)),
   48   map_sum_sw_(log_prob_dir, Prior, P1, LP0),
   49   map_swc(add, Eta, Prior, Post),
   50   map_sw(stoch*maplist(max(0.0)*add(-1.0)), Post, P2), % mode
   51   call(mul_add(ITemp, LL), LP0, Obj),
   52   map_swc(pow(ITemp), P1, PP).
   54learn(vb(Prior), Stats, ITemp, Graph, Obj, A1, A2) :-
   55   maplist(map_swc(true2,Prior), [A1,Pi]), % establish same shape as prior
   56   map_swc(mul_add(ITemp,1.0-ITemp), Prior, EffPrior),
   57   map_sum_sw(log_partition_dirichlet, Prior, LogZPrior),
   58   vb_helper(ITemp, LogZPrior, EffPrior, A1, Pi, Div),
   59   once(graph_counts(Stats, log, Graph, Pi, Eta, LL)),
   60   map_swc(mul_add(ITemp), EffPrior, Eta, A2),
   61   sub(Div, LL, Obj).
   63vb_helper(ITemp, LogZPrior, EffPrior, A, Pi, Div) :-
   64   map_sw(mean_log_dir, A, PsiA),
   65   map_swc(sub, EffPrior, A, Delta),
   66   map_swc(mul(ITemp), PsiA, Pi),
   67   map_sum_sw_(log_part_dir, A, LogZA),
   68   map_sum_sw_(map_sum_(mul), PsiA, Delta, Diff),
   69   call(sub(LogZA)*mul_add(ITemp,Diff), LogZPrior, Div).
 converge(+C:convergence, +L:pred(-learner), -LL:list(float), +P1:sw_params, -P2:sw_params) is det
Use L to create a predicate to do one step of learning, and then iterate this until convergence, starting from P1 and ending with P2. History of objective function values is returned in LL. Convergence C is of type:
convergence ---> abs(float); rel(float).
   78:- meta_predicate converge(+,1,-,+,-).   79converge(Test, Setup, [X0|History], S0, SFinal) :-
   80   debug(learn(setup), 'converge: Setting up...',[]),
   81   run_lambda_compiler((
   82      time(call(Setup, Step)),
   83      call(Step, X0, S0, S1),
   84      time(converge_x(Test, Step, X0, History, S1, SFinal)))).
   86converge_x(Test, Step, X0, [X1|History], S1, SFinal) :-
   87   call(Step, X1, S1, S2),
   88   (  converged(Test, X0, X1) -> History=[], SFinal=S2
   89   ;  converge_x(Test, Step, X1, History, S2, SFinal)
   90   ).
   92converged(abs(Eps), X1, X2) :- abs(X1-X2) =< Eps.
   93converged(rel(Del), X1, X2) :- abs((X1-X2)/(X1+X2)) =< Del