|
1 | | -from pathlib import Path |
2 | | - |
| 1 | +import re |
3 | 2 | import click |
4 | | -from pydantic import BaseModel, Field, model_validator, ConfigDict |
| 3 | +from pydantic import BaseModel, Field, model_validator, ConfigDict, field_validator |
5 | 4 | from typing import List, Optional, Literal, Annotated, Union |
6 | | -from prich.models.output_shaping import BaseOutputShapingModel |
7 | | -from prich.constants import RESERVED_RUN_TEMPLATE_CLI_OPTIONS |
| 5 | +from prich.models.text_filter_model import TextFilterModel |
8 | 6 | from prich.models.file_scope import FileScope |
| 7 | +from prich.constants import RESERVED_RUN_TEMPLATE_CLI_OPTIONS |
9 | 8 | from prich.core.utils import is_valid_variable_name, is_cli_option_name, get_prich_dir |
10 | 9 | from prich.version import TEMPLATE_SCHEMA_VERSION |
11 | 10 |
|
@@ -45,46 +44,70 @@ class ExtractVarModel(BaseModel): |
45 | 44 | variable: str |
46 | 45 | multiple: Optional[bool] = False # default: single match |
47 | 46 |
|
| 47 | + def extract(self, text: str) -> str | list: |
| 48 | + pattern = re.compile(self.regex) |
| 49 | + if self.multiple: |
| 50 | + matches = pattern.findall(text) |
| 51 | + if matches: |
| 52 | + # if regex has groups, findall returns tuples |
| 53 | + values = [m if isinstance(m, str) else m[0] for m in matches] |
| 54 | + return values |
| 55 | + return [] |
| 56 | + else: |
| 57 | + m = pattern.search(text) |
| 58 | + if m: |
| 59 | + return m.group(1) if m.groups() else m.group(0) |
| 60 | + return "" |
| 61 | + |
| 62 | + |
| 63 | +class OutputFileModel(BaseModel): |
| 64 | + name: Optional[str | None] = None |
| 65 | + mode: Optional[Literal["write", "append"] | None] = None |
48 | 66 |
|
49 | | -class BaseStepModel(BaseOutputShapingModel): |
| 67 | + |
| 68 | +class BaseStepModel(BaseModel): |
50 | 69 | model_config = ConfigDict(extra='forbid') |
51 | 70 |
|
52 | 71 | name: str |
53 | 72 |
|
54 | 73 | # regex transforms |
55 | | - extract_vars: Optional[list[ExtractVarModel]] = None # enrichment |
| 74 | + extract_variables: Optional[list[ExtractVarModel]] = None # enrichment |
| 75 | + |
| 76 | + # output transforms |
| 77 | + filter: Optional[TextFilterModel] = None |
56 | 78 |
|
57 | 79 | # persistence |
58 | 80 | output_variable: Optional[str | None] = None |
59 | | - output_file: Optional[str | None] = None |
60 | | - output_file_mode: Optional[Literal["write", "append"]] = None |
| 81 | + output_file: Optional[str | OutputFileModel | None] = None |
61 | 82 | output_console: Optional[bool | None] = None |
62 | 83 |
|
63 | 84 | # execution control |
64 | 85 | when: Optional[str | None] = None |
65 | 86 | validate_: Optional[ValidateStepOutput | list[ValidateStepOutput]] = Field(alias="validate", default=None) |
66 | 87 |
|
67 | | - def postprocess_extract_vars(self, output: str, variables: dict): |
68 | | - import re |
69 | | - |
| 88 | + # normalize output_file |
| 89 | + @field_validator("output_file") |
| 90 | + def normalize_output_file(cls, v): |
| 91 | + if v is None: |
| 92 | + return None |
| 93 | + if isinstance(v, OutputFileModel): |
| 94 | + return v |
| 95 | + if isinstance(v, str): |
| 96 | + return OutputFileModel(name=v) |
| 97 | + if isinstance(v, dict): |
| 98 | + return OutputFileModel(name=v.get("name"), mode=v.get("mode", None)) |
| 99 | + raise ValueError("Invalid format for variable field in TransformStep") |
| 100 | + |
| 101 | + def postprocess_extract_vars(self, out: str, variables: dict): |
70 | 102 | # extract side variables |
71 | | - if self.extract_vars: |
72 | | - for spec in self.extract_vars: |
73 | | - pattern = re.compile(spec.regex) |
74 | | - if spec.multiple: |
75 | | - matches = pattern.findall(output) |
76 | | - if matches: |
77 | | - # if regex has groups, findall returns tuples |
78 | | - values = [m if isinstance(m, str) else m[0] for m in matches] |
79 | | - variables[spec.variable] = values |
80 | | - else: |
81 | | - variables[spec.variable] = [] |
82 | | - else: |
83 | | - m = pattern.search(output) |
84 | | - if m: |
85 | | - variables[spec.variable] = m.group(1) if m.groups() else m.group(0) |
86 | | - else: |
87 | | - variables[spec.variable] = "" |
| 103 | + if self.extract_variables: |
| 104 | + for spec in self.extract_variables: |
| 105 | + variables[spec.variable] = spec.extract(out) |
| 106 | + |
| 107 | + def postprocess_filter(self, out): |
| 108 | + if self.filter: |
| 109 | + out = self.filter.apply(out) |
| 110 | + return out |
88 | 111 |
|
89 | 112 |
|
90 | 113 | class LLMStep(BaseStepModel): |
|
0 commit comments