1:- module(egraph, [add_term//2, add_terms//2, union//2, saturate//1,
    2                   saturate//2, extract//2, extract_all//2, lookup/2, query//1]).

E-graph implementation for term rewriting and saturation

This module implements an E-graph (Equivalence Graph) data structure, commonly used for efficient term rewriting, congruence closure, and e-matching. The E-graph state is typically threaded through operations using DCG notation.

Rewrite rules are automatically compiled into efficient DCG predicates via term expansion. See the egraph_compile module for full details. The supported rule declarations are:

Main predicates:

*/

   40:- use_module(library(dcg/high_order)).   41:- use_module(library(ordsets)).   42:- use_module(library(rbtrees)).   43:- use_module(library(heaps)).   44
   45:- use_module(egraph/compile).
 lookup(+Pair, +SortedPairs) is semidet
Retrieves a value from a sorted list of pairs using standard term comparison. The search is unrolled for performance. Adapted from ord_memberchk/2.
Arguments:
Pair- A Key-Value pair where Key is the target key to find, and Value is unified with the associated value.
SortedPairs- A list of Key-Value pairs sorted by Key.
   55lookup(Item-V, [X1-V1, X2-V2, X3-V3, X4-V4|Xs]) :-
   56   !,
   57   compare(R4, Item, X4),
   58   (   R4=(>)
   59   ->  lookup(Item-V, Xs)
   60   ;   R4=(<)
   61   ->  compare(R2, Item, X2),
   62      (   R2=(>)
   63      ->  Item==X3, V = V3
   64      ;   R2=(<)
   65      ->  Item==X1, V = V1
   66      ;   V = V2
   67      )
   68   ;   V = V4
   69   ).
   70lookup(Item-V, [X1-V1, X2-V2|Xs]) :-
   71   !,
   72   compare(R2, Item, X2),
   73   (   R2=(>)
   74   ->  lookup(Item-V, Xs)
   75   ;   R2=(<)
   76   ->  Item==X1, V = V1
   77   ;   V = V2
   78   ).
   79lookup(Item-V, [X1-V1]) :-
   80   Item==X1, V = V1.
 add_term(+Term, -Id)// is det
Adds a term to the E-graph, returning its e-class ID. Compound terms are recursively traversed and their arguments are added to the E-graph first.
Arguments:
Term- The term to be added.
Id- The e-class ID representing the added term.
   91add_term(Term, Node) -->
   92   add_term(Term, Node, [cost(1)]).
   93add_term(Var, Id, Opt), var(Var) ==>
   94   { option(var(What), Opt, node) },
   95   (  { What == node }
   96   -> add_node(Var, Id, Opt)
   97   ;  { What == class }
   98   -> {  (  option(mark(true), Opt)
   99         -> put_attr(Id, egraph, true)
  100         ;  true
  101         ),
  102         Var=Id
  103      }
  104   ;  { domain_error(node-class, What) }
  105   ).
  106add_term('$NODE'(Node), Id, Opt) ==>
  107   add_node(Node, Id, Opt).
  108add_term(Term, Id, Opt), is_dict(Term, Tag) ==>
  109   {
  110      dict_pairs(Term, Tag, Pairs),
  111      pairs_keys_values(Pairs, Keys, Values),
  112      pairs_keys_values(Data, Keys, Ids),
  113      dict_create(Node, Tag, Data)
  114   },
  115   add_terms(Values, Ids, Opt),
  116   add_node(Node, Id, Opt).
  117add_term(Term, Id, Opt), compound(Term) ==>
  118   { Term =.. [F | Args] },
  119   add_terms(Args, Ids, Opt),
  120   { Node =.. [F | Ids] },
  121   add_node(Node, Id, Opt).
  122add_term(Term, Id, Opt) ==>
  123   add_node(Term, Id, Opt).
  124
  125add_terms([], [], _Opt) --> [].
  126add_terms([Term | Terms], [Id | Ids], Opt) -->
  127   add_term(Term, Id, Opt),
  128   add_terms(Terms, Ids, Opt).
  129
  130add_terms([], _Opt) --> [].
  131add_terms([Term-Id | Terms], Opt) -->
  132   add_term(Term, Id, Opt),
  133   add_terms(Terms, Opt).
  134
  135add_node(Node-Id, Opt, In, Out) :-
  136   add_node(Node, Id, Opt, In, Out).
  137add_node(Node, Id, Opt, In, Out) :-
  138   (  lookup(Node-node(Id, _Cost), In)
  139   -> Out = In
  140   ;  option(nodes(Nodes), Opt), lookup(Node-node(Id, _Cost), Nodes)
  141   -> Out = In
  142   ;  (  member(cost(N, Cost), Opt), N == Node
  143      ;  option(cost(Cost), Opt, 1)
  144      ),
  145      !,
  146      must_be(number, Cost),
  147      (  option(mark(true), Opt)
  148      -> put_attr(Id, egraph, true)
  149      ;  true
  150      ),
  151      ord_add_element(In, Node-node(Id, Cost), Out)
  152   ).
  153
  154rules([Rule | Rules], M, EGraph, Index, Parents, MinId, UnifsIn, UnifsOut) -->
  155   { strip_module(M:Rule, Mod, Name) },
  156   call(Mod:Name, '$empty'-node(_, 0), state([], EGraph, Index, Parents, MinId), UnifsIn, UnifsTmp),
  157   rules(Rules, M, EGraph, Index, Parents, MinId, UnifsTmp, UnifsOut).
  158rules([], _, _, _, _, _, Unifs, Unifs) --> [].
  159
  160make_index(In, Index) :-
  161   index_pairs(In, UnsortedPairs),
  162   sort(UnsortedPairs, IdPairs),
  163   group_pairs_by_key(IdPairs, Groups),
  164   ord_list_to_rbtree(Groups, Index).
  165
  166index_pairs([], []).
  167index_pairs([Node-node(Id, Cost)|T0], [Id-(Node-node(Id, Cost))|T1]) :-
  168   index_pairs(T0, T1).
  169
  170make_parents(In, Parents) :-
  171   phrase(parent_pairs(In), UnsortedPairs),
  172   sort(UnsortedPairs, IdPairs),
  173   group_pairs_by_key(IdPairs, Groups),
  174   ord_list_to_rbtree(Groups, Parents).
  175
  176parent_pairs([]) ==> [].
  177parent_pairs([Node-node(Id, Cost) | L]), is_dict(Node) ==>
  178   { dict_pairs(Node, Tag, Pairs) },
  179   dict_pairs(Pairs, Node-node(Id, Cost), Tag),
  180   parent_pairs(L).
  181parent_pairs([Node-node(Id, Cost) | L]), compound(Node) ==>
  182   { compound_name_arguments(Node, Name, Args) },
  183   arg_pairs(Args, Node-node(Id, Cost), Name, _Arity, 0),
  184   parent_pairs(L).
  185parent_pairs([_ | L]) ==>
  186   parent_pairs(L).
  187
  188dict_pairs([], _, _) --> [].
  189dict_pairs([_Key-Value | Pairs], Node, Tag) -->
  190   [Value-Node],
  191   dict_pairs(Pairs, Node, Tag).
  192
  193arg_pairs([], _Node, _Name, Arity, Arity) --> [].
  194arg_pairs([Arg | Args], Node, Name, Arity, I) -->
  195   [Arg-Node],
  196   { I1 is I+1 },
  197   arg_pairs(Args, Node, Name, Arity, I1).
 union(+Id1, +Id2)// is det
Merges two e-classes by unifying their IDs and merging their underlying nodes.
Arguments:
Id1- The first e-class ID.
Id2- The second e-class ID.
  206union(A, A) -->
  207   merge_nodes.
  208
  209merge_nodes(In, Out) :-
  210   sort(In, Sort),
  211   group_pairs_by_key(Sort, Groups),
  212   merge_groups(Groups, Tmp, false, Merged),
  213   (  Merged == true
  214   -> merge_nodes(Tmp, Out)
  215   ;  Out = Sort
  216   ).
  217
  218merge_groups([Sig-[H | T] | Nodes], [Sig-Node | Worklist], In, Out) :-
  219   merge_group(T, H, Node),
  220   (  T == []
  221   -> Tmp = In
  222   ;  Tmp = true
  223   ),
  224   merge_groups(Nodes, Worklist, Tmp, Out).
  225merge_groups([], [], In, In).
  226
  227merge_group([], Node, Node).
  228merge_group([node(Id, Cost) | T], node(Id, PrevCost), Out) :-
  229   MinCost is min(Cost, PrevCost),
  230   (  MinCost < PrevCost
  231   -> b_setval(egraph_changed, true)
  232   ;  true
  233   ),
  234   merge_group(T, node(Id, MinCost), Out).
  235
  236apply_unifs([]).
  237apply_unifs([A=B | L]) :-
  238   A = B,
  239   apply_unifs(L).
  240
  241rebuild(Matches, Unifs, Out) :-
  242   apply_unifs(Unifs),
  243   merge_nodes(Matches, Out).
  244
  245:- meta_predicate saturate(:, ?, ?).  246:- meta_predicate saturate(:, +, ?, ?).
 saturate(+Rules)// is det
Applies a list of compiled rewrite rules to the E-graph until saturation is reached.
Arguments:
Rules- A list of compiled rewrite rule names to apply.
 saturate(+Rules, +N)// is det
Applies a list of compiled rewrite rules to the E-graph up to N times or until saturation is reached.
Arguments:
Rules- A list of compiled rewrite rule names to apply.
N- The maximum number of iterations (or inf for no limit).
  263saturate(M:Rules) -->
  264   saturate(M:Rules, inf).
  265saturate(M:Rules, N, In, Out) :-
  266   (  N > 0
  267   -> b_setval(egraph_changed, false),
  268      run_rules(Rules, M, In, Matches, In, Unifs),
  269      rebuild(Matches, Unifs, Tmp),
  270      length(In, Len1),
  271      length(Tmp, Len2),
  272      b_getval(egraph_changed, Changed),
  273      debug(saturate, "~p", [Len1-Len2-Changed]),
  274      (  (Len1 \== Len2 ; Changed == true)
  275      -> (  N == inf
  276         -> N1 = N
  277         ;  N1 is N - 1
  278         ),
  279         saturate(M:Rules, N1, Tmp, Out)
  280      ;  Out = Tmp
  281      )
  282   ;  Out = In
  283   ).
  284
  285run_rules(Rules, M, In, Matches, Tail, Unifs) :-
  286   make_index(In, Index),
  287   make_parents(In, Parents),
  288   (  rb_min(Index, MinId, _) -> true ; MinId = 0 ),
  289   phrase(rules(Rules, M, In, Index, Parents, MinId, Unifs, []), Matches, Tail).
 query(?Pattern)// is multi
Queries the E-graph and dynamically binds pattern variables. Upon success, the variables in the query are bound and the Pattern is unified with the complete extracted matching term from the E-graph. On backtracking, Pattern will be bound to all possible representations of the matched equivalence class in increasing order of cost.
Arguments:
Pattern- The term pattern to search for, which is unified with the fully extracted match.
  301query(Pattern, In, Out) :-
  302   copy_term(Pattern, PatternCopy),
  303   variant_sha1(PatternCopy, Sha),
  304   atom_concat('query_', Sha, RuleName),
  305   (  current_predicate(RuleName/_)
  306   -> true
  307   ;  quote_vars(PatternCopy, QuotedPat),
  308      Rule = rule(RuleName, [QuotedPat-Id], [], [query(Id)-_], []),
  309      phrase(egraph_compile:compile(Rule :- true), Clauses),
  310      maplist(maplist(assert_ssu), Clauses)
  311   ),
  312   run_rules([RuleName], egraph, In, Matches, [], _Unifs),
  313   pairs_keys(Matches, Queries),
  314   sort(Queries, SortedQueries),
  315   member(query(MatchId), SortedQueries),
  316   extract_all(MatchId, Pattern, In, Out).
  317
  318assert_ssu((Head, Guard) => Body) =>
  319   assertz(?=>(Head, (Guard, !, Body))).
  320assert_ssu(Head => Body) =>
  321   assertz(Head => Body).
  322assert_ssu(_) => true.
  323
  324quote_vars(Var, '$NODE'(Var)) :- var(Var), !.
  325quote_vars(Dict, Quoted) :- is_dict(Dict), !,
  326   dict_pairs(Dict, Tag, Pairs),
  327   maplist(quote_pair, Pairs, QuotedPairs),
  328   dict_pairs(Quoted, Tag, QuotedPairs).
  329quote_vars(Compound, Quoted) :- compound(Compound), !,
  330   compound_name_arguments(Compound, Name, Args),
  331   maplist(quote_vars, Args, QuotedArgs),
  332   compound_name_arguments(Quoted, Name, QuotedArgs).
  333quote_vars(Atomic, Atomic).
  334
  335quote_pair(K-V, K-QV) :- quote_vars(V, QV).
 extract(Id, Extracted)// is det
Extracts the optimal term from the E-graph based on term costs.
Arguments:
Id- The eclass Id to be extracted as returned by add_term
Extracted- the extracted term
  344extract(Target, Extracted, EGraph, EGraph) :-
  345   current_prolog_flag(float_overflow, Flag),
  346   setup_call_cleanup(
  347      set_prolog_flag(float_overflow, infinity),
  348      (  dijkstra(Target, EGraph, Costs),
  349         extract_class(Costs, Target, Extracted)
  350      ),
  351      set_prolog_flag(float_overflow, Flag)
  352   ).
  353extract_class(Costs, Target, Extracted) :-
  354   rb_lookup(Target, _-Node, Costs),
  355   extract_node(Costs, Node, Extracted).
  356extract_node(Costs, Dict, R), is_dict(Dict) =>
  357   dict_pairs(Dict, Tag, Pairs),
  358   pairs_keys_values(Pairs, Keys, Classes),
  359   pairs_keys_values(NewPairs, Keys, Values),
  360   dict_pairs(R, Tag, NewPairs),
  361   maplist(extract_class(Costs), Classes, Values).
  362extract_node(Costs, Compound, R), compound(Compound) =>
  363   compound_name_arguments(Compound, Name, Classes),
  364   same_length(Classes, Values),
  365   compound_name_arguments(R, Name, Values),
  366   maplist(extract_class(Costs), Classes, Values).
  367extract_node(_, Atomic, R) =>
  368   R = Atomic.
  369
  370dijkstra(Target, EGraph, CostsOut) :-
  371   empty_heap(HeapIn),
  372   rb_new(EmptyCosts),
  373   setup(EGraph, ParentPairs, EmptyCosts, CostsIn, HeapIn, HeapOut),
  374   keysort(ParentPairs, SortedParentPairs),
  375   group_pairs_by_key(SortedParentPairs, GroupedParentPairs),
  376   ord_list_to_rbtree(GroupedParentPairs, Parents),
  377   dijkstra(Target, Parents, HeapOut, CostsIn, CostsOut).
  378dijkstra(Target, Parents, HeapIn, CostsIn, CostsOut) :-
  379   (  get_from_heap(HeapIn, CurrentCost, Class, HeapTmp)
  380   -> (  Class == Target
  381      -> CostsOut = CostsIn
  382      ;  rb_lookup(Class, ClassCost-_, CostsIn),
  383         (  CurrentCost > ClassCost
  384         -> dijkstra(Target, Parents, HeapTmp, CostsIn, CostsOut)
  385         ;  (  rb_lookup(Class, ClassParents, Parents)
  386            -> true
  387            ;  ClassParents = []
  388            ),
  389            update_parents(ClassParents, CostsIn, CostsTmp, HeapTmp, HeapOut),
  390            dijkstra(Target, Parents, HeapOut, CostsTmp, CostsOut)
  391         )
  392      )
  393   ;  CostsOut = CostsIn
  394   ).
  395update_parents([], Costs, Costs, Heap, Heap).
  396update_parents([ParentNode-node(ParentClass, ParentCost) | Parents], CostsIn, CostsOut, HeapIn, HeapOut) :-
  397   (  is_dict(ParentNode)
  398   -> dict_pairs(ParentNode, _, KeysValues),
  399      pairs_values(KeysValues, ChildClasses)
  400   ;  compound(ParentNode)
  401   -> compound_name_arguments(ParentNode, _, ChildClasses)
  402   ;  ChildClasses = []
  403   ),
  404   compute_cost(ChildClasses, CostsIn, ParentCost, Cost),
  405   (  rb_lookup(ParentClass, CurrentCost-_, CostsIn)
  406   -> true
  407   ;  CurrentCost = inf
  408   ),
  409   (  Cost < CurrentCost
  410   -> rb_insert(CostsIn, ParentClass, Cost-ParentNode, CostsTmp),
  411      add_to_heap(HeapIn, Cost, ParentClass, HeapTmp)
  412   ;  CostsTmp = CostsIn, HeapTmp = HeapIn
  413   ),
  414   update_parents(Parents, CostsTmp, CostsOut, HeapTmp, HeapOut).
  415
  416   
  417compute_cost([], _, Cost, Cost).
  418compute_cost([Child | Childs], Costs, CostIn, CostOut) :-
  419   (  rb_lookup(Child, ChildCost-_, Costs)
  420   -> true
  421   ;  ChildCost = inf
  422   ),
  423   CostTmp is CostIn + ChildCost,
  424   compute_cost(Childs, Costs, CostTmp, CostOut).
  425
  426setup([], [], Cost, Cost, Heap, Heap).
  427setup([Node-node(ClassId, NodeCost) | Nodes], ParentsIn, CostIn, CostOut, HeapIn, HeapOut) :-
  428   (  is_dict(Node)
  429   -> dict_pairs(Node, _, KeysValues),
  430      pairs_values(KeysValues, ChildClasses)
  431   ;  compound(Node)
  432   -> compound_name_arguments(Node, _, ChildClasses)
  433   ;  ChildClasses = []
  434   ),
  435   (  ChildClasses == []
  436   -> ParentsOut = ParentsIn,
  437      (  (rb_lookup(ClassId, CurCost-_, CostIn) ; CurCost = inf), NodeCost < CurCost
  438      -> rb_insert(CostIn, ClassId, NodeCost-Node, CostTmp),
  439         add_to_heap(HeapIn, NodeCost, ClassId, HeapTmp)
  440      ;  CostTmp = CostIn, HeapTmp = HeapIn
  441      )
  442   ;  insert_parent(ChildClasses, Node-node(ClassId, NodeCost), ParentsIn, ParentsOut),
  443      CostTmp = CostIn, HeapTmp = HeapIn
  444   ),
  445   setup(Nodes, ParentsOut, CostTmp, CostOut, HeapTmp, HeapOut).
  446
  447insert_parent([], _, Parents, Parents).
  448insert_parent([ChildClass | ChildClasses], Node, [ChildClass-Node | ParentsTmp], ParentsOut) :-
  449   insert_parent(ChildClasses, Node, ParentsTmp, ParentsOut).
  450
  451extract_all(Target, Extracted, EGraph, EGraph) :-
  452   current_prolog_flag(float_overflow, Flag),
  453   setup_call_cleanup(
  454      set_prolog_flag(float_overflow, infinity),
  455      (  dijkstra(_, EGraph, Costs),
  456         extract_all_(EGraph, Costs, Target, Extracted)
  457      ),
  458      set_prolog_flag(float_overflow, Flag)
  459   ).
  460
  461extract_all_index(EGraph, Index) :-
  462   extract_pairs(EGraph, UnsortedPairs),
  463   keysort(UnsortedPairs, IdPairs),
  464   group_pairs_by_key(IdPairs, Groups),
  465   ord_list_to_rbtree(Groups, Index).
  466
  467extract_pairs([], []).
  468extract_pairs([Node-node(Id, Cost)|T0], [Id-(Node-Cost)|T1]) :-
  469   extract_pairs(T0, T1).
  470
  471extract_all_(Egraph, Costs, Target, Extracted) :-
  472   empty_heap(HeapIn),
  473   rb_lookup(Target, H-_, Costs),
  474   State = state(0, [], [Target]),
  475   add_to_heap(HeapIn, H, State, HeapOut),
  476   extract_all_index(Egraph, Index),
  477   extract_all__(Index, Costs, HeapOut, Unifs),
  478   reverse(Unifs, PreOrder),
  479   build_term(PreOrder, [], Extracted).
  480
  481build_term([_-Value | UnifsIn], UnifsOut, Term) :-
  482   (  var(Value)
  483   -> Term = Value,
  484      UnifsOut = UnifsIn
  485   ;  is_dict(Value)
  486   -> dict_pairs(Value, Tag, Pairs),
  487      pairs_keys_values(Pairs, Keys, ChildClasses),
  488      build_terms(ChildClasses, UnifsIn, UnifsOut, Values),
  489      pairs_keys_values(NewPairs, Keys, Values),
  490      dict_pairs(Term, Tag, NewPairs)
  491   ;  compound(Value)
  492   -> compound_name_arguments(Value, Name, ChildClasses),
  493      build_terms(ChildClasses, UnifsIn, UnifsOut, Values),
  494      compound_name_arguments(Term, Name, Values)
  495   ;  Term = Value,
  496      UnifsOut = UnifsIn
  497   ).
  498
  499build_terms([], Unifs, Unifs, []).
  500build_terms([_|Cs], UnifsIn, UnifsOut, [V|Vs]) :-
  501   build_term(UnifsIn, UnifsTmp, V),
  502   build_terms(Cs, UnifsTmp, UnifsOut, Vs).
  503
  504extract_all__(Index, Costs, HeapIn, Unifs) :-
  505   (  get_from_heap(HeapIn, _, state(G, PartialTerm, PendingHoles), HeapTmp)
  506   -> (  PendingHoles == []
  507      -> (  Unifs = PartialTerm
  508         ;  extract_all__(Index, Costs, HeapTmp, Unifs)
  509         )
  510      ;  [Hole | RestHoles] = PendingHoles,
  511         rb_lookup(Hole, Nodes, Index),
  512         extract_all_childs(Nodes, Costs, G, PartialTerm, Hole, RestHoles, HeapTmp, HeapOut),
  513         extract_all__(Index, Costs, HeapOut, Unifs)
  514      )
  515   ).
  516
  517extract_all_childs([], _, _, _, _, _, Heap, Heap).
  518extract_all_childs([Node-NodeCost | Nodes], Costs, G, PartialTerm, Hole, RestHoles, HeapIn, HeapOut) :-
  519   (  is_dict(Node)
  520   -> dict_pairs(Node, _, KeysValues),
  521      pairs_values(KeysValues, ChildClasses)
  522   ;  compound(Node)
  523   -> compound_name_arguments(Node, _, ChildClasses)
  524   ;  ChildClasses = []
  525   ),
  526   NewPartialTerm = [Hole-Node | PartialTerm],
  527   append(ChildClasses, RestHoles, NewHoles),
  528
  529   NewG is G + NodeCost,
  530   foldl(sum_costs(Costs), NewHoles, 0, H),
  531   F is NewG + H,
  532   add_to_heap(HeapIn, F, state(G, NewPartialTerm, NewHoles), HeapTmp),
  533   extract_all_childs(Nodes, Costs, G, PartialTerm, Hole, RestHoles, HeapTmp, HeapOut).
  534
  535sum_costs(Costs, Hole, In, Out) :-
  536   rb_lookup(Hole, Cost-_, Costs),
  537   Out is In + Cost.
  538
  539:- begin_tests(egraph_add_terms).  540
  541test_term(Var, [Var-node(_, 1)]).
  542test_term('$NODE'(Var), [Var-node(_, 1)]).
  543test_term(tag{k1: v1, k2: v2}, [v1-node(I1, 1), v2-node(I2, 1), tag{k1: I1, k2: I2}-node(_, 1)]).
  544test_term(f(arg1, arg2), [arg1-node(I1, 1), arg2-node(I2, 1), f(I1, I2)-node(_, 1)]).
  545test_term(simple_atom, [simple_atom-node(_, 1)]).
  546
  547test(add_term, [forall(test_term(Term, Expected)), OutNodes =@= Expected]) :-
  548   phrase(add_term(Term, _Id, [var(node)]), [], OutNodes).
  549
  550:- end_tests(egraph_add_terms).