Skip to content

Commit 12aa51a

Browse files
committed
added setting to reason again with restart of time
1 parent 8eca8b6 commit 12aa51a

File tree

6 files changed

+58
-26
lines changed

6 files changed

+58
-26
lines changed

docs/source/user_guide/8_advanced_usage.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@ PyReason allows you to reason over the graph multiple times. This can be useful
2020
and add facts that were not available before. To reason over the graph multiple times, you can set ``again=True`` in ``pr.reason(again=True)``.
2121
To specify additional facts or rules, you can add them as you normally would using ``pr.add_fact`` and ``pr.add_rule``.
2222

23+
You can also clear the rules to use completely different ones with ``pr.clear_rules()``. This can be useful when you
24+
want to reason over the graph with a new set of rules.
25+
26+
When reasoning multiple times, the time is reset to zero. Therefore any facts that are added should take this into account.
27+
It is also possible to continue incrementing the time by running ``pr.reason(again=True, restart=False)``
28+
2329
.. note::
24-
When reasoning multiple times, the time continues to increment. Therefore any facts that are added should take this into account.
30+
When reasoning multiple times with ``restart=False``, the time continues to increment. Therefore any facts that are added should take this into account.
2531
The timestep parameter specifies how many additional timesteps to reason. For example, if the initial reasoning converges at
2632
timestep 5, and you want to reason for 3 more timesteps, you can set ``timestep=3`` in ``pr.reason(timestep=3, again=True)``.
2733
If you are specifying new facts, take this into account when setting their ``start_time`` and ``end_time``.

pyreason/pyreason.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,21 @@ def reset():
459459
"""Resets certain variables to None to be able to do pr.reason() multiple times in a program
460460
without memory blowing up
461461
"""
462-
global __node_facts, __edge_facts
462+
global __node_facts, __edge_facts, __graph
463+
464+
# Facts
463465
__node_facts = None
464466
__edge_facts = None
467+
if __program is not None:
468+
__program.reset_facts()
469+
470+
# Graph
471+
__graph = None
472+
if __program is not None:
473+
__program.reset_graph()
474+
475+
# Rules
476+
reset_rules()
465477

466478

