1:- module(autodiff2, [max/3, mul/3, add/3, pow/3, exp/2, llog/2, log/2, lse/2, deriv/3, back/1, grad/1,
    2                      esc/3, expand_wsums/0, wsum/2, add_to_wsum/3, gather_ops/3]).

Reverse mode automatic differentatin using CHR.

Todo:

*/

    9:- use_module(library(chr)).   10:- use_module(library(rbutils)).   11:- use_module(library(listutils), [measure/2]).   12:- use_module(library(dcg_pair)).   13:- use_module(library(dcg_macros)).   14
   15:- chr_constraint expand_wsums, wsum(?,-), add_to_wsum(?,?,-), ops(-,+).
   16:- chr_constraint max(?,?,-), add(?,?,-), mul(?,?,-), llog(-,-), log(-,-), exp(-,-), pow(+,-,-),
   17                  lse(?,-), stoch_exp(?,-), stoch_exp(?,+,-), mes(?,-,-,-), chi(?,?,?,-),
   18                  deriv(?,-,?), agg(?,-), acc(?,-), acc(-), go, esc(+,?,-).
   19
   20add_to_wsum(X,0.0,S) <=> ord_list_to_rbtree([X-1], Terms), wsum(Terms, S).
   21add_to_wsum(X,S1,S2), wsum(Terms1, S1) <=> incr_term(X, Terms1, Terms2), wsum(Terms2, S2).
   22
   23incr_term(X) --> rb_app_or_new(X, succ, =(1)).
   24add_mul(X-N, S1, S2) :- K is float(N), mul(K,X,NX), add(NX,S1,S2).
   25expand_wsums \ wsum(Terms, Sum) <=> rb_fold(add_mul, Terms, 0.0, Sum).
   26expand_wsums <=> true.
   27
   28% operations interface with simplifications
   29mul(0.0,_,Y) <=> Y=0.0.
   30mul(_,0.0,Y) <=> Y=0.0.
   31mul(1.0,X,Y) <=> Y=X.
   32mul(X,1.0,Y) <=> Y=X.
   33mul(X,Y,Z1) \ mul(X,Y,Z2) <=> Z1=Z2.
   34pow(K,X,Y)   <=> K =:= 1 | Y=X. % guard to match floats and ints
   35pow(0,_,Y)   <=> Y=1.
   36add(0.0,X,Y) <=> Y=X.
   37add(X,0.0,Y) <=> Y=X.
   38add(X,Y,Z1) \ add(X,Y,Z2) <=> Z1=Z2.
   39
   40% lse: log(sum(map(exp,Xs))), stoch_exp: stoch(map(exp,Xs))
   41% mes: max, exp, sum - used to share computation of max(Xs), exp(Exs-Max) and sum
   42lse([X],Y) <=> X=Y.
   43lse(Xs,Y1) \ lse(Xs,Y2) <=> Y1=Y2.
   44lse(Xs,_) ==> mes(Xs,_,_,_).
   45stoch_exp(Xs,Ys1) \ stoch_exp(Xs,Ys2) <=> Ys1=Ys2.
   46stoch_exp(Xs,Ys) ==> mes(Xs,_,_,_), measure(Xs,Ns), maplist(stoch_exp(Xs),Ns,Ys).
   47mes(Xs,M1,Ws1,S1) \ mes(Xs,M2,Ws2,S2) <=> M1=M2, Ws1=Ws2, S1=S2.
   48
   49% propagate derivatives through unary and binary operators
   50deriv(L,X,DX) \ deriv(L,X,DX1) <=> DX=DX1.
   51deriv(L,_,DX) <=> ground(L) | DX=0.0.
   52deriv(L,L,DL) ==> DL=1.0.
   53deriv(_,_,DX) ==> var(DX) | acc(DX).
   54deriv(L,Y,DY), pow(K,X,Y)   ==> deriv(L,X,DX), dpow(K,X,Z), mul(DY,Z,T), agg(T,DX).
   55deriv(L,Y,DY), exp(X,Y)     ==> deriv(L,X,DX), mul(Y,DY,T), agg(T,DX).
   56deriv(L,Y,DY), llog(Y,X)    ==> deriv(L,X,DX), mul(Y,DY,T), agg(T,DX).
   57deriv(L,Y,DY), log(X,Y)     ==> deriv(L,X,DX), pow(-1,X,RX), mul(RX,DY,T), agg(T,DX).
   58deriv(L,Y,DY), add(X1,X2,Y) ==> maplist(agg_add(L,DY),[X1,X2]).
   59deriv(L,Y,DY), mul(X1,X2,Y) ==> maplist(agg_mul(L,DY),[X1,X2],[X2,X1]).
   60deriv(L,Y,DY), max(X1,X2,Y) ==> maplist(agg_max(L,DY),[X1,X2],[X2,X1]).
   61deriv(L,Y,DY), lse(Xs,Y)    ==> stoch_exp(Xs,Ps), maplist(agg_mul(L,DY),Xs,Ps).
   62deriv(L,Y,DY), stoch_exp(Xs,N,Y) ==>
   63   pow(2,Y,Y2), mul(-1.0,Y2,NY2),
   64   mul(DY,NY2,T1), mul(DY,Y,T2),
   65   maplist(deriv(L),Xs,DXs), % !!! NB the rest is wrong for any constants in Xs
   66   maplist(agg(T1),DXs),
   67   nth1(N,DXs,DXN),
   68   agg(T2,DXN).
   69
   70dpow(K,X,T) :- K1 is K - 1, KK is float(K), pow(K1,X,XpowK1), mul(KK,XpowK1,T).
   71agg_max(L,DY,X1,X2) :- var(X1) -> deriv(L,X1,DX1), chi(X1,X2,DY,T1), agg(T1,DX1); true.
   72agg_mul(L,DY,X1,X2) :- var(X1) -> deriv(L,X1,DX1), mul(X2,DY,T1), agg(T1,DX1); true.
   73agg_add(L,DY,X1)    :- var(X1) -> deriv(L,X1,DX1), agg(DY,DX1); true.
   74acc(X) \ acc(X) <=> true.
   75
   76% initiatiate back-propagation starting from Y
   77back(Y) :- var(Y) -> diff(Y), go; true.
   78diff(Y) :- deriv(Y,Y,1.0).
   79grad(Ys) :- maplist(diff,Ys), go.
   80
   81acc(X,S1), agg(Z,X) <=> add(Z,S1,S2), acc(X,S2).
   82acc(X,S) <=> S=X.
   83
   84go \ deriv(_,_,_) <=> true.
   85go \ acc(DX) <=> acc(DX,0.0).
   86go <=> true.
   87
   88:- meta_predicate upd_ops(//,?,?).   89upd_ops(Upd,G1,G3) :- call(Upd,G1,G2), ops(G2,G3).
   90op(Op, Ins, Outs) --> [op(Op,Ins,Outs)].
   91
   92ops(G1,G2), add(X,Y,Z) <=> upd_ops(op(add, [X,Y], [Z]), G1, G2).
   93ops(G1,G2), mul(X,Y,Z) <=> upd_ops(op(mul, [X,Y], [Z]), G1, G2).
   94ops(G1,G2), max(X,Y,Z) <=> upd_ops(op(max, [X,Y], [Z]), G1, G2).
   95ops(G1,G2), pow(X,Y,Z) <=> upd_ops(op(pow, [X,Y], [Z]), G1, G2).
   96ops(G1,G2), log(X,Y)   <=> upd_ops(op(log, [X], [Y]), G1, G2).
   97ops(G1,G2), exp(X,Y)   <=> upd_ops(op(exp, [X], [Y]), G1, G2).
   98ops(G1,G2), esc(Op,X,Y)<=> upd_ops(op(Op, X, Y), G1, G2).
   99ops(_,_) \ llog(_,_)   <=> true.
  100
  101ops(_,_) \ stoch_exp(_,_,_)  <=> true.
  102mes(Xs,M,_,S)  \ ops(G1,G2), lse(Xs,Y)        <=> mes(Xs,M,_,S), upd_ops(add_log(S,M,Y), G1, G2).
  103mes(Xs,_,Ws,S) \ ops(G1,G2), stoch_exp(Xs,Ys) <=> upd_ops(divby_list(S,Ws,Ys), G1, G2).
  104ops(G1,G2), mes(Xs,M,Ws,S)                    <=> upd_ops(max_exp_sum(Xs,M,Ws,S),G1,G2).
  105ops(G1,G2), chi(X,Y,Z,I)                      <=> upd_ops(op(chi, [X,Y,Z], [I]), G1, G2).
  106ops(G1,G2) <=> G1=G2.
  107
  108add_log(S,M,Y) --> op(add_log,[M,S],[Y]).
  109divby_list(S,Ws,Ys) --> foldl(divby(S), Ws, Ys).
  110divby(S,W,Y) --> op(div, [W,S], [Y]).
  111
  112max_exp_sum(Xs,M,Ws,Sum) -->
  113   op(max_list, Xs, [M]),
  114   foldl(exp_sub(M),Xs,Ws),
  115   op(sum_list, Ws, [Sum]).
  116exp_sub(M,X,Y) --> op(exp_sub, [M,X], [Y]).
  117
  118gather_ops(Ins, Outs, Sorted) :-
  119   ops(Ops,[]), rb_empty(E),
  120   foldl(back_links, Ops, E, BS),
  121   traverse(BS, Ins, Outs, Sorted-E, []-_).
  122
  123back_links(Edge) --> {Edge=op(_,_,Outs)}, foldl(back_link(Edge), Outs).
  124back_link(Edge, Out) --> rb_add(Out, Edge).
  125traverse(BS, Ins, Outs) --> \> foldl(insert, Ins), foldl(eval(BS), Outs).
  126insert(X) --> rb_add(X,t).
  127
  128eval(BS, Var) -->
  129   (  ({nonvar(Var)}; \>