/*
 * Prolog part of random generator library
 * Samer Abdallah (2009)
*/
	  
:- module(plrand, [
		rv/2 					% -Spec, -Type
	,	sample/4				% +Dist, -Value, +StateIn, -StateOut
	,	sample/2				% +Dist, -Value
	,	get_rnd_state/1	% -State
	,	set_rnd_state/1	% +State
	,	is_rnd_state/1		% +State
	,	init_rnd_state/1	% -State
	,	with_rnd_state/1	% :Goal
	,	reset_rnd_state/0 
	,	randomise/1			% -State
	,	init_jump/2       % +Int, -Jump
	,	double_jump/2		% +Jump, -Jump
	,	spawn/3           % -State, +State, -State
	,	jump/3				% +Jump, +State, -State

	,	str_init/3
	,	str_split/3
	,	str_sample/4

	,	sample_Single_/1  % -Float
	,	sample_Double_/1  % -Float
	]).
	
:-	load_foreign_library(foreign(plrand)).
:- meta_predicate with_rnd_state(:).

%% rv(-Spec, -Type) is multi.
%
%  Unifies Spec and Type with specifications for all the distributions
%  known to the system.
%
%  @param Spec is a term representing a distribution that can be used
%         with sample/2 or sample/4. The head functor represents the
%         distribution to sampled from and any arguments represent the
%         types of the corresponding arguments to be used when sampling.
%  @param Type represents the type of the sampled value. 
rv( raw, natural).
rv( uniform01, nonneg).
rv( normal,    real).
rv( exponential, nonneg).
rv( gamma(nonneg), nonneg).
rv( studentst(nonneg), real).
rv( poisson(nonneg), natural).
rv( invgamma(nonneg), nonneg).
rv( beta(nonneg,nonneg), nonneg).
rv( zeta(nonneg), nonneg).
rv( binomial(nonneg,natural), natural).
rv( dirichlet(natural,list(nonneg)), list(nonneg)).
rv( dirichlet(list(nonneg)), list(nonneg)).
rv( stable(nonneg,real), real).
rv( bernoulli(nonneg), atom).
rv( bernoulli(nonneg), atom).
rv( discrete(natural,list(nonneg)),natural).
rv( discrete(list(nonneg)),natural).
rv( dirichlet_process(nonneg,rv(A)),rv(A)).
rv( dp(dpparam(A)),A). % well not quite but...

%% sample( +DistExpression, -Value) is det.
%% sample( +DistExpression, -Value, +StateIn, -StateOut) is det.
%
% 	sample/2 and sample/4 implement a small language for describing
%  and sampling from a distribution. 
%
%  sample/2 uses and modifies the global state.
%  sample/4 uses a given random generator state and returns the state
%  on completion, and is designed to compatible with DCG syntax to hide
%  the threading of the random state through several consecutive calls.
%
%  DistExpression is an expression describing a distribution.  The head
%  functor can be a distribution name (as listed by rv/2) or one of a number
%  of arithmetic operators or term constructors. The arguments (in almost all
%  cases) can then be further DistExpressions which are evaluated recursively. 
%  Valid non-distributional termsare:
%
%  	* X*Y
%		returns product of samples from X and Y
%  	* X/Y
%		returns ratio of samples from X and Y
%  	* X+Y
%		returns sum of samples from X and Y
%  	* X-Y
%		returns difference of samples from X and Y
%  	* -X
%		returns the negation of a sample from X
%		* sqrt(X)
%		returns square root of sample from X
%		* [X,Y,...]
%     return samples from X, Y etc in a list
%		* rep(N,X)
%     returns N independent sample from X in a list (N must be a constant)
%     * factorial(N,X)
%     returns N independent sample from X in an N argument term \(X1,...,XN)
%     * <any number>
%     returns itself as a constant
%
%     For example
%
%     ?- sample(invgamma(1)*rep(8,discrete(dirichlet(rep(4,0.5)))),X).
%     X = [3, 2, 3, 3, 3, 1, 2, 2] .

sample( raw, X)   --> sample_Raw(X1), {X is integer(X1)}.

