Skip to content

Commit 123cec3

Browse files
comaniacDhakshin Suriakannu
authored andcommitted
[ray.data.llm] Propose log_input_column_names() (ray-project#51441)
## Why are these changes needed? It's tricky for users to implement `preprocess` function when constructing a Processor, because users may not have an idea about what's the input dataset should look like (i.e. what's the expected schema). This PR proposes a new API `log_input_column_names()` that logs the expected schema. Example: ```python import ray from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig processor_config = vLLMEngineProcessorConfig(...) processor = build_llm_processor(...) processor.log_input_column_names() # The first stage of the processor is ChatTemplateStage. # Required input columns: # messages: A list of messages in OpenAI chat format. See https://platform.openai.com/docs/api-reference/chat/create for details. processor_config = vLLMEngineProcessorConfig( apply_chat_template=False, tokenize=False, ) processor = build_llm_processor(...) processor.log_input_column_names() # The first stage of the processor is vLLMEngineStage. # Required input columns: # prompt: The text prompt (str). # sampling_params: The sampling parameters. See https://docs.vllm.ai/en/latest/api/inference_params.html#sampling-parameters for details. # Optional input columns: # tokenized_prompt: The tokenized prompt. If provided, the prompt will not be tokenized by the vLLM engine. # images: The images to generate text from. If provided, the prompt will be a multimodal prompt. # model: The model to use for this request. If the model is different from the model set in the stage, then this is a LoRA request. ``` ## Related issue number <!-- For example: "Closes ray-project#1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [x] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Cody Yu <hao.yu.cody@gmail.com> Signed-off-by: Dhakshin Suriakannu <d_suriakannu@apple.com>
1 parent add057f commit 123cec3

16 files changed

+185
-120
lines changed

doc/source/data/working-with-llms.rst

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,32 +77,26 @@ Upon execution, the Processor object instantiates replicas of the vLLM engine (u
7777

7878
{'answer': 'Snowflakes gently fall\nBlanketing the winter scene\nFrozen peaceful hush'}
7979

80-
Some models may require a Hugging Face token to be specified. You can specify the token in the `runtime_env` argument.
80+
Each processor requires specific input columns. You can find get more info by using the following API:
8181

8282
.. testcode::
8383

84-
config = vLLMEngineProcessorConfig(
85-
model_source="unsloth/Llama-3.1-8B-Instruct",
86-
runtime_env={"env_vars": {"HF_TOKEN": "your_huggingface_token"}},
87-
concurrency=1,
88-
batch_size=64,
89-
)
84+
processor.log_input_column_names()
9085

91-
If your model is hosted on AWS S3, you can specify the S3 path in the `model_source` argument, and specify `load_format="runai_streamer"` in the `engine_kwargs` argument.
86+
.. testoutput::
87+
:options: +MOCK
9288

93-
.. note::
94-
Install vLLM with runai dependencies: `pip install -U "vllm[runai]==0.7.2"`
89+
The first stage of the processor is ChatTemplateStage.
90+
Required input columns:
91+
messages: A list of messages in OpenAI chat format. See https://platform.openai.com/docs/api-reference/chat/create for details.
92+
93+
Some models may require a Hugging Face token to be specified. You can specify the token in the `runtime_env` argument.
9594

9695
.. testcode::
9796

9897
config = vLLMEngineProcessorConfig(
99-
model_source="s3://your-bucket/your-model/", # Make sure adding the trailing slash!
100-
engine_kwargs={"load_format": "runai_streamer"},
101-
runtime_env={"env_vars": {
102-
"AWS_ACCESS_KEY_ID": "your_access_key_id",
103-
"AWS_SECRET_ACCESS_KEY": "your_secret_access_key",
104-
"AWS_REGION": "your_region",
105-
}},
98+
model_source="unsloth/Llama-3.1-8B-Instruct",
99+
runtime_env={"env_vars": {"HF_TOKEN": "your_huggingface_token"}},
106100
concurrency=1,
107101
batch_size=64,
108102
)
@@ -146,7 +140,10 @@ The underlying `Processor` object instantiates replicas of the vLLM engine and a
146140
configure parallel workers to handle model parallelism (for tensor parallelism and pipeline parallelism,
147141
if specified).
148142

149-
To optimize model loading, you can configure the `load_format` to `runai_streamer` or `tensorizer`:
143+
To optimize model loading, you can configure the `load_format` to `runai_streamer` or `tensorizer`.
144+
145+
.. note::
146+
In this case, install vLLM with runai dependencies: `pip install -U "vllm[runai]==0.7.2"`
150147

151148
.. testcode::
152149

@@ -157,6 +154,22 @@ To optimize model loading, you can configure the `load_format` to `runai_streame
157154
batch_size=64,
158155
)
159156

157+
If your model is hosted on AWS S3, you can specify the S3 path in the `model_source` argument, and specify `load_format="runai_streamer"` in the `engine_kwargs` argument.
158+
159+
.. testcode::
160+
161+
config = vLLMEngineProcessorConfig(
162+
model_source="s3://your-bucket/your-model/", # Make sure adding the trailing slash!
163+
engine_kwargs={"load_format": "runai_streamer"},
164+
runtime_env={"env_vars": {
165+
"AWS_ACCESS_KEY_ID": "your_access_key_id",
166+
"AWS_SECRET_ACCESS_KEY": "your_secret_access_key",
167+
"AWS_REGION": "your_region",
168+
}},
169+
concurrency=1,
170+
batch_size=64,
171+
)
172+
160173
To do multi-LoRA batch inference, you need to set LoRA related parameters in `engine_kwargs`. See :doc:`the vLLM with LoRA example</llm/examples/batch/vllm-with-lora>` for details.
161174

162175
.. testcode::

python/ray/data/llm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
147147
),
148148
)
149149
150+
# The processor requires specific input columns, which depend on
151+
# your processor config. You can use the following API to check
152+
# the required input columns:
153+
processor.log_input_column_names()
154+
# Example log:
155+
# The first stage of the processor is ChatTemplateStage.
156+
# Required input columns:
157+
# messages: A list of messages in OpenAI chat format.
158+
150159
ds = ray.data.range(300)
151160
ds = processor(ds)
152161
for row in ds.take_all():

python/ray/llm/_internal/batch/processor/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections import OrderedDict
23
from typing import Optional, List, Type, Callable, Dict
34

@@ -15,6 +16,9 @@
1516
from ray.llm._internal.common.base_pydantic import BaseModelExtended
1617

1718

19+
logger = logging.getLogger(__name__)
20+
21+
1822
class ProcessorConfig(BaseModelExtended):
1923
"""The processor configuration."""
2024

@@ -158,6 +162,25 @@ def get_stage_by_name(self, name: str) -> StatefulStage:
158162
return self.stages[name]
159163
raise ValueError(f"Stage {name} not found")
160164

165+
def log_input_column_names(self):
166+
"""Log.info the input stage and column names of this processor.
167+
If the input dataset does not contain these columns, you have to
168+
provide a preprocess function to bridge the gap.
169+
"""
170+
name, stage = list(self.stages.items())[0]
171+
expected_input_keys = stage.get_required_input_keys()
172+
optional_input_keys = stage.get_optional_input_keys()
173+
174+
message = f"The first stage of the processor is {name}."
175+
if expected_input_keys:
176+
message += "\nRequired input columns:\n"
177+
message += "\n".join(f"\t{k}: {v}" for k, v in expected_input_keys.items())
178+
if optional_input_keys:
179+
message += "\nOptional input columns:\n"
180+
message += "\n".join(f"\t{k}: {v}" for k, v in optional_input_keys.items())
181+
182+
logger.info(message)
183+
161184

162185
@DeveloperAPI
163186
class ProcessorBuilder:

python/ray/llm/_internal/batch/stages/base.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""The base class for all stages."""
22
import logging
3-
from typing import Any, Dict, AsyncIterator, List, Callable, Type
3+
from typing import Any, Dict, AsyncIterator, List, Callable, Type, Optional
44

55
import pyarrow
66
from pydantic import BaseModel, Field
@@ -71,14 +71,18 @@ class StatefulStageUDF:
7171
__call__ method will take the data column as the input of the udf
7272
method, and encapsulate the output of the udf method into the data
7373
column for the next stage.
74+
expected_input_keys: The expected input keys of the stage.
7475
"""
7576

7677
# The internal column name for the index of the row in the batch.
7778
# This is used to align the output of the UDF with the input batch.
7879
IDX_IN_BATCH_COLUMN: str = "__idx_in_batch"
7980

80-
def __init__(self, data_column: str):
81+
def __init__(
82+
self, data_column: str, expected_input_keys: Optional[List[str]] = None
83+
):
8184
self.data_column = data_column
85+
self.expected_input_keys = set(expected_input_keys or [])
8286

8387
async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]:
8488
"""A stage UDF wrapper that processes the input and output columns
@@ -195,8 +199,6 @@ def validate_inputs(self, inputs: List[Dict[str, Any]]):
195199
Raises:
196200
ValueError: If the required keys are not found.
197201
"""
198-
expected_input_keys = set(self.expected_input_keys)
199-
200202
for inp in inputs:
201203
input_keys = set(inp.keys())
202204

@@ -206,26 +208,16 @@ def validate_inputs(self, inputs: List[Dict[str, Any]]):
206208
"for internal use."
207209
)
208210

209-
if not expected_input_keys:
211+
if not self.expected_input_keys:
210212
continue
211213

212-
missing_required = expected_input_keys - input_keys
214+
missing_required = self.expected_input_keys - input_keys
213215
if missing_required:
214216
raise ValueError(
215217
f"Required input keys {missing_required} not found at the input of "
216218
f"{self.__class__.__name__}. Input keys: {input_keys}"
217219
)
218220

219-
@property
220-
def expected_input_keys(self) -> List[str]:
221-
"""A list of required input keys. Missing required keys will raise
222-
an exception.
223-
224-
Returns:
225-
A list of required input keys.
226-
"""
227-
return []
228-
229221
async def udf(self, rows: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]:
230222
raise NotImplementedError("StageUDF must implement the udf method")
231223

@@ -247,6 +239,14 @@ class StatefulStage(BaseModel):
247239
description="The arguments of .map_batches(). Default {'concurrency': 1}.",
248240
)
249241

242+
def get_required_input_keys(self) -> Dict[str, str]:
243+
"""The required input keys of the stage and their descriptions."""
244+
return {}
245+
246+
def get_optional_input_keys(self) -> Dict[str, str]:
247+
"""The optional input keys of the stage and their descriptions."""
248+
return {}
249+
250250
def get_dataset_map_batches_kwargs(
251251
self,
252252
batch_size: int,
@@ -280,6 +280,9 @@ def get_dataset_map_batches_kwargs(
280280
)
281281

282282
kwargs["fn_constructor_kwargs"]["data_column"] = data_column
283+
kwargs["fn_constructor_kwargs"]["expected_input_keys"] = list(
284+
self.get_required_input_keys().keys()
285+
)
283286
return kwargs
284287

285288
class Config:

python/ray/llm/_internal/batch/stages/chat_template_stage.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class ChatTemplateUDF(StatefulStageUDF):
1313
def __init__(
1414
self,
1515
data_column: str,
16+
expected_input_keys: List[str],
1617
model: str,
1718
chat_template: Optional[str] = None,
1819
):
@@ -21,14 +22,15 @@ def __init__(
2122
2223
Args:
2324
data_column: The data column name.
25+
expected_input_keys: The expected input keys of the stage.
2426
model: The model to use for the chat template.
2527
chat_template: The chat template in Jinja template format. This is
2628
usually not needed if the model checkpoint already contains the
2729
chat template.
2830
"""
2931
from transformers import AutoProcessor
3032

31-
super().__init__(data_column)
33+
super().__init__(data_column, expected_input_keys)
3234

3335
# NOTE: We always use processor instead of tokenizer in this stage,
3436
# because tokenizers of VLM models may not have chat template attribute.
@@ -95,15 +97,18 @@ def _should_add_generation_prompt(self, conversation: List[Dict[str, Any]]) -> b
9597
"""
9698
return conversation[-1]["role"] == "user"
9799

98-
@property
99-
def expected_input_keys(self) -> List[str]:
100-
"""The expected input keys."""
101-
return ["messages"]
102-
103100

104101
class ChatTemplateStage(StatefulStage):
105102
"""
106103
A stage that applies chat template.
107104
"""
108105

109106
fn: Type[StatefulStageUDF] = ChatTemplateUDF
107+
108+
def get_required_input_keys(self) -> Dict[str, str]:
109+
"""The required input keys of the stage and their descriptions."""
110+
return {
111+
"messages": "A list of messages in OpenAI chat format. "
112+
"See https://platform.openai.com/docs/api-reference/chat/create "
113+
"for details."
114+
}

python/ray/llm/_internal/batch/stages/http_request_stage.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class HttpRequestUDF(StatefulStageUDF):
1313
def __init__(
1414
self,
1515
data_column: str,
16+
expected_input_keys: List[str],
1617
url: str,
1718
additional_header: Optional[Dict[str, Any]] = None,
1819
qps: Optional[int] = None,
@@ -22,11 +23,12 @@ def __init__(
2223
2324
Args:
2425
data_column: The data column name.
26+
expected_input_keys: The expected input keys of the stage.
2527
url: The URL to send the HTTP request to.
2628
additional_header: The additional headers to send with the HTTP request.
2729
qps: The maximum number of requests per second.
2830
"""
29-
super().__init__(data_column)
31+
super().__init__(data_column, expected_input_keys)
3032
self.url = url
3133
self.additional_header = additional_header or {}
3234
self.qps = qps
@@ -90,14 +92,17 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
9092
"http_response": resp_json,
9193
}
9294

93-
@property
94-
def expected_input_keys(self) -> List[str]:
95-
return ["payload"]
96-
9795

9896
class HttpRequestStage(StatefulStage):
9997
"""
10098
A stage that sends HTTP requests.
10199
"""
102100

103101
fn: Type[StatefulStageUDF] = HttpRequestUDF
102+
103+
def get_required_input_keys(self) -> Dict[str, str]:
104+
"""The required input keys of the stage and their descriptions."""
105+
return {
106+
"payload": "The payload to send to the HTTP request. "
107+
"It should be in JSON format."
108+
}

python/ray/llm/_internal/batch/stages/prepare_image_stage.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ async def process(self, images: List[_ImageType]) -> List["Image.Image"]:
304304

305305

306306
class PrepareImageUDF(StatefulStageUDF):
307-
def __init__(self, data_column: str):
308-
super().__init__(data_column)
307+
def __init__(self, data_column: str, expected_input_keys: List[str]):
308+
super().__init__(data_column, expected_input_keys)
309309
self.Image = importlib.import_module("PIL.Image")
310310
self.image_processor = ImageProcessor()
311311

@@ -365,13 +365,16 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
365365
img_start_idx += num_images_in_req
366366
yield ret
367367

368-
@property
369-
def expected_input_keys(self) -> List[str]:
370-
"""The expected input keys."""
371-
return ["messages"]
372-
373368

374369
class PrepareImageStage(StatefulStage):
375370
"""A stage to prepare images from OpenAI chat template messages."""
376371

377372
fn: StatefulStageUDF = PrepareImageUDF
373+
374+
def get_required_input_keys(self) -> Dict[str, str]:
375+
"""The required input keys of the stage and their descriptions."""
376+
return {
377+
"messages": "A list of messages in OpenAI chat format. "
378+
"See https://platform.openai.com/docs/api-reference/chat/create "
379+
"for details."
380+
}

0 commit comments

Comments
 (0)