Skip to content

Commit 89770d3

Browse files
authored
feat: ProgramOfThought supports multiple output fields and more (#8004)
* feat: ProgramOfThought supports multiple return and more * chore: prompt fixes * chore: add missing codes * chore: comment wording fixes
1 parent e7505a1 commit 89770d3

File tree

3 files changed

+99
-56
lines changed

3 files changed

+99
-56
lines changed

dspy/predict/program_of_thought.py

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import re
33
from typing import Union, Type
4+
import json
45

56
import dspy
67
from dspy.signatures.signature import ensure_signature, Signature
@@ -10,6 +11,7 @@
1011

1112
logger = logging.getLogger(__name__)
1213

14+
1315
class ProgramOfThought(Module):
1416
"""
1517
A DSPy module that runs Python programs to solve a problem.
@@ -39,28 +41,6 @@ def __init__(self, signature: Union[str, Type[Signature]], max_iters=3):
3941
self.input_fields = signature.input_fields
4042
self.output_fields = signature.output_fields
4143

42-
assert len(self.output_fields) == 1, "PoT only supports one output field."
43-
44-
self.output_field_name = next(iter(self.output_fields))
45-
inputs_ = ", ".join(
46-
[f"`{field_name}`" for field_name in self.input_fields.keys()],
47-
)
48-
outputs_ = f"`{self.output_field_name}`"
49-
50-
assert len(self.output_fields) == 1, "PoT only supports one output field."
51-
52-
instr = []
53-
instr.append(
54-
f"You will be given {inputs_} and you will respond with {outputs_}.",
55-
)
56-
instr.append(
57-
f"Generating executable Python code that programmatically computes the correct {outputs_}.",
58-
)
59-
instr.append(
60-
f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.",
61-
)
62-
instr = "\n".join(instr)
63-
6444
self.code_generate = dspy.ChainOfThought(
6545
dspy.Signature(
6646
self._generate_signature("generate").fields,
@@ -79,8 +59,7 @@ def __init__(self, signature: Union[str, Type[Signature]], max_iters=3):
7959
self._generate_instruction("answer"),
8060
),
8161
)
82-
# Currently, the interpreter class checks the deno availability at execution time.
83-
# We may consider checking it at the initialization time for better instruction.
62+
# It will raises exception when dspy cannot find available deno instance by now.
8463
self.interpreter = PythonInterpreter()
8564

8665
def _generate_signature(self, mode):
@@ -119,25 +98,28 @@ def _generate_signature(self, mode):
11998
prefix="Code Output:",
12099
desc="output of previously-generated python code",
121100
),
122-
self.output_field_name: self.signature.fields[self.output_field_name],
123-
},
101+
}
102+
| self.signature.output_fields,
124103
}
125104
signature_dict.update(fields_for_mode[mode])
126105
return dspy.Signature(signature_dict)
127106

128107
def _generate_instruction(self, mode):
129108
mode_inputs = ", ".join(
130-
[
131-
f"`{field_name}`"
132-
for field_name in self._generate_signature(mode).input_fields
133-
],
109+
[f"`{field_name}`" for field_name in self._generate_signature(mode).input_fields],
110+
)
111+
mode_outputs = ", ".join(
112+
[f"`{field_name}`" for field_name in self._generate_signature(mode).output_fields],
113+
)
114+
final_outputs = ", ".join(
115+
[f"`{field_name}`" for field_name in self.output_fields],
134116
)
135-
mode_outputs = f"`{self.output_field_name}`"
136117
if mode == "generate":
137118
instr = [
138119
f"You will be given {mode_inputs} and you will respond with {mode_outputs}.",
139120
f"Generating executable Python code that programmatically computes the correct {mode_outputs}.",
140-
f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}.",
121+
"After you're done with the computation and think you have the answer, make sure to provide your answer by calling the preloaded function `final_answer()`.",
122+
f'You should structure your answer in a dict object, like {{"field_a": answer_a, ...}}, evaluates to the correct value mapping for {final_outputs}.',
141123
]
142124
elif mode == "regenerate":
143125
instr = [
@@ -151,11 +133,8 @@ def _generate_instruction(self, mode):
151133

152134
return "\n".join(instr)
153135

154-
155136
def _parse_code(self, code_data):
156-
code = (
157-
code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0]
158-
)
137+
code = code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0]
159138
code_match = re.search(r"```python[ \n](.*?)[ \n]```?", code, re.DOTALL)
160139
code_block = (code_match.group(1) if code_match else code).replace("\\n", "\n")
161140
if not code_block:
@@ -168,10 +147,14 @@ def _parse_code(self, code_data):
168147
code_block += "\n" + last_line_match.group(1)
169148
else:
170149
code_block = re.sub(
171-
r"([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)", r"\1\n", code_block,
150+
r"([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)",
151+
r"\1\n",
152+
code_block,
172153
)
173154
code_block = re.sub(
174-
r"([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$", r"\1\n\2", code_block,
155+
r"([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$",
156+
r"\1\n\2",
157+
code_block,
175158
)
176159
return code_block, None
177160

@@ -181,17 +164,16 @@ def _execute_code(self, code):
181164
"""
182165
if not code:
183166
return None, "Error: Empty code before execution."
184-
167+
185168
try:
186-
output = str(self.interpreter.execute(code))
169+
# Since it's more complex structure now, just blindly use json to represents all.
170+
output = json.dumps(self.interpreter.execute(code))
187171
return output, None
188172
except Exception as e:
189173
return None, str(e)
190174

191175
def forward(self, **kwargs):
192-
input_kwargs = {
193-
field_name: kwargs[field_name] for field_name in self.input_fields
194-
}
176+
input_kwargs = {field_name: kwargs[field_name] for field_name in self.input_fields}
195177
code_data = self.code_generate(**input_kwargs)
196178
output = None
197179
code, error = self._parse_code(code_data)

dspy/primitives/python_interpreter.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,16 @@ def execute(
117117
if "error" in result:
118118
error_msg = result["error"]
119119
error_type = result.get("errorType", "Sandbox Error")
120-
if error_type == "SyntaxError":
120+
if error_type == "FinalAnswer":
121+
# The `FinalAnswer` trick to receive output from the sandbox interpreter,
122+
# just simply replace the output with the arguments.
123+
result["output"] = result.get("errorArgs", None)
124+
elif error_type == "SyntaxError":
121125
raise SyntaxError(f"Invalid Python syntax. message: {error_msg}")
122-
elif error_type == "FinalAnswer":
123-
return result.get("errorArgs")
124126
else:
125127
raise InterpreterError(f"{error_type}: {result.get('errorArgs') or error_msg}")
126128

127-
# If there's no error, return the "output" field
129+
# If there's no error or got `FinalAnswer`, return the "output" field
128130
return result.get("output", None)
129131

130132
def __enter__(self):
@@ -153,4 +155,4 @@ def shutdown(self) -> None:
153155
self.deno_process.stdin.flush()
154156
self.deno_process.stdin.close()
155157
self.deno_process.wait()
156-
self.deno_process = None
158+
self.deno_process = None

tests/predict/test_program_of_thought.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,33 @@
99
# This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/
1010
is_deno_available = shutil.which("deno") is not None
1111

12+
1213
class BasicQA(Signature):
1314
question = dspy.InputField()
1415
answer = dspy.OutputField(desc="often between 1 and 5 words")
1516

1617

1718
@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
1819
def test_pot_code_generation():
20+
lm = DummyLM(
21+
[
22+
{
23+
"reasoning": "Reason_A",
24+
"generated_code": "```python\nresult = 1+1\nfinal_answer({'answer': result})\n```",
25+
},
26+
{"reasoning": "Reason_B", "answer": "2"},
27+
]
28+
)
29+
dspy.settings.configure(lm=lm)
30+
pot = ProgramOfThought(BasicQA)
31+
res = pot(question="What is 1+1?")
32+
assert res.answer == "2"
33+
assert pot.interpreter.deno_process is None
34+
35+
36+
# This test ensures the old finetuned saved models still work
37+
@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
38+
def test_old_style_pot():
1939
lm = DummyLM(
2040
[
2141
{"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+1\n```"},
@@ -29,17 +49,47 @@ def test_pot_code_generation():
2949
assert pot.interpreter.deno_process is None
3050

3151

52+
class ExtremumFinder(Signature):
53+
input_list = dspy.InputField()
54+
maximum = dspy.OutputField(desc="The maximum of the given numbers")
55+
minimum = dspy.OutputField(desc="The minimum of the given numbers")
56+
57+
58+
@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
59+
def test_pot_support_multiple_fields():
60+
lm = DummyLM(
61+
[
62+
{
63+
"reasoning": "Reason_A",
64+
"generated_code": "```python\nmaximum = 6\nminimum = 2\nfinal_answer({'maximum': maximum, 'minimum': minimum})\n```",
65+
},
66+
{"reasoning": "Reason_B", "maximum": "6", "minimum": "2"},
67+
]
68+
)
69+
dspy.settings.configure(lm=lm)
70+
pot = ProgramOfThought(ExtremumFinder)
71+
res = pot(input_list="2, 3, 5, 6")
72+
assert res.maximum == "6"
73+
assert res.minimum == "2"
74+
assert pot.interpreter.deno_process is None
75+
76+
3277
@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
3378
def test_pot_code_generation_with_one_error():
3479
lm = DummyLM(
3580
[
36-
{"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"},
37-
{"reasoning": "Reason_B", "generated_code": "```python\nresult = 1+1\n```"},
81+
{
82+
"reasoning": "Reason_A",
83+
"generated_code": "```python\nresult = 1+0/0\nfinal_answer({'answer': result})\n```",
84+
},
85+
{
86+
"reasoning": "Reason_B",
87+
"generated_code": "```python\nresult = 1+1\nfinal_answer({'answer': result})\n```",
88+
},
3889
{"reasoning": "Reason_C", "answer": "2"},
3990
]
4091
)
4192
dspy.settings.configure(lm=lm)
42-
4393
pot = ProgramOfThought(BasicQA)
4494
res = pot(question="What is 1+1?")
4595
assert res.answer == "2"
@@ -51,8 +101,12 @@ def test_pot_code_generation_persistent_errors():
51101
max_iters = 3
52102
lm = DummyLM(
53103
[
54-
{"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"},
55-
] * max_iters
104+
{
105+
"reasoning": "Reason_A",
106+
"generated_code": "```python\nresult = 1+0/0\nfinal_answer({'answer': result})\n```",
107+
},
108+
]
109+
* max_iters
56110
)
57111
dspy.settings.configure(lm=lm)
58112

@@ -67,11 +121,16 @@ def test_pot_code_parse_error():
67121
lm = DummyLM(
68122
[
69123
{"reasoning": "Reason_A", "generated_code": "```python\ninvalid=python=code\n```"},
70-
] * max_iters
124+
]
125+
* max_iters
71126
)
72127
dspy.settings.configure(lm=lm)
73-
74128
pot = ProgramOfThought(BasicQA, max_iters=max_iters)
75-
with patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code, pytest.raises(RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: Error: Code format is not correct."):
129+
with (
130+
patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code,
131+
pytest.raises(
132+
RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: Error: Code format is not correct."
133+
),
134+
):
76135
pot(question="What is 1+1?")
77136
mock_execute_code.assert_not_called()

0 commit comments

Comments
 (0)