Skip to content

Commit ef83201

Browse files
fix the saving of TypePredictor (#1651)
1 parent 9a952fe commit ef83201

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

dspy/functional/functional.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,20 @@ def __init__(self, signature, instructions=None, *, max_retries=3, wrap_json=Fal
9898
explain_errors: If True, the model will try to explain the errors it encounters.
9999
"""
100100
super().__init__()
101-
self.signature = ensure_signature(signature, instructions)
101+
signature = ensure_signature(signature, instructions)
102102
self.predictor = dspy.Predict(signature, _parse_values=False)
103103
self.max_retries = max_retries
104104
self.wrap_json = wrap_json
105105
self.explain_errors = explain_errors
106106

107+
@property
108+
def signature(self) -> dspy.Signature:
109+
return self.predictor.signature
110+
111+
@signature.setter
112+
def signature(self, value: dspy.Signature):
113+
self.predictor.signature = value
114+
107115
def copy(self) -> "TypedPredictor":
108116
return TypedPredictor(
109117
self.signature,

tests/functional/test_functional.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,3 +822,23 @@ def check_cateogry(self):
822822

823823
pred = predictor(input_data="What is the best animal?", allowed_categories=["cat", "dog"])
824824
assert pred.category == "dog"
825+
826+
def test_save_type_predictor(tmp_path):
827+
class MySignature(dspy.Signature):
828+
"""I am a benigh signature."""
829+
question: str = dspy.InputField()
830+
answer: str = dspy.OutputField()
831+
832+
class CustomModel(dspy.Module):
833+
def __init__(self):
834+
self.predictor = dspy.TypedPredictor(MySignature)
835+
836+
save_path = tmp_path / "state.json"
837+
model = CustomModel()
838+
model.predictor.signature = MySignature.with_instructions("I am a malicious signature.")
839+
model.save(save_path)
840+
841+
loaded = CustomModel()
842+
assert loaded.predictor.signature.instructions == "I am a benigh signature."
843+
loaded.load(save_path)
844+
assert loaded.predictor.signature.instructions == "I am a malicious signature."

0 commit comments

Comments
 (0)