Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions cleanlab_studio/studio/enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the PR description with:

  1. User code if they just want to do some real-time data enrichment quickly.

  2. User code if they want to first run data enrichment project over a big static dataset, and then later want to run some real-time data enrichment over additional data.

from tqdm import tqdm
from typing_extensions import NotRequired
from functools import lru_cache

from cleanlab_studio.errors import EnrichmentProjectError
from cleanlab_studio.internal.api import api
Expand Down Expand Up @@ -49,6 +50,13 @@ def _response_timestamp_to_datetime(timestamp_string: str) -> datetime:
return datetime.strptime(timestamp_string, response_timestamp_format_str)


@lru_cache(maxsize=None)
def _get_run_online():
from cleanlab_studio.utils.data_enrichment.enrich import run_online

return run_online

Comment on lines +54 to +59
Copy link
Member

@jwmueller jwmueller Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get rid of this and do a standard import, unsure why you are using such an odd approach


class EnrichmentProject:
"""Represents an Enrichment Project instance, which is bound to a Cleanlab Studio account.

Expand Down Expand Up @@ -342,9 +350,11 @@ def list_all_jobs(self) -> List[EnrichmentJob]:
id=job["id"],
status=job["status"],
created_at=_response_timestamp_to_datetime(job["created_at"]),
updated_at=_response_timestamp_to_datetime(job["updated_at"])
if job["updated_at"]
else None,
updated_at=(
_response_timestamp_to_datetime(job["updated_at"])
if job["updated_at"]
else None
),
enrichment_options=EnrichmentOptions(**enrichment_options_dict), # type: ignore
average_trustworthiness_score=job["average_trustworthiness_score"],
job_type=job["type"],
Expand Down Expand Up @@ -399,6 +409,27 @@ def resume(self) -> JSONDict:
latest_job = self._get_latest_job()
return api.resume_enrichment_job(api_key=self._api_key, job_id=latest_job["id"])

def run_online(
self,
data: Union[pd.DataFrame, List[dict]],
options: EnrichmentOptions,
new_column_name: str,
) -> Dict[str, Any]:
"""
Enrich data in real-time using the same logic as the run() method, but client-side.

Args:
data (Union[pd.DataFrame, List[dict]]): The dataset to enrich.
options (EnrichmentOptions): Options for enriching the dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link to EnrichmentOptions docstring

new_column_name (str): The name of the new column to store the results.

Returns:
Dict[str, Any]: A dictionary containing information about the enrichment job and the enriched dataset.
"""
run_online = _get_run_online()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not just imported at the top of the file?

job_info = run_online(data, options, new_column_name, self._api_key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think passing in self._api_key because run_online expects a Studio object?

return job_info


class EnrichmentJob(TypedDict):
"""Represents an Enrichment Job instance.
Expand Down
209 changes: 114 additions & 95 deletions cleanlab_studio/utils/data_enrichment/enrich.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,71 @@
from typing import Any, List, Optional, Tuple, Union
import pandas as pd
from typing import Any, List, Tuple, Union, Dict
from functools import lru_cache
from cleanlab_studio.internal.enrichment_utils import (
extract_df_subset,
get_prompt_outputs,
get_regex_match_or_replacement,
get_constrain_outputs_match,
get_optimized_prompt,
Replacement,
)
from cleanlab_studio.studio.enrichment import EnrichmentOptions

from cleanlab_studio.studio.studio import Studio


def enrich_data(
studio: Studio,
data: pd.DataFrame,
prompt: str,
*,
regex: Optional[Union[str, Replacement, List[Replacement]]] = None,
constrain_outputs: Optional[List[str]] = None,
optimize_prompt: bool = True,
subset_indices: Optional[Union[Tuple[int, int], List[int]]] = (0, 3),
new_column_name: str = "metadata",
disable_warnings: bool = False,
**kwargs: Any,
) -> pd.DataFrame:

@lru_cache(maxsize=None)
def _get_pandas():
import pandas as pd

return pd
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is going on here?

pandas is already a dependency of this package, there should be no special logic to lazy-import it

"pandas==2.*",



@lru_cache(maxsize=None)
def _get_tqdm():
from tqdm import tqdm

return tqdm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is going on here?

tqdm is already a dependency of this package, there should be no special logic to lazy import it

"tqdm>=4.64.0",