467479
def get_rules():
@@ -478,14 +490,8 @@ def reset_rules():
478490
"""
479491
global __rules
480492
__rules = None
481-
482-
483-
def reset_graph():
484-
"""
485-
Resets graph to none
486-
"""
487-
global __graph
488-
__graph = None
493+
if __program is not None:
494+
__program.reset_rules()
489495

490496

491497
def reset_settings():
@@ -633,15 +639,15 @@ def add_annotation_function(function: Callable) -> None:
633639
__annotation_functions.append(function)
634640

635641

636-
def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False):
642+
def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False, restart: bool = True):
637643
"""Function to start the main reasoning process. Graph and rules must already be loaded.
638644
639645
:param timesteps: Max number of timesteps to run. -1 specifies run till convergence. If reasoning again, this is the number of timesteps to reason for extra (no zero timestep), defaults to -1
640646
:param convergence_threshold: Maximum number of interpretations that have changed between timesteps or fixed point operations until considered convergent. Program will end at convergence. -1 => no changes, perfect convergence, defaults to -1
641647
:param convergence_bound_threshold: Maximum change in any interpretation (bounds) between timesteps or fixed point operations until considered convergent, defaults to -1
642648
:param queries: A list of PyReason query objects that can be used to filter the ruleset based on the query. Default is None
643649
:param again: Whether to reason again on an existing interpretation, defaults to False
644-
:param facts: New facts to use during the next reasoning process when reasoning again. Other facts from file will be discarded, defaults to None
650+
:param restart: Whether to restart the program time from 0 when reasoning again, defaults to True
645651
:return: The final interpretation after reasoning.
646652
"""
647653
global settings, __timestamp
@@ -662,10 +668,10 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou
662668
else:
663669
if settings.memory_profile:
664670
start_mem = mp.memory_usage(max_usage=True)
665-
mem_usage, interp = mp.memory_usage((_reason_again, [timesteps, convergence_threshold, convergence_bound_threshold]), max_usage=True, retval=True)
671+
mem_usage, interp = mp.memory_usage((_reason_again, [timesteps, restart, convergence_threshold, convergence_bound_threshold]), max_usage=True, retval=True)
666672
print(f"\nProgram used {mem_usage-start_mem} MB of memory")
667673
else:
668-
interp = _reason_again(timesteps, convergence_threshold, convergence_bound_threshold)
674+
interp = _reason_again(timesteps, restart, convergence_threshold, convergence_bound_threshold)
669675

670676
return interp
671677

@@ -758,7 +764,7 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
758764
return interpretation
759765

760766

761-
def _reason_again(timesteps, convergence_threshold, convergence_bound_threshold):
767+
def _reason_again(timesteps, restart, convergence_threshold, convergence_bound_threshold):
762768
# Globals
763769
global __graph, __rules, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
764770
global settings, __timestamp, __program
@@ -772,7 +778,7 @@ def _reason_again(timesteps, convergence_threshold, convergence_bound_threshold)
772778
all_edge_facts.extend(numba.typed.List(__edge_facts))
773779

774780
# Run Program and get final interpretation
775-
interpretation = __program.reason_again(timesteps, convergence_threshold, convergence_bound_threshold, all_node_facts, all_edge_facts, settings.verbose)
781+
interpretation = __program.reason_again(timesteps, restart, convergence_threshold, convergence_bound_threshold, all_node_facts, all_edge_facts, settings.verbose)
776782

777783
return interpretation
778784

pyreason/scripts/interpretation/interpretation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ def _init_convergence(convergence_bound_threshold, convergence_threshold):
180180
convergence_delta = convergence_bound_threshold
181181
return convergence_mode, convergence_delta
182182

183-
def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_threshold, convergence_bound_threshold, again=False):
183+
def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_threshold, convergence_bound_threshold, again=False, restart=True):
184184
self.tmax = tmax
185185
self._convergence_mode, self._convergence_delta = self._init_convergence(convergence_bound_threshold, convergence_threshold)
186186
max_facts_time = self._init_facts(facts_node, facts_edge, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.atom_trace)
187-
self._start_fp(rules, max_facts_time, verbose, again)
187+
self._start_fp(rules, max_facts_time, verbose, again, restart)
188188

189189
@staticmethod
190190
@numba.njit(cache=True)
@@ -208,9 +208,12 @@ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_ap
208208
facts_to_be_applied_edge_trace.append(fact.get_name())
209209
return max_time
210210

211-
def _start_fp(self, rules, max_facts_time, verbose, again):
211+
def _start_fp(self, rules, max_facts_time, verbose, again, restart):
212212
if again:
213213
self.num_ga.append(self.num_ga[-1])
214+
if restart:
215+
self.time = 0
216+
self.prev_reasoning_data[0] = 0
214217
fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.persistent, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, self.num_ga, verbose, again)
215218
self.time = t - 1
216219
# If we need to reason again, store the next timestep to start from

pyreason/scripts/interpretation/interpretation_parallel.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ def _init_convergence(convergence_bound_threshold, convergence_threshold):
180180
convergence_delta = convergence_bound_threshold
181181
return convergence_mode, convergence_delta
182182

183-
def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_threshold, convergence_bound_threshold, again=False):
183+
def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_threshold, convergence_bound_threshold, again=False, restart=True):
184184
self.tmax = tmax
185185
self._convergence_mode, self._convergence_delta = self._init_convergence(convergence_bound_threshold, convergence_threshold)
186186
max_facts_time = self._init_facts(facts_node, facts_edge, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.atom_trace)
187-
self._start_fp(rules, max_facts_time, verbose, again)
187+
self._start_fp(rules, max_facts_time, verbose, again, restart)
188188

189189
@staticmethod
190190
@numba.njit(cache=True)
@@ -208,9 +208,12 @@ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_ap
208208
facts_to_be_applied_edge_trace.append(fact.get_name())
209209
return max_time
210210

211-
def _start_fp(self, rules, max_facts_time, verbose, again):
211+
def _start_fp(self, rules, max_facts_time, verbose, again, restart):
212212
if again:
213213
self.num_ga.append(self.num_ga[-1])
214+
if restart:
215+
self.time = 0
216+
self.prev_reasoning_data[0] = 0
214217
fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.persistent, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, self.num_ga, verbose, again)
215218
self.time = t - 1
216219
# If we need to reason again, store the next timestep to start from

pyreason/scripts/program/program.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,23 @@ def reason(self, tmax, convergence_threshold, convergence_bound_threshold, verbo
3939

4040
return self.interp
4141

42-
def reason_again(self, tmax, convergence_threshold, convergence_bound_threshold, facts_node, facts_edge, verbose=True):
42+
def reason_again(self, tmax, restart, convergence_threshold, convergence_bound_threshold, facts_node, facts_edge, verbose=True):
4343
assert self.interp is not None, 'Call reason before calling reason again'
44-
self._tmax = self.interp.time + tmax
45-
self.interp.start_fp(self._tmax, facts_node, facts_edge, self._rules, verbose, convergence_threshold, convergence_bound_threshold, again=True)
44+
if restart:
45+
self._tmax = tmax
46+
else:
47+
self._tmax = self.interp.time + tmax
48+
self.interp.start_fp(self._tmax, facts_node, facts_edge, self._rules, verbose, convergence_threshold, convergence_bound_threshold, again=True, restart=restart)
4649

4750
return self.interp
51+
52+
def reset_graph(self):
53+
self._graph = None
54+
self.interp = None
55+
56+
def reset_rules(self):
57+
self._rules = None
58+
59+
def reset_facts(self):
60+
self._facts_node = None
61+
self._facts_edge = None

tests/test_reason_again.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_reason_again():
2929
# Now reason again
3030
new_fact = pr.Fact('popular(Mary)', 'popular_fact2', 2, 4)
3131
pr.add_fact(new_fact)
32-
interpretation = pr.reason(timesteps=3, again=True)
32+
interpretation = pr.reason(timesteps=3, again=True, restart=False)
3333

3434
# Display the changes in the interpretation for each timestep
3535
dataframes = pr.filter_and_sort_nodes(interpretation, ['popular'])

0 commit comments

Comments
 (0)