1:- module(egraph, [add_term//2, add_terms//2, union//2, saturate//1,
2 saturate//2, extract//2, extract_all//2, lookup/2, query//1]).
40:- use_module(library(dcg/high_order)). 41:- use_module(library(ordsets)). 42:- use_module(library(rbtrees)). 43:- use_module(library(heaps)). 44
45:- use_module(egraph/compile).
55lookup(Item-V, [X1-V1, X2-V2, X3-V3, X4-V4|Xs]) :-
56 !,
57 compare(R4, Item, X4),
58 ( R4=(>)
59 -> lookup(Item-V, Xs)
60 ; R4=(<)
61 -> compare(R2, Item, X2),
62 ( R2=(>)
63 -> Item==X3, V = V3
64 ; R2=(<)
65 -> Item==X1, V = V1
66 ; V = V2
67 )
68 ; V = V4
69 ).
70lookup(Item-V, [X1-V1, X2-V2|Xs]) :-
71 !,
72 compare(R2, Item, X2),
73 ( R2=(>)
74 -> lookup(Item-V, Xs)
75 ; R2=(<)
76 -> Item==X1, V = V1
77 ; V = V2
78 ).
79lookup(Item-V, [X1-V1]) :-
80 Item==X1, V = V1.
91add_term(Term, Node) -->
92 add_term(Term, Node, [cost(1)]).
93add_term(Var, Id, Opt), var(Var) ==>
94 { option(var(What), Opt, node) },
95 ( { What == node }
96 -> add_node(Var, Id, Opt)
97 ; { What == class }
98 -> { ( option(mark(true), Opt)
99 -> put_attr(Id, egraph, true)
100 ; true
101 ),
102 Var=Id
103 }
104 ; { domain_error(node-class, What) }
105 ).
106add_term('$NODE'(Node), Id, Opt) ==>
107 add_node(Node, Id, Opt).
108add_term(Term, Id, Opt), is_dict(Term, Tag) ==>
109 {
110 dict_pairs(Term, Tag, Pairs),
111 pairs_keys_values(Pairs, Keys, Values),
112 pairs_keys_values(Data, Keys, Ids),
113 dict_create(Node, Tag, Data)
114 },
115 add_terms(Values, Ids, Opt),
116 add_node(Node, Id, Opt).
117add_term(Term, Id, Opt), compound(Term) ==>
118 { Term =.. [F | Args] },
119 add_terms(Args, Ids, Opt),
120 { Node =.. [F | Ids] },
121 add_node(Node, Id, Opt).
122add_term(Term, Id, Opt) ==>
123 add_node(Term, Id, Opt).
124
125add_terms([], [], _Opt) --> [].
126add_terms([Term | Terms], [Id | Ids], Opt) -->
127 add_term(Term, Id, Opt),
128 add_terms(Terms, Ids, Opt).
129
130add_terms([], _Opt) --> [].
131add_terms([Term-Id | Terms], Opt) -->
132 add_term(Term, Id, Opt),
133 add_terms(Terms, Opt).
134
135add_node(Node-Id, Opt, In, Out) :-
136 add_node(Node, Id, Opt, In, Out).
137add_node(Node, Id, Opt, In, Out) :-
138 ( lookup(Node-node(Id, _Cost), In)
139 -> Out = In
140 ; option(nodes(Nodes), Opt), lookup(Node-node(Id, _Cost), Nodes)
141 -> Out = In
142 ; ( member(cost(N, Cost), Opt), N == Node
143 ; option(cost(Cost), Opt, 1)
144 ),
145 !,
146 must_be(number, Cost),
147 ( option(mark(true), Opt)
148 -> put_attr(Id, egraph, true)
149 ; true
150 ),
151 ord_add_element(In, Node-node(Id, Cost), Out)
152 ).
153
154rules([Rule | Rules], M, EGraph, Index, Parents, MinId, UnifsIn, UnifsOut) -->
155 { strip_module(M:Rule, Mod, Name) },
156 call(Mod:Name, '$empty'-node(_, 0), state([], EGraph, Index, Parents, MinId), UnifsIn, UnifsTmp),
157 rules(Rules, M, EGraph, Index, Parents, MinId, UnifsTmp, UnifsOut).
158rules([], _, _, _, _, _, Unifs, Unifs) --> [].
159
160make_index(In, Index) :-
161 index_pairs(In, UnsortedPairs),
162 sort(UnsortedPairs, IdPairs),
163 group_pairs_by_key(IdPairs, Groups),
164 ord_list_to_rbtree(Groups, Index).
165
166index_pairs([], []).
167index_pairs([Node-node(Id, Cost)|T0], [Id-(Node-node(Id, Cost))|T1]) :-
168 index_pairs(T0, T1).
169
170make_parents(In, Parents) :-
171 phrase(parent_pairs(In), UnsortedPairs),
172 sort(UnsortedPairs, IdPairs),
173 group_pairs_by_key(IdPairs, Groups),
174 ord_list_to_rbtree(Groups, Parents).
175
176parent_pairs([]) ==> [].
177parent_pairs([Node-node(Id, Cost) | L]), is_dict(Node) ==>
178 { dict_pairs(Node, Tag, Pairs) },
179 dict_pairs(Pairs, Node-node(Id, Cost), Tag),
180 parent_pairs(L).
181parent_pairs([Node-node(Id, Cost) | L]), compound(Node) ==>
182 { compound_name_arguments(Node, Name, Args) },
183 arg_pairs(Args, Node-node(Id, Cost), Name, _Arity, 0),
184 parent_pairs(L).
185parent_pairs([_ | L]) ==>
186 parent_pairs(L).
187
188dict_pairs([], _, _) --> [].
189dict_pairs([_Key-Value | Pairs], Node, Tag) -->
190 [Value-Node],
191 dict_pairs(Pairs, Node, Tag).
192
193arg_pairs([], _Node, _Name, Arity, Arity) --> [].
194arg_pairs([Arg | Args], Node, Name, Arity, I) -->
195 [Arg-Node],
196 { I1 is I+1 },
197 arg_pairs(Args, Node, Name, Arity, I1).
206union(A, A) -->
207 merge_nodes.
208
209merge_nodes(In, Out) :-
210 sort(In, Sort),
211 group_pairs_by_key(Sort, Groups),
212 merge_groups(Groups, Tmp, false, Merged),
213 ( Merged == true
214 -> merge_nodes(Tmp, Out)
215 ; Out = Sort
216 ).
217
218merge_groups([Sig-[H | T] | Nodes], [Sig-Node | Worklist], In, Out) :-
219 merge_group(T, H, Node),
220 ( T == []
221 -> Tmp = In
222 ; Tmp = true
223 ),
224 merge_groups(Nodes, Worklist, Tmp, Out).
225merge_groups([], [], In, In).
226
227merge_group([], Node, Node).
228merge_group([node(Id, Cost) | T], node(Id, PrevCost), Out) :-
229 MinCost is min(Cost, PrevCost),
230 ( MinCost < PrevCost
231 -> b_setval(egraph_changed, true)
232 ; true
233 ),
234 merge_group(T, node(Id, MinCost), Out).
235
236apply_unifs([]).
237apply_unifs([A=B | L]) :-
238 A = B,
239 apply_unifs(L).
240
241rebuild(Matches, Unifs, Out) :-
242 apply_unifs(Unifs),
243 merge_nodes(Matches, Out).
244
245:- meta_predicate saturate(:, ?, ?). 246:- meta_predicate saturate(:, +, ?, ?).
263saturate(M:Rules) -->
264 saturate(M:Rules, inf).
265saturate(M:Rules, N, In, Out) :-
266 ( N > 0
267 -> b_setval(egraph_changed, false),
268 run_rules(Rules, M, In, Matches, In, Unifs),
269 rebuild(Matches, Unifs, Tmp),
270 length(In, Len1),
271 length(Tmp, Len2),
272 b_getval(egraph_changed, Changed),
273 debug(saturate, "~p", [Len1-Len2-Changed]),
274 ( (Len1 \== Len2 ; Changed == true)
275 -> ( N == inf
276 -> N1 = N
277 ; N1 is N - 1
278 ),
279 saturate(M:Rules, N1, Tmp, Out)
280 ; Out = Tmp
281 )
282 ; Out = In
283 ).
284
285run_rules(Rules, M, In, Matches, Tail, Unifs) :-
286 make_index(In, Index),
287 make_parents(In, Parents),
288 ( rb_min(Index, MinId, _) -> true ; MinId = 0 ),
289 phrase(rules(Rules, M, In, Index, Parents, MinId, Unifs, []), Matches, Tail).
301query(Pattern, In, Out) :-
302 copy_term(Pattern, PatternCopy),
303 variant_sha1(PatternCopy, Sha),
304 atom_concat('query_', Sha, RuleName),
305 ( current_predicate(RuleName/_)
306 -> true
307 ; quote_vars(PatternCopy, QuotedPat),
308 Rule = rule(RuleName, [QuotedPat-Id], [], [query(Id)-_], []),
309 phrase(egraph_compile:compile(Rule :- true), Clauses),
310 maplist(maplist(assert_ssu), Clauses)
311 ),
312 run_rules([RuleName], egraph, In, Matches, [], _Unifs),
313 pairs_keys(Matches, Queries),
314 sort(Queries, SortedQueries),
315 member(query(MatchId), SortedQueries),
316 extract_all(MatchId, Pattern, In, Out).
317
318assert_ssu((Head, Guard) => Body) =>
319 assertz(?=>(Head, (Guard, !, Body))).
320assert_ssu(Head => Body) =>
321 assertz(Head => Body).
322assert_ssu(_) => true.
323
324quote_vars(Var, '$NODE'(Var)) :- var(Var), !.
325quote_vars(Dict, Quoted) :- is_dict(Dict), !,
326 dict_pairs(Dict, Tag, Pairs),
327 maplist(quote_pair, Pairs, QuotedPairs),
328 dict_pairs(Quoted, Tag, QuotedPairs).
329quote_vars(Compound, Quoted) :- compound(Compound), !,
330 compound_name_arguments(Compound, Name, Args),
331 maplist(quote_vars, Args, QuotedArgs),
332 compound_name_arguments(Quoted, Name, QuotedArgs).
333quote_vars(Atomic, Atomic).
334
335quote_pair(K-V, K-QV) :- quote_vars(V, QV).
344extract(Target, Extracted, EGraph, EGraph) :-
345 current_prolog_flag(float_overflow, Flag),
346 setup_call_cleanup(
347 set_prolog_flag(float_overflow, infinity),
348 ( dijkstra(Target, EGraph, Costs),
349 extract_class(Costs, Target, Extracted)
350 ),
351 set_prolog_flag(float_overflow, Flag)
352 ).
(Costs, Target, Extracted) :-
354 rb_lookup(Target, _-Node, Costs),
355 extract_node(Costs, Node, Extracted).
(Costs, Dict, R), is_dict(Dict) =>
357 dict_pairs(Dict, Tag, Pairs),
358 pairs_keys_values(Pairs, Keys, Classes),
359 pairs_keys_values(NewPairs, Keys, Values),
360 dict_pairs(R, Tag, NewPairs),
361 maplist(extract_class(Costs), Classes, Values).
362extract_node(Costs, Compound, R), compound(Compound) =>
363 compound_name_arguments(Compound, Name, Classes),
364 same_length(Classes, Values),
365 compound_name_arguments(R, Name, Values),
366 maplist(extract_class(Costs), Classes, Values).
367extract_node(_, Atomic, R) =>
368 R = Atomic.
369
370dijkstra(Target, EGraph, CostsOut) :-
371 empty_heap(HeapIn),
372 rb_new(EmptyCosts),
373 setup(EGraph, ParentPairs, EmptyCosts, CostsIn, HeapIn, HeapOut),
374 keysort(ParentPairs, SortedParentPairs),
375 group_pairs_by_key(SortedParentPairs, GroupedParentPairs),
376 ord_list_to_rbtree(GroupedParentPairs, Parents),
377 dijkstra(Target, Parents, HeapOut, CostsIn, CostsOut).
378dijkstra(Target, Parents, HeapIn, CostsIn, CostsOut) :-
379 ( get_from_heap(HeapIn, CurrentCost, Class, HeapTmp)
380 -> ( Class == Target
381 -> CostsOut = CostsIn
382 ; rb_lookup(Class, ClassCost-_, CostsIn),
383 ( CurrentCost > ClassCost
384 -> dijkstra(Target, Parents, HeapTmp, CostsIn, CostsOut)
385 ; ( rb_lookup(Class, ClassParents, Parents)
386 -> true
387 ; ClassParents = []
388 ),
389 update_parents(ClassParents, CostsIn, CostsTmp, HeapTmp, HeapOut),
390 dijkstra(Target, Parents, HeapOut, CostsTmp, CostsOut)
391 )
392 )
393 ; CostsOut = CostsIn
394 ).
395update_parents([], Costs, Costs, Heap, Heap).
396update_parents([ParentNode-node(ParentClass, ParentCost) | Parents], CostsIn, CostsOut, HeapIn, HeapOut) :-
397 ( is_dict(ParentNode)
398 -> dict_pairs(ParentNode, _, KeysValues),
399 pairs_values(KeysValues, ChildClasses)
400 ; compound(ParentNode)
401 -> compound_name_arguments(ParentNode, _, ChildClasses)
402 ; ChildClasses = []
403 ),
404 compute_cost(ChildClasses, CostsIn, ParentCost, Cost),
405 ( rb_lookup(ParentClass, CurrentCost-_, CostsIn)
406 -> true
407 ; CurrentCost = inf
408 ),
409 ( Cost < CurrentCost
410 -> rb_insert(CostsIn, ParentClass, Cost-ParentNode, CostsTmp),
411 add_to_heap(HeapIn, Cost, ParentClass, HeapTmp)
412 ; CostsTmp = CostsIn, HeapTmp = HeapIn
413 ),
414 update_parents(Parents, CostsTmp, CostsOut, HeapTmp, HeapOut).
415
416
417compute_cost([], _, Cost, Cost).
418compute_cost([Child | Childs], Costs, CostIn, CostOut) :-
419 ( rb_lookup(Child, ChildCost-_, Costs)
420 -> true
421 ; ChildCost = inf
422 ),
423 CostTmp is CostIn + ChildCost,
424 compute_cost(Childs, Costs, CostTmp, CostOut).
425
426setup([], [], Cost, Cost, Heap, Heap).
427setup([Node-node(ClassId, NodeCost) | Nodes], ParentsIn, CostIn, CostOut, HeapIn, HeapOut) :-
428 ( is_dict(Node)
429 -> dict_pairs(Node, _, KeysValues),
430 pairs_values(KeysValues, ChildClasses)
431 ; compound(Node)
432 -> compound_name_arguments(Node, _, ChildClasses)
433 ; ChildClasses = []
434 ),
435 ( ChildClasses == []
436 -> ParentsOut = ParentsIn,
437 ( (rb_lookup(ClassId, CurCost-_, CostIn) ; CurCost = inf), NodeCost < CurCost
438 -> rb_insert(CostIn, ClassId, NodeCost-Node, CostTmp),
439 add_to_heap(HeapIn, NodeCost, ClassId, HeapTmp)
440 ; CostTmp = CostIn, HeapTmp = HeapIn
441 )
442 ; insert_parent(ChildClasses, Node-node(ClassId, NodeCost), ParentsIn, ParentsOut),
443 CostTmp = CostIn, HeapTmp = HeapIn
444 ),
445 setup(Nodes, ParentsOut, CostTmp, CostOut, HeapTmp, HeapOut).
446
447insert_parent([], _, Parents, Parents).
448insert_parent([ChildClass | ChildClasses], Node, [ChildClass-Node | ParentsTmp], ParentsOut) :-
449 insert_parent(ChildClasses, Node, ParentsTmp, ParentsOut).
450
(Target, Extracted, EGraph, EGraph) :-
452 current_prolog_flag(float_overflow, Flag),
453 setup_call_cleanup(
454 set_prolog_flag(float_overflow, infinity),
455 ( dijkstra(_, EGraph, Costs),
456 extract_all_(EGraph, Costs, Target, Extracted)
457 ),
458 set_prolog_flag(float_overflow, Flag)
459 ).
460
(EGraph, Index) :-
462 extract_pairs(EGraph, UnsortedPairs),
463 keysort(UnsortedPairs, IdPairs),
464 group_pairs_by_key(IdPairs, Groups),
465 ord_list_to_rbtree(Groups, Index).
466
([], []).
468extract_pairs([Node-node(Id, Cost)|T0], [Id-(Node-Cost)|T1]) :-
469 extract_pairs(T0, T1).
470
(Egraph, Costs, Target, Extracted) :-
472 empty_heap(HeapIn),
473 rb_lookup(Target, H-_, Costs),
474 State = state(0, [], [Target]),
475 add_to_heap(HeapIn, H, State, HeapOut),
476 extract_all_index(Egraph, Index),
477 extract_all__(Index, Costs, HeapOut, Unifs),
478 reverse(Unifs, PreOrder),
479 build_term(PreOrder, [], Extracted).
480
481build_term([_-Value | UnifsIn], UnifsOut, Term) :-
482 ( var(Value)
483 -> Term = Value,
484 UnifsOut = UnifsIn
485 ; is_dict(Value)
486 -> dict_pairs(Value, Tag, Pairs),
487 pairs_keys_values(Pairs, Keys, ChildClasses),
488 build_terms(ChildClasses, UnifsIn, UnifsOut, Values),
489 pairs_keys_values(NewPairs, Keys, Values),
490 dict_pairs(Term, Tag, NewPairs)
491 ; compound(Value)
492 -> compound_name_arguments(Value, Name, ChildClasses),
493 build_terms(ChildClasses, UnifsIn, UnifsOut, Values),
494 compound_name_arguments(Term, Name, Values)
495 ; Term = Value,
496 UnifsOut = UnifsIn
497 ).
498
499build_terms([], Unifs, Unifs, []).
500build_terms([_|Cs], UnifsIn, UnifsOut, [V|Vs]) :-
501 build_term(UnifsIn, UnifsTmp, V),
502 build_terms(Cs, UnifsTmp, UnifsOut, Vs).
503
(Index, Costs, HeapIn, Unifs) :-
505 ( get_from_heap(HeapIn, _, state(G, PartialTerm, PendingHoles), HeapTmp)
506 -> ( PendingHoles == []
507 -> ( Unifs = PartialTerm
508 ; extract_all__(Index, Costs, HeapTmp, Unifs)
509 )
510 ; [Hole | RestHoles] = PendingHoles,
511 rb_lookup(Hole, Nodes, Index),
512 extract_all_childs(Nodes, Costs, G, PartialTerm, Hole, RestHoles, HeapTmp, HeapOut),
513 extract_all__(Index, Costs, HeapOut, Unifs)
514 )
515 ).
516
([], _, _, _, _, _, Heap, Heap).
518extract_all_childs([Node-NodeCost | Nodes], Costs, G, PartialTerm, Hole, RestHoles, HeapIn, HeapOut) :-
519 ( is_dict(Node)
520 -> dict_pairs(Node, _, KeysValues),
521 pairs_values(KeysValues, ChildClasses)
522 ; compound(Node)
523 -> compound_name_arguments(Node, _, ChildClasses)
524 ; ChildClasses = []
525 ),
526 NewPartialTerm = [Hole-Node | PartialTerm],
527 append(ChildClasses, RestHoles, NewHoles),
528
529 NewG is G + NodeCost,
530 foldl(sum_costs(Costs), NewHoles, 0, H),
531 F is NewG + H,
532 add_to_heap(HeapIn, F, state(G, NewPartialTerm, NewHoles), HeapTmp),
533 extract_all_childs(Nodes, Costs, G, PartialTerm, Hole, RestHoles, HeapTmp, HeapOut).
534
535sum_costs(Costs, Hole, In, Out) :-
536 rb_lookup(Hole, Cost-_, Costs),
537 Out is In + Cost.
538
539:- begin_tests(egraph_add_terms). 540
541test_term(Var, [Var-node(_, 1)]).
542test_term('$NODE'(Var), [Var-node(_, 1)]).
543test_term(tag{k1: v1, k2: v2}, [v1-node(I1, 1), v2-node(I2, 1), tag{k1: I1, k2: I2}-node(_, 1)]).
544test_term(f(arg1, arg2), [arg1-node(I1, 1), arg2-node(I2, 1), f(I1, I2)-node(_, 1)]).
545test_term(simple_atom, [simple_atom-node(_, 1)]).
546
547test(add_term, [forall(test_term(Term, Expected)), OutNodes =@= Expected]) :-
548 phrase(add_term(Term, _Id, [var(node)]), [], OutNodes).
549
550:- end_tests(egraph_add_terms).
E-graph implementation for term rewriting and saturation
This module implements an E-graph (Equivalence Graph) data structure, commonly used for efficient term rewriting, congruence closure, and e-matching. The E-graph state is typically threaded through operations using DCG notation.
Rewrite rules are automatically compiled into efficient DCG predicates via term expansion. See the
egraph_compilemodule for full details. The supported rule declarations are:rewrite(Name, Lhs, Rhs)rewrite(Name, Lhs, Rhs, RhsOptions)rewrite(Name, Lhs, LhsOptions, Rhs, RhsOptions)rewrite(Name, Lhs, LhsOptions, Rhs, RhsOptions):- Bodyanalyze(Name, Lhs, RhsOptions)analyze(Name, Lhs, LhsOptions, RhsOptions)analyze(Name, Lhs, LhsOptions, RhsOptions):- Bodymerge_property(Name, V1, V2, Merged)merge_property(Name, V1, V2, Merged):- Bodyrule(Name, Lhs, Rhs)rule(Name, Lhs, Rhs):- Bodyrule(Name, Lhs, Rhs, RhsOptions)rule(Name, Lhs, Rhs, RhsOptions):- Bodyrule(Name, Lhs, LhsOptions, Rhs, RhsOptions)rule(Name, Lhs, LhsOptions, Rhs, RhsOptions):- BodyMain predicates:
*/