Skip to content

Commit dce853e

Browse files
authored
Merge pull request #29 from oleks-dev/refactor_output_filters
output filters refactored, added regex_replace for sanitization
2 parents 137d1c5 + 8760e57 commit dce853e

File tree

11 files changed

+300
-135
lines changed

11 files changed

+300
-135
lines changed

docs/reference/config/providers.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ Send prompts to a command via STDIN and read STDOUT (e.g., q chat, mlx_lm.genera
6767
qchat:
6868
provider_type: "stdin_consumer"
6969
mode: "flat"
70-
call: "q" # executable file
70+
call: "q" # executable file
7171
args: ["chat", "--no-interactive"] # optional arguments
72-
strip_output_prefix: "> " # optional [str] - strip prefix text from step output
73-
slice_output_start: 0 # optional [int] - slice step output text starting from N
74-
slice_output_end: -1 # optional [int] - slice step output text ending at N (negative counts from the back)
72+
filter:
73+
strip: true # strip leading and ending spaces
74+
strip_prefix: "> " # remove q char initial character and space
7575
```
7676

7777
#### MLX LM Generate
@@ -85,5 +85,6 @@ qchat:
8585
- "/Users/guest/.cache/huggingface/hub/models--mlx-community--Mistral-7B-Instruct-v0.3-4bit/snapshots/a4b8f870474b0eb527f466a03fbc187830d271f5"
8686
- "--prompt"
8787
- "-"
88-
output_regex: "^==========\\n((?:.|\\n)+)\\n\\=\\=\\=\\=\\=\\=\\=\\=\\=\\=(?:.|\\n)+$"
88+
filter:
89+
regex_extract: "^==========\\n((?:.|\\n)+)\\n\\=\\=\\=\\=\\=\\=\\=\\=\\=\\=(?:.|\\n)+$"
8990
```

docs/reference/template/steps.md

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,27 @@ steps:
2121
2222
##### Output text transformations
2323
```yaml
24-
# strip spaces from beginning and end of the output
25-
strip_output: true # optional [bool]
24+
filter:
25+
# strip spaces from beginning and end of the output
26+
strip: true # optional [bool]
2627

27-
# strip prefix characters from the output
28-
strip_output_prefix: "> " # optional [str]
28+
# strip prefix characters from the output
29+
strip_prefix: "> " # optional [str]
2930

30-
# slice output from N character
31-
slice_output_start: 3 # optional [int]
31+
# slice output from N character
32+
slice_start: 3 # optional [int]
3233

33-
# slice output till N character
34-
# (use negative number for backwards count)
35-
slice_output_end: -1 # optional [int]
34+
# slice output till N character
35+
# (use negative number for backwards count)
36+
slice_end: -1 # optional [int]
3637

37-
# filter output using regex
38-
# (take 1st group if groups are used otherwise take matching regex)
39-
output_regex: ".*" # optional [str]
38+
# filter output using regex
39+
# (take 1st group if groups are used otherwise take matching regex)
40+
regex_extract: ".*" # optional [str]
41+
42+
# replace output text using regex patterns
43+
regex_replace: # optional [list(tuple(str,str))] - regex pattern, replace
44+
- ["(?i)(\"password\"\s*:\s*\")[^\"]+(\")", "\\1*****\\2"] # (ex. for json passwords sanitization)
4045
```
4146
4247
> **Note:** Output text transformation params applied one by one in order as they mentioned in the example, each next one uses result of the previous.
@@ -48,10 +53,12 @@ steps:
4853
output_variable: "out_var" # optional [str]
4954

5055
# save output to file
51-
output_file: "out.txt" # optional [str]
52-
53-
# file mode overwrite or append
54-
output_file_mode: "write" # optional ["write"|"append"]
56+
output_file: "out.txt" # optional [str|dict] - write mode by default
57+
# OR
58+
output_file:
59+
name: "out.txt"
60+
# file mode write or append
61+
mode: "write" # optional ["write"|"append"]
5562

5663
# print output to console during normal execution
5764
# (mostly for user reference when such information is needed)

prich/core/engine.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -95,37 +95,36 @@ def run_template(template_id, **kwargs):
9595
raise click.ClickException(f"Step {step.type} type is not supported.")
9696

9797
if is_verbose():
98-
if step.extract_vars or step.output_regex or step.strip_output_prefix or step.slice_output_start or step.slice_output_end:
99-
console_print(f"[dim]Output: '{step_output}'[/dim]")
100-
step.postprocess_extract_vars(output=step_output, variables=variables)
101-
step_output = step.postprocess_output(output=step_output)
98+
if step.extract_variables or step.filter:
99+
console_print(f"[dim]Output:\n{step_output}[/dim]")
100+
step.postprocess_extract_vars(out=step_output, variables=variables)
101+
step_output = step.postprocess_filter(out=step_output)
102102
if is_verbose():
103-
if step.extract_vars:
104-
for spec in step.extract_vars:
105-
console_print(f"[dim]Inject \"{spec.regex}\" {f'({len(variables.get(spec.variable))} matches) ' if spec.multiple else ''}{spec.variable}: {f'{variables.get(spec.variable)}' if type(variables.get(spec.variable) == str) else variables.get(spec.variable)}[/dim]")
106-
if step.strip_output is not None:
107-
console_print(f"[dim]Strip output spaces: {step.strip_output}[/dim]")
108-
if step.strip_output_prefix:
109-
console_print(f"[dim]Strip output prefix: \"{step.strip_output_prefix}\"[/dim]")
110-
if step.slice_output_start or step.slice_output_end:
111-
console_print(f"[dim]Slice output text{f' from {step.slice_output_start}' if step.slice_output_start else ''}{f' to {step.slice_output_end}' if step.slice_output_end else ''}[/dim]")
112-
if step.output_regex:
113-
console_print(f"[dim]Apply regex: \"{step.output_regex}\"[/dim]")
103+
if step.extract_variables:
104+
for spec in step.extract_variables:
105+
console_print(f"""[dim]Inject \"{spec.regex}\" {f'({len(variables.get(spec.variable))} matches) ' if spec.multiple else ''}{spec.variable}: {f'"{variables.get(spec.variable)}"' if type(variables.get(spec.variable)) == str else variables.get(spec.variable)}[/dim]""")
106+
if step.filter:
107+
if step.filter.strip is not None:
108+
console_print(f"[dim]Strip output spaces: {step.filter.strip}[/dim]")
109+
if step.filter.strip_prefix:
110+
console_print(f"[dim]Strip output prefix: \"{step.filter.strip_prefix}\"[/dim]")
111+
if step.filter.slice_start or step.filter.slice_end:
112+
console_print(f"[dim]Slice output text{f' from {step.filter.slice_start}' if step.filter.slice_start else ''}{f' to {step.filter.slice_end}' if step.filter.slice_end else ''}[/dim]")
113+
if step.filter.regex_extract:
114+
console_print(f"[dim]Apply regex: \"{step.filter.regex_extract}\"[/dim]")
115+
if step.filter.regex_replace:
116+
replace_details = '\n '.join([f"\"{x}\"\"{y}\"" for x,y in step.filter.regex_replace])
117+
console_print(f"[dim]Apply regex replace: {replace_details}[/dim]")
114118

115119
# Store last output
116120
last_output = step_output
117121

118122
if output_var:
119123
variables[output_var] = step_output
120124
if step.output_file:
121-
if step.output_file.startswith('.'):
122-
save_to_file = step.output_file.replace('.', str(get_cwd_dir()), 1)
123-
elif step.output_file.startswith('~'):
124-
save_to_file = step.output_file.replace('~', str(get_home_dir()), 1)
125-
else:
126-
save_to_file = step.output_file
125+
save_to_file = step.output_file.name
127126
try:
128-
write_mode = step.output_file_mode[:1] if step.output_file_mode else 'w'
127+
write_mode = step.output_file.mode[:1] if step.output_file.mode else 'w'
129128
with open(save_to_file, write_mode) as step_output_file:
130129
if is_verbose():
131130
console_print(f"[dim]{'Save' if write_mode == 'w' else 'Append'} output to file: {save_to_file}[/dim]")

prich/core/steps/step_sent_to_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def send_to_llm(template: TemplateModel, step: LLMStep, provider: str, config: C
7575
instructions=step.rendered_instructions,
7676
input_=step.rendered_input
7777
)
78-
step_output = selected_provider.postprocess_output(response)
78+
step_output = selected_provider.postprocess_filter(response)
7979
if (is_verbose() or step.output_console) and not llm_provider.show_response and not is_quiet():
8080
console_print(step_output, markup=False)
8181
except Exception as e:

prich/models/config_providers.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1-
from pydantic import Field, ConfigDict
1+
from pydantic import Field, ConfigDict, BaseModel
22
from typing import Literal, Optional, List, Tuple
3-
from prich.models.output_shaping import BaseOutputShapingModel
3+
from prich.models.text_filter_model import TextFilterModel
44

55

6-
class BaseProviderModel(BaseOutputShapingModel):
6+
class BaseProviderModel(BaseModel):
77
model_config = ConfigDict(extra='forbid')
88
name: str | None = Field(default=None, exclude=True) # will be injected
99

1010
mode: Optional[str] = None
1111

12+
# transforms
13+
filter: Optional[TextFilterModel] = None
14+
15+
def postprocess_filter(self, out):
16+
if self.filter:
17+
out = self.filter.apply(out)
18+
return out
19+
1220
def model_post_init(self, __context):
1321
if self.name is None and __context and "__name" in __context:
1422
self.name = __context["__name"]

prich/models/output_shaping.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

prich/models/template.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from pathlib import Path
2-
1+
import re
32
import click
4-
from pydantic import BaseModel, Field, model_validator, ConfigDict
3+
from pydantic import BaseModel, Field, model_validator, ConfigDict, field_validator
54
from typing import List, Optional, Literal, Annotated, Union
6-
from prich.models.output_shaping import BaseOutputShapingModel
7-
from prich.constants import RESERVED_RUN_TEMPLATE_CLI_OPTIONS
5+
from prich.models.text_filter_model import TextFilterModel
86
from prich.models.file_scope import FileScope
7+
from prich.constants import RESERVED_RUN_TEMPLATE_CLI_OPTIONS
98
from prich.core.utils import is_valid_variable_name, is_cli_option_name, get_prich_dir
109
from prich.version import TEMPLATE_SCHEMA_VERSION
1110

@@ -45,46 +44,70 @@ class ExtractVarModel(BaseModel):
4544
variable: str
4645
multiple: Optional[bool] = False # default: single match
4746

47+
def extract(self, text: str) -> str | list:
48+
pattern = re.compile(self.regex)
49+
if self.multiple:
50+
matches = pattern.findall(text)
51+
if matches:
52+
# if regex has groups, findall returns tuples
53+
values = [m if isinstance(m, str) else m[0] for m in matches]
54+
return values
55+
return []
56+
else:
57+
m = pattern.search(text)
58+
if m:
59+
return m.group(1) if m.groups() else m.group(0)
60+
return ""
61+
62+
63+
class OutputFileModel(BaseModel):
64+
name: Optional[str | None] = None
65+
mode: Optional[Literal["write", "append"] | None] = None
4866

49-
class BaseStepModel(BaseOutputShapingModel):
67+
68+
class BaseStepModel(BaseModel):
5069
model_config = ConfigDict(extra='forbid')
5170

5271
name: str
5372

5473
# regex transforms
55-
extract_vars: Optional[list[ExtractVarModel]] = None # enrichment
74+
extract_variables: Optional[list[ExtractVarModel]] = None # enrichment
75+
76+
# output transforms
77+
filter: Optional[TextFilterModel] = None
5678

5779
# persistence
5880
output_variable: Optional[str | None] = None
59-
output_file: Optional[str | None] = None
60-
output_file_mode: Optional[Literal["write", "append"]] = None
81+
output_file: Optional[str | OutputFileModel | None] = None
6182
output_console: Optional[bool | None] = None
6283

6384
# execution control
6485
when: Optional[str | None] = None
6586
validate_: Optional[ValidateStepOutput | list[ValidateStepOutput]] = Field(alias="validate", default=None)
6687

67-
def postprocess_extract_vars(self, output: str, variables: dict):
68-
import re
69-
88+
# normalize output_file
89+
@field_validator("output_file")
90+
def normalize_output_file(cls, v):
91+
if v is None:
92+
return None
93+
if isinstance(v, OutputFileModel):
94+
return v
95+
if isinstance(v, str):
96+
return OutputFileModel(name=v)
97+
if isinstance(v, dict):
98+
return OutputFileModel(name=v.get("name"), mode=v.get("mode", None))
99+
raise ValueError("Invalid format for variable field in TransformStep")
100+
101+
def postprocess_extract_vars(self, out: str, variables: dict):
70102
# extract side variables
71-
if self.extract_vars:
72-
for spec in self.extract_vars:
73-
pattern = re.compile(spec.regex)
74-
if spec.multiple:
75-
matches = pattern.findall(output)
76-
if matches:
77-
# if regex has groups, findall returns tuples
78-
values = [m if isinstance(m, str) else m[0] for m in matches]
79-
variables[spec.variable] = values
80-
else:
81-
variables[spec.variable] = []
82-
else:
83-
m = pattern.search(output)
84-
if m:
85-
variables[spec.variable] = m.group(1) if m.groups() else m.group(0)
86-
else:
87-
variables[spec.variable] = ""
103+
if self.extract_variables:
104+
for spec in self.extract_variables:
105+
variables[spec.variable] = spec.extract(out)
106+
107+
def postprocess_filter(self, out):
108+
if self.filter:
109+
out = self.filter.apply(out)
110+
return out
88111

89112

90113
class LLMStep(BaseStepModel):

prich/models/text_filter_model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import re
2+
from typing import Optional, List, Tuple
3+
from pydantic import BaseModel
4+
5+
class TextFilterModel(BaseModel):
6+
strip: Optional[bool] = True
7+
strip_prefix: Optional[str] = None
8+
slice_start: Optional[int] = None
9+
slice_end: Optional[int] = None
10+
regex_extract: Optional[str] = None
11+
regex_replace: Optional[List[Tuple[str, str]]] = None # [(pattern, replacement), ...]
12+
13+
def apply(self, text: str) -> str:
14+
out = text
15+
16+
if self.strip:
17+
out = out.strip()
18+
19+
if self.strip_prefix and out.startswith(self.strip_prefix):
20+
out = out[len(self.strip_prefix):]
21+
22+
if self.slice_start is not None or self.slice_end is not None:
23+
out = out[self.slice_start:self.slice_end]
24+
25+
if self.regex_extract:
26+
m = re.search(self.regex_extract, out)
27+
out = m.group(1) if (m and m.groups()) else (m.group(0) if m else "")
28+
29+
if self.regex_replace:
30+
for pattern, repl in self.regex_replace:
31+
out = re.sub(pattern, repl, out)
32+
33+
return out

templates_guide.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@ Each step defines an action for the template pipeline workflow executed in order
2727
**Step can have next key parameters**:
2828
* `name`: name of the step (should be unique for each step in the template)
2929
* `output_variable`: save output of the execution into a variable for the following usage
30-
* `output_file`: save output of the execution into a file
31-
* `output_file_mode`: `write`/`append`
32-
* `strip_output_prefix`: strip prefix string from the step output
33-
* `slice_output_start`: slice step output from character number
34-
* `slice_output_end`: slice step output to character number
30+
* `output_file`: save output of the execution into a file (default would be `write` mode)
31+
* `output_file`:
32+
`name`: save output of the execution into a file
33+
`mode`: `write` or `append`
34+
* `filter`:
35+
`strip_prefix`: strip prefix string from the step output
36+
`slice_start`: slice step output from character number
37+
`slice_end`: slice step output to character number
3538
* `when`: execute step only when true - simple jinja evaluation like `working_vs_last_commit or (not working_vs_last_commit and not working_vs_remote and not committed_vs_remote and not remote_vs_local)` or `not remote_vs_local`, etc.
3639
* `validate`: validate step execution (see `step validate`)
3740
* plus additional keys based on the step type

0 commit comments

Comments
 (0)