`    1:- module(ccp_learn, [converge/5, learn/4]).`

# Expectation-maximisation, variational Bayes and deterministic annealing.

*/

```    6:- use_module(library(data/pair),  [snd/2]).    7:- use_module(library(callutils),  [(*)/4, true2/2]).    8:- use_module(library(plrand),     [log_partition_dirichlet/2]).    9:- use_module(library(autodiff2),  [esc/3, add/3, mul/3, pow/3, max/3, gather_ops/3]).   10:- use_module(library(clambda),    [clambda/2, run_lambda_compiler/1]).   11:- use_module(library(plflow),     [ops_body/4, sub/3, stoch/2, mean_log_dir/2, log_part_dir/2, log_prob_dir/3]).   12:- use_module(graph,    [graph_counts/6]).   13:- use_module(switches, [map_sw/3, map_swc/3, map_swc/4, map_sum_sw/3]).   14
15
17map_sum_(P,X,Sum)   :- maplist(P,X,Z),   esc(sum_list,Z,[Sum]).
18map_sum_(P,X,Y,Sum) :- maplist(P,X,Y,Z), esc(sum_list,Z,[Sum]).
19map_sum_sw_(P,X,Sum)   :- map_sum_(P*snd,X,Sum).
20map_sum_sw_(P,X,Y,Sum) :- map_sum_(f2sw1(P),X,Y,Sum).
21f2sw1(P,SW-X,SW-Y,Z) :- call(P,X,Y,Z).```
learn(+Method:learn_method, +Stats:stats_method, +ITemp:number, +G:graph, -U:learner) is det
learn(+Method:learn_method, +Stats:stats_method, +G:graph, -U:learner) is det
Get update predicate for several EM-based parameter learning methods. learn/4 invokes learn/5 with ITemp=1.0.
```learn_method ---> ml; map(sw_params); vb(sw_params).
stats_method ---> io(scaling); vit.
scaling      ---> lin; log.
learner == pred(-float, +sw_params, -sw_params).```
```   33learn(Method, StatsMethod, Graph, Step) :-
34   learn(Method, StatsMethod, 1.0, Graph, Obj, P1, P2),
35   maplist(term_variables, [P1,P2], [Ins,Outs]),
36   gather_ops(Ins, [Obj|Outs], Ops), length(Ops, NumOps),
37   debug(learn(setup), 'Compiled ~d operations.', [NumOps]),
38   ops_body(Ins, [Obj|Outs], Ops, Body),
39   clambda(lambda([Obj,P1,P2], Body), Step).
40
41learn(ml, Stats, ITemp, Graph, LL, P1, P2) :-
42   once(graph_counts(Stats, lin, Graph, PP, Eta, LL)),
43   map_swc(pow(ITemp), P1, PP),
44   map_sw(stoch, Eta, P2).
45
46learn(map(Prior), Stats, ITemp, Graph, Obj, P1, P2) :-
47   once(graph_counts(Stats, lin, Graph, PP, Eta, LL)),
48   map_sum_sw_(log_prob_dir, Prior, P1, LP0),
50   map_sw(stoch*maplist(max(0.0)*add(-1.0)), Post, P2), % mode
52   map_swc(pow(ITemp), P1, PP).
53
54learn(vb(Prior), Stats, ITemp, Graph, Obj, A1, A2) :-
55   maplist(map_swc(true2,Prior), [A1,Pi]), % establish same shape as prior
57   map_sum_sw(log_partition_dirichlet, Prior, LogZPrior),
58   vb_helper(ITemp, LogZPrior, EffPrior, A1, Pi, Div),
59   once(graph_counts(Stats, log, Graph, Pi, Eta, LL)),
61   sub(Div, LL, Obj).
62
63vb_helper(ITemp, LogZPrior, EffPrior, A, Pi, Div) :-
64   map_sw(mean_log_dir, A, PsiA),
65   map_swc(sub, EffPrior, A, Delta),
66   map_swc(mul(ITemp), PsiA, Pi),
67   map_sum_sw_(log_part_dir, A, LogZA),
68   map_sum_sw_(map_sum_(mul), PsiA, Delta, Diff),
converge(+C:convergence, +L:pred(-learner), -LL:list(float), +P1:sw_params, -P2:sw_params) is det
Use L to create a predicate to do one step of learning, and then iterate this until convergence, starting from P1 and ending with P2. History of objective function values is returned in LL. Convergence C is of type:
`convergence ---> abs(float); rel(float).`
```   78:- meta_predicate converge(+,1,-,+,-).   79converge(Test, Setup, [X0|History], S0, SFinal) :-
80   debug(learn(setup), 'converge: Setting up...',[]),
81   run_lambda_compiler((
82      time(call(Setup, Step)),
83      call(Step, X0, S0, S1),
84      time(converge_x(Test, Step, X0, History, S1, SFinal)))).
85
86converge_x(Test, Step, X0, [X1|History], S1, SFinal) :-
87   call(Step, X1, S1, S2),
88   (  converged(Test, X0, X1) -> History=[], SFinal=S2
89   ;  converge_x(Test, Step, X1, History, S2, SFinal)
90   ).
91
92converged(abs(Eps), X1, X2) :- abs(X1-X2) =< Eps.
93converged(rel(Del), X1, X2) :- abs((X1-X2)/(X1+X2)) =< Del```