sample( uniform01, X)   --> sample_Uniform01(X).
sample( normal, X)      --> sample_Normal(X).
sample( exponential, X) --> sample_Exponential(X).
sample( gamma(A), X)    --> sample(A,A1), sample_Gamma(A1,X).
sample( poisson(A), X)  --> sample(A,A1), sample_Poisson(A1,X).
sample( invgamma(A), X) --> sample(A,A1), sample_Gamma(A1,Y), {X is 1/Y}.
sample( beta(A,B), X)   --> sample(A,A1), sample(B,B1), sample_Beta(A1,B1,X).
sample( zeta(A), X)     --> sample(A,A1), sample_Zeta(A1,X1), {X is integer(X1)}.
sample( pareto(A), X)   --> sample(A,A1), sample_Uniform01(Y), {X is (1-Y)**(-1/A1) }.
sample( binomial(P,N), X)  --> sample(P,P1), sample(N,N1), sample_Binomial(P1,N1,X).
sample( dirichlet(N,A), X) --> sample(A,A1), sample_Dirichlet(N,A1,X).
sample( dirichlet(A), X)   --> sample(A,A1), {length(A1,N)}, sample_Dirichlet(N,A1,X).
sample( stable(A,B), X)    --> sample(A,A1), sample(B,B1), sample_Stable(A1,B1,X).
sample( discrete(N,P), X)  --> sample(P,P1), sample_Discrete(N,P1,X).
sample( discrete(P), X)    --> sample(P,P1), {length(P1,N)}, sample_Discrete(N,P1,X).
sample( bernoulli(P), X)   --> sample(P,P1), sample_Uniform01(U), {U<P1->X=1;X=0}.
sample( studentst(V), X)   --> sample(V/2,V1), sample(normal*sqrt(V1/gamma(V1)),X).

% dps(Vals) represents a sample from a Dirichlet process. It is an infinite
% stream of weight:value pairs which is used as an infinite discrete distribution.

sample( dirichlet_process(Alpha,Src),dps(Vals),S1,S2) :-
	spawn(S3,S1,S2),
	freeze(Vals,plrand:unfold_dp(Alpha,Src,S3,Vals)).

sample(dps([B1:X1|Vals]),X) -->
	sample(uniform01,U),
	(	{U<B1} -> {X=X1}
	;  sample(dps(Vals),X)
	).
	

sample( X*Y, Z) --> sample(X,X1), sample(Y,Y1), {Z is X1*Y1}.
sample( X/Y, Z) --> sample(X,X1), sample(Y,Y1), {Z is X1/Y1}.
sample( X+Y, Z) --> sample(X,X1), sample(Y,Y1), {Z is X1+Y1}.
sample( X-Y, Z) --> sample(X,X1), sample(Y,Y1), {Z is X1-Y1}.
sample( -X, Z)  --> sample(X,X1), {Z is -X1}.
sample( sqrt(X), Z) --> sample(X,X1), {Z is sqrt(X1)}.

sample( [], []) --> [].
sample( [X|XX], [Z|ZZ]) --> sample(X,Z), sample(XX,ZZ).
sample( rep(N,X), Z) --> {length(Z,N)}, seqmap(sample(X),Z).
sample( factorial(N,X), Z) --> {functor(Z,\,N)}, seqmapargs(sample(X),N,Z).
sample( Tuple, Value) -->
	{functor(Tuple,F,N), functor(Value,F,N), tuple_functor(F)},
	seqmapargs(sample,N,Tuple,Value).

sample( N,N,S,S) :- number(N), !.

sample(M,X)     :- get_rnd_state(S1), sample(M,X,S1,S2), set_rnd_state(S2).

% used to step through the stream of parameters sampled from
% dirichlet process.
unfold_dp(Alpha,Src,Seed,[B:X|Tail]) :-
	sample(Src,X,Seed,Seed1),
	sample(beta(1,Alpha),B,Seed1,Seed2),
	freeze(Tail,plrand:unfold_dp(Alpha,Src,Seed2,Tail)).

%% get_rnd_state(-State) is det.
%
%  Unifies State with the current global RNG state.
%  @see set_rnd_state/1


%% set_rnd_state(+State) is det.
%
%  Sets the globab RNG state to State.
%  @see get_rnd_state/1

%% init_rnd_state(-State) is det.
%
%  Unifies State with the state that was set at load time.
%  @see reset_rnd_state/0

%% reset_rnd_state is det.
%  
%  Resets the global random state to the initial one set at load time.
%  This is the same as the one returned by init_rnd_state/1.
%
%  @see init_rnd_state/1
%  @see set_rnd_state/1
reset_rnd_state :- init_rnd_state(S0), set_rnd_state(S0).

%% is_rnd_state(+State) is semidet.
%
%  Succeeds if State is a BLOB atom representing a random generator state.


%% with_rnd_state(:Callable) is nondet.
%
%  Runs DCG phrase Callable using the current global RNG state and setting
%  the global RNG state afterwards.
%
%  @param Callable must be a DCG phrase or callable that takes two more
%         arguments, ie, that can be used with call(:Callable,+S1,-S2)
with_rnd_state(P) :- get_rnd_state(S1), phrase(P,S1,S2), set_rnd_state(S2).


%% randomise(-State) is det.
%
% Unifies State with a new and truly random state obtained by
% taking bits from /dev/random.


%% init_jump(+E:integer, -Jump) is det.
%
% Unifies Jump with a BLOB atom representing an operator to jump forwards
% 2^E steps in the stream of random numbers. The generator has a period
% of about 2^191. Operators for jumping ahead by 2^76 and 2^127 are
% precomputed and so can be returned faster. The resulting BLOB can be
% used with jump/3 to advance any given random generator state.
%
% @see double_jump/2
% @see jump/3

