Skip to content

Commit 96ae50e

Browse files
committed
temporal_classifier_integration_ex.py
1 parent d7c9ce5 commit 96ae50e

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import pyreason as pr
2+
import torch
3+
import torch.nn as nn
4+
import numpy as np
5+
import random
6+
7+
# Set a seed for reproducibility.
8+
seed_value = 65 # Good Gap Gap
9+
# seed_value = 47 # Good Gap Good
10+
# seed_value = 43 # Good Good Good
11+
random.seed(seed_value)
12+
np.random.seed(seed_value)
13+
torch.manual_seed(seed_value)
14+
15+
# --- Part 1: Weld Quality Model Integration ---
16+
17+
# Create a dummy PyTorch model for detecting weld quality.
18+
# Each weld is represented by 3 features and is classified as "good" or "gap".
19+
weld_model = nn.Linear(3, 2)
20+
class_names = ["good", "gap"]
21+
22+
# Define integration options:
23+
# Only consider probabilities above 0.5, adjust lower bound for high confidence, and use a snap value.
24+
interface_options = pr.ModelInterfaceOptions(
25+
threshold=0.5,
26+
set_lower_bound=True,
27+
set_upper_bound=False,
28+
snap_value=1.0
29+
)
30+
31+
# Wrap the model using LogicIntegratedClassifier.
32+
weld_quality_checker = pr.LogicIntegratedClassifier(
33+
weld_model,
34+
class_names,
35+
identifier="weld_object",
36+
interface_options=interface_options
37+
)
38+
39+
# --- Part 2: Simulate Weld Inspections Over Time ---
40+
pr.add_rule(pr.Rule("repair_attempted(weld_object) <-1 gap(weld_object)", "repair attempted rule"))
41+
pr.add_rule(pr.Rule("defective(weld_object) <-0 gap(weld_object), repair_attempted(weld_object)", "defective rule"))
42+
43+
# Time step 1: Initial inspection shows the weld is good.
44+
features_t0 = torch.rand(1, 3) # Values chosen to indicate a good weld.
45+
logits_t0, probs_t0, classifier_facts_t0 = weld_quality_checker(features_t0, t1=0, t2=0)
46+
print("=== Weld Inspection at Time 0 ===")
47+
print("Logits:", logits_t0)
48+
print("Probabilities:", probs_t0)
49+
for fact in classifier_facts_t0:
50+
pr.add_fact(fact)
51+
52+
# Time step 2: Second inspection detects a gap.
53+
features_t1 = torch.rand(1, 3) # Values chosen to simulate a gap.
54+
logits_t1, probs_t1, classifier_facts_t1 = weld_quality_checker(features_t1, t1=1, t2=1)
55+
print("\n=== Weld Inspection at Time 1 ===")
56+
print("Logits:", logits_t1)
57+
print("Probabilities:", probs_t1)
58+
for fact in classifier_facts_t1:
59+
pr.add_fact(fact)
60+
61+
62+
# Time step 3: Third inspection, the gap still persists.
63+
features_t2 = torch.rand(1, 3) # Values chosen to simulate persistent gap.
64+
logits_t2, probs_t2, classifier_facts_t2 = weld_quality_checker(features_t2, t1=2, t2=2)
65+
print("\n=== Weld Inspection at Time 2 ===")
66+
print("Logits:", logits_t2)
67+
print("Probabilities:", probs_t2)
68+
for fact in classifier_facts_t2:
69+
pr.add_fact(fact)
70+
71+
72+
# --- Part 3: Run the Reasoning Engine ---
73+
74+
# Enable atom tracing for debugging the rule application process.
75+
pr.settings.atom_trace = True
76+
interpretation = pr.reason(timesteps=2)
77+
trace = pr.get_rule_trace(interpretation)
78+
79+
print("\n=== Reasoning Rule Trace ===")
80+
print(trace[0])

0 commit comments

Comments
 (0)