def run_online(
data: Union["pd.DataFrame", List[dict]],
options: EnrichmentOptions,
new_column_name: str,
studio: Any,
) -> Dict[str, Any]:
"""
Generate a column of arbitrary metadata for your DataFrame, reliably at scale with Generative AI.
The metadata is separately generated for each row of your DataFrame, based on a prompt that specifies what information you need and what existing columns' data it should be derived from.
Each row of generated metadata is accompanied by a trustworthiness score, which helps you discover which metadata is most/least reliable.
You can optionally apply regular expressions to further reformat your metadata beyond raw LLM outputs, or specify that each row of the metadata must be constrained to a particular set of values.
Enrich data in real-time using the same logic as the run() method, but client-side.

Args:
studio (Studio): Cleanlab Studio client object, which you must instantiate before calling this method.
data (pd.DataFrame): A pandas DataFrame containing your data.
prompt (str): Formatted f-string, that contains both the prompt, and names of columns to embed.
**Example:** "Is this a numeric value, answer Yes or No only. Value: {column_name}"
regex (str | Replacement | List[Replacement], optional): A string, tuple, or list of tuples specifying regular expressions to apply for post-processing the raw LLM outputs.

If a string value is passed in, a regex match will be performed and the matched pattern will be returned (if the pattern cannot be matched, None will be returned).
Specifically the provided string will be passed into Python's `re.match()` method.
Pass in a tuple `(R1, R2)` instead if you wish to perform find and replace operations rather than matching/extraction.
`R1` should be a string containing the regex pattern to match, and `R2` should be a string to replace matches with.
Pass in a list of tuples instead if you wish to apply multiple replacements. Replacements will be applied in the order they appear in the list.
Note that you cannot pass in a list of strings (chaining of multiple regex processing steps is only allowed for replacement operations).

These tuples specify the desired patterns to match and replace from the raw LLM response,
This regex processing is useful in settings where you are unable to prompt the LLM to generate valid outputs 100% of the time,
but can easily transform the raw LLM outputs to be valid through regular expressions that extract and replace parts of the raw output string.
When this regex is applied, the processed results can be seen ithe ``{new_column_name}`` column, and the raw outpus (before any regex processing)
will be saved in the ``{new_column_name}_log`` column of the results dataframe.

**Example 1:** ``regex = '.*The answer is: (Bird|[Rr]abbit).*'`` will extract strings that are the words 'Bird', 'Rabbit' or 'rabbit' after the characters "The answer is: " from the raw response.
**Example 2:** ``regex = [('True', 'T'), ('False', 'F')]`` will replace the words True and False with T and F.
**Example 3:** ``regex = (' Explanation:.*', '') will remove everything after and including the words "Explanation:".
For instance, the response "True. Explanation: 3+4=7, and 7 is an odd number." would return "True." after the regex replacement.
constrain_outputs (List[str], optional): List of all possible output values for the `metadata` column.
If specified, every entry in the `metadata` column will exactly match one of these values (for less open-ended data enrichment tasks). If None, the `metadata` column can contain arbitrary values (for more open-ended data enrichment tasks).
There may be additional transformations applied to ensure the returned value is one of these. If regex is also specified, then these transformations occur after your regex is applied.
If `optimize_prompt` is True, the prompt will be automatically adjusted to include a statement that the response must match one of the `constrain_outputs`.
The last value of this list should be considered the baseline value (eg. “other”) that will be returned where there are no close matches between the raw LLM response and any of the classes mentioned,
that value will be returned if no close matches can be made.
optimize_prompt (bool, default = True): When False, your provided prompt will not be modified in any way. When True, your provided prompt may be automatically adjusted in an effort to produce better results.
For instance, if the constrain_outputs are constrained, we may automatically append the following statement to your prompt: "Your answer must exactly match one of the following values: `constrain_outputs`."
subset_indices (Tuple[int, int] | List[int], optional): What subset of the supplied data rows to generate metadata for. If None, we run on all of the data.
This can be either a list of unique indices or a range. These indices are passed into pandas ``.iloc`` method, so should be integers based on row order as opposed to row-index labels pointing to `df.index`.
We advise against collecting results for all of your data at first. First collect results for a smaller data subset, and use this subset to experiment with different values of the `prompt` or `regex` arguments. Only once the results look good for your subset should you run on the full dataset.
new_column_name (str): Optional name for the returned enriched column. Name acts as a prefix appended to all additional columns that are returned.
disable_warnings (bool, default = False): When True, warnings are disabled.
**kwargs: Optional keyword arguments to pass to the underlying TLM object, such as ``quality_preset`` and ``options`` to specify the TLM quality present and TLMOptions respectively.
For more information on valid TLM arguments, view the TLM documentation here: https://help.cleanlab.ai/reference/python/studio/#method-tlm
data (Union[pd.DataFrame, List[dict]]): The dataset to enrich.
options (EnrichmentOptions): Options for enriching the dataset.
new_column_name (str): The name of the new column to store the results.
studio (Any): A required parameter for the Studio object.

Returns:
A DataFrame that contains `metadata` and `trustworthiness` columns related to the prompt in order of original data. Some columns names will have `new_column_name` prepended to them.
`metadata` column = responses to the prompt and other data mutations if `regex` or `constrain_outputs` is not specified.
`trustworthiness` column = trustworthiness of the prompt responses (which ignore the data mutations).
**Note**: If you specified the `regex` or `constrain_outputs` arguments, some additional transformations may be applied to raw LLM outputs to produce the returned values. In these cases, an additional `log` column will be added to the returned DataFrame that records the raw LLM outputs (feel free to disregard these).
Dict[str, Any]: A dictionary containing information about the enrichment job and the enriched dataset.
"""
if subset_indices:
df = extract_df_subset(data, subset_indices)
else:
df = data.copy()
pd = _get_pandas()
tqdm = _get_tqdm()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason these are not just being imported at the top of the file?


# Validate options
_validate_enrichment_options(options)

# Ensure data is a DataFrame
if isinstance(data, list):
data = pd.DataFrame(data)

df = data.copy()

# Extract options
prompt = options["prompt"]
regex = options.get("regex")
constrain_outputs = options.get("constrain_outputs")
optimize_prompt = options.get("optimize_prompt", True)
quality_preset = options.get("quality_preset", "medium")

if optimize_prompt:
prompt = get_optimized_prompt(prompt, constrain_outputs)

outputs = get_prompt_outputs(studio, prompt, df, **kwargs)
outputs = get_prompt_outputs(
studio, prompt, df, quality_preset=quality_preset, **options.get("tlm_options", {})
)
column_name_prefix = new_column_name + "_"

df[f"{column_name_prefix}trustworthiness"] = [
Expand All @@ -94,41 +74,78 @@ def enrich_data(
df[f"{new_column_name}"] = [
output["response"] if output is not None else None for output in outputs
]

if (
regex is None and constrain_outputs is None
): # we do not need to have a "log" column as original output is not augmented by regex replacements or contrained outputs
return df[[f"{new_column_name}", f"{column_name_prefix}trustworthiness"]]

df[f"{column_name_prefix}log"] = [
output["response"] if output is not None else None for output in outputs
]

if regex:
df[f"{new_column_name}"] = df[f"{new_column_name}"].apply(
lambda x: get_regex_match_or_replacement(x, regex)
)
if regex is None and constrain_outputs is None:
enriched_df = df[[f"{new_column_name}", f"{column_name_prefix}trustworthiness"]]
else:
if regex:
df[f"{new_column_name}"] = df[f"{new_column_name}"].apply(
lambda x: get_regex_match_or_replacement(x, regex)
)

if constrain_outputs:
df[f"{new_column_name}"] = df[f"{new_column_name}"].apply(
lambda x: get_constrain_outputs_match(
x, constrain_outputs, disable_warnings=disable_warnings
if constrain_outputs:
df[f"{new_column_name}"] = df[f"{new_column_name}"].apply(
lambda x: get_constrain_outputs_match(x, constrain_outputs)
)
)

return df[
[
f"{new_column_name}",
f"{column_name_prefix}trustworthiness",
f"{column_name_prefix}log",
enriched_df = df[
[
f"{new_column_name}",
f"{column_name_prefix}trustworthiness",
f"{column_name_prefix}log",
]
]
]

# Simulate the response structure of the run() method
job_info = {
"job_id": "run_online",
"status": "SUCCEEDED",
"num_rows": len(enriched_df),
"processed_rows": len(enriched_df),
"average_trustworthiness_score": enriched_df[f"{column_name_prefix}trustworthiness"].mean(),
"results": enriched_df,
}

return job_info


def _validate_enrichment_options(options: EnrichmentOptions) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why there is a separate _validate_enrichment_options defined here rather than using the validation function in run() here?

required_keys = ["prompt"]
for key in required_keys:
if key not in options or options[key] is None:
raise ValueError(f"'{key}' is required in the options.")

# Validate types and values
if not isinstance(options["prompt"], str):
raise TypeError("'prompt' must be a string.")

if "constrain_outputs" in options and options["constrain_outputs"] is not None:
if not isinstance(options["constrain_outputs"], list):
raise TypeError("'constrain_outputs' must be a list if provided.")

if "optimize_prompt" in options and options["optimize_prompt"] is not None:
if not isinstance(options["optimize_prompt"], bool):
raise TypeError("'optimize_prompt' must be a boolean if provided.")

if "quality_preset" in options and options["quality_preset"] is not None:
if not isinstance(options["quality_preset"], str):
raise TypeError("'quality_preset' must be a string if provided.")

if "regex" in options and options["regex"] is not None:
regex = options["regex"]
if not isinstance(regex, (str, tuple, list)):
raise TypeError("'regex' must be a string, tuple, or list of tuples.")
if isinstance(regex, list) and not all(isinstance(item, tuple) for item in regex):
raise TypeError("All items in 'regex' list must be tuples.")


def process_regex(
column_data: Union[pd.Series, List[str]],
regex: Union[str, Replacement, List[Replacement]],
) -> Union[pd.Series, List[str]]:
column_data: Union["pd.Series", List[str]],
regex: Union[str, Tuple[str, str], List[Tuple[str, str]]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Replacement is a type alias for the Tuple[str, str] type (ref here), not entirely sure why you made this change?

) -> Union["pd.Series", List[str]]:
"""
Performs regex matches or replacements to the given string according to the given matching patterns and replacement strings.

Expand All @@ -153,6 +170,8 @@ def process_regex(
Returns:
Extracted matches to the provided regular expression from each element of the data column (specifically, the first match is returned).
"""
pd = _get_pandas()

if isinstance(column_data, list):
return [get_regex_match_or_replacement(x, regex) for x in column_data]
elif isinstance(column_data, pd.Series):
Expand Down
Loading