%% double_jump( +Jump1, -Jump2) is det.
%
% Unifies Jump2 with an operator that jumps ahead twice as far Jump1,
% ie if Jump1 jumps by 2^E, then Jump2 jumps by 2^(E+1). Jump operators
% can be created by init_jump/2 and applied by jump/3.
%
% @see double_jump/2
% @see jump/3

%% jump( +Jump, +State1, -State2) is det.
%
% Advances random generator state represented in State1 by the number
% of steps represented by Jump, unifying State2 with the result.
%
% @see double_jump/2
% @see jump/3

%% spawn( -New, +Orig, -Next) is det.
%
% Samples enough data from Orig to create generator state, State2,
% leaving the original generator in Next. If generator states represent
% streams of random numbers, then you can think of it as sampling a whole
% stream of values instead of just one value.
% Note: New is likely to point to a new point in the original stream
% far away from Orig and State2, simply because the period of the generator
% is so large (about 2^191) but there is no guarantee of this. Therefore, it's 
% possible (but unlikely) that New might produce a stream that overlaps
% significantly with samples drawn from Next. If you need to be sure
% that New is a long way from Orig and Next, then use jump/3 instead.
%
% @param Orig is the state of the source generator, the original stream.
% @param New is the state of the new generator.
% @param Next is the state of the source generator after extracting New.
%
% @see jump/3

tuple_functor(\).
tuple_functor(tuple).
tuple_functor(vec).

seqmapargs(P,N,X1) -->
	(	{N>0}
	->	{succ(M,N), arg(N,X1,X1N)},
		call(P,X1N),
		seqmapargs(P,M,X1)
	;	[]
	).
seqmapargs(P,N,X1,X2) -->
	(	{N>0}
	->	{succ(M,N), arg(N,X1,X1N), arg(N,X2,X2N)},
		call(P,X1N,X2N),
		seqmapargs(P,M,X1,X2)
	;	[]
	).

seqmap(_,[])             --> [].
seqmap(P,[A|AX])         --> call(P,A), seqmap(P,AX).
seqmap(_,[],[])          --> [].
seqmap(P,[A|AX],[B|BX])  --> call(P,A,B), seqmap(P,AX,BX).



%% str_init(+State,+E:integer,-Stream) is det.
%
% [EXPERIMENTAL] Used to initialise a splittable stream. The idea is that
% a stream of random number can be recursively split into many substreams,
% each of which will not overlap unless very many samples are drawn.
%
% For example,
% ?- init_rnd_state(S), str_init(S,76,Z).
% produces a stream such that the first split jumps 2^76 steps along.
% Thus, 2^76 samples can be drawn from either before overlap occurs.
% Since the period of the generator is more than 2^190, there can be up to
% 114 levels of recursive splitting, resulting in up to 2^114 substreams.
%
% @param State is the initial state of the generator.
% @param E is a nonnegative integer such that, when the stream is first split,
%        the new substream begins 2^E steps along the sequence. Subsequent splits
%        will jump twice as far each time.
% @param Stream is the new stream (a term)

str_init(State,E,str(State,Jump)) :- init_jump(E,Jump).

%% str_split(+Stream1,-Stream2,-Stream3) is det.
%
%  [EXPERIMENTAL] Stream1 is split into independent streams Stream2 and Stream3.
%  Stream2 actually contains the same numbers as Stream1, but Stream3
%  starts a long way down the sequence.
%
%  @see jump/3
str_split(str(S0,J0),str(S0,J1),str(S1,J1)) :- !,
	jump(J0,S0,S1), double_jump(J0,J1).

%% str_sample(+Dist,-Value,+Stream,-Stream) is det.
%
% [EXPERIMENTAL] Much like sample/4 but for splittable streams. Overrides 
% implementation of dirichlet_process to use splitting instead of spawning.

str_sample( dirichlet_process(Alpha,Src),dp(Vals),str(S1,J1),str(S2,J2)) :-
	str_split(str(S1,J1),str(S2,J2),str(S3,_)),
	freeze(Vals,plrand:unfold_dp(Alpha,Src,S3,Vals)).

str_sample(Dist,Value,str(S0,J),str(S1,J)) :- 
	sample(Dist,Value,S0,S1), !.



%% sample_Single_(-Float) is det.
%
% Samples a single precision floating point value in [0,1) using the 
% internal (global) random generator state. It consumes one 32 bit value
% from the generator. This is the fasted way to generate a random value 
% using this library.

%% sample_Double_(-Float) is det.
%
% Samples a double precision floating point value in [0,1) using the 
% internal (global) random generator state. It consumes two 32 bit values
% from the generator.
