Skip to content

[llm] ray.llm support custom accelerators #51359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/ray/llm/_internal/batch/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class ProcessorConfig(BaseModelExtended):
"You can tune the batch size to balance the throughput and fault-tolerance "
"based on your use case. Defaults to 64.",
)
resources_per_worker: Optional[Dict[str, float]] = Field(
default=None,
description="This will override the default resources config for actors/workers, "
"the default resource config for LLM Stage may be something like {'GPU': 1}.",
)
accelerator_type: Optional[str] = Field(
default=None,
description="The accelerator type used by the LLM stage in a processor. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def build_vllm_engine_processor(
# This is used to make sure we overlap batches to avoid the tail
# latency of each batch.
max_concurrency=config.max_concurrent_batches,
resources=config.resources_per_worker,
accelerator_type=config.accelerator_type,
runtime_env=config.runtime_env,
),
Expand Down
36 changes: 25 additions & 11 deletions python/ray/llm/_internal/batch/stages/vllm_engine_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,26 +573,30 @@ def __del__(self):
self.llm.shutdown()


def _ray_scheduling_strategy_fn(num_gpus_per_instance: int, accelerator_type: str):
def _ray_scheduling_strategy_fn(
num_workers_per_instance: int,
accelerator_type: str,
resources: Optional[Dict[str, float]] = None,
):
"""
Create a Ray scheduling strategy for vLLM engine.

Args:
num_gpus_per_instance: The number of GPUs per instance.
num_workers_per_instance: The number of workers per instance.
accelerator_type: The accelerator type.

Returns:
The Ray scheduling strategy.
"""

def _get_bundle() -> Dict[str, float]:
bundle: Dict[str, float] = {"GPU": 1, "CPU": 1}
bundle: Dict[str, float] = resources if resources else {"GPU": 1, "CPU": 1}
if accelerator_type:
bundle[f"accelerator_type:{accelerator_type}"] = 0.001
return bundle

pg = ray.util.placement_group(
[_get_bundle()] * num_gpus_per_instance,
[_get_bundle()] * num_workers_per_instance,
strategy="STRICT_PACK",
)
return dict(
Expand Down Expand Up @@ -621,6 +625,7 @@ def post_init(cls, values):
The updated values.
"""
map_batches_kwargs = values["map_batches_kwargs"]
resources_per_worker = map_batches_kwargs.get("resources")
accelerator_type = map_batches_kwargs.get("accelerator_type", "")
fn_constructor_kwargs = values["fn_constructor_kwargs"]
engine_kwargs = fn_constructor_kwargs.get("engine_kwargs", {})
Expand All @@ -629,29 +634,38 @@ def post_init(cls, values):
if accelerator_type:
ray_remote_args["accelerator_type"] = accelerator_type

# Setup num_gpus required per vLLM engine.
# Setup num_workers required per vLLM engine.
tp_size = engine_kwargs.get("tensor_parallel_size", 1)
pp_size = engine_kwargs.get("pipeline_parallel_size", 1)
num_gpus = tp_size * pp_size
num_workers = tp_size * pp_size

# Use the MP backend by default.
engine_kwargs.setdefault("distributed_executor_backend", "mp")
executor_backend = engine_kwargs.get("distributed_executor_backend")

# When Ray is used in the vLLM engine, we set num_gpus to 0 so that
# When Ray is used in the vLLM engine, we set num_devices to 0 so that
# Ray Data won't reserve GPUs in advance. Instead, we specify scheduling
# strategy in .map_batches() arguments and let vLLM Ray executor to
# create placement groups for each TP/PP worker.
if executor_backend == "ray" and num_gpus > 1:
num_mp_workers = num_workers
if executor_backend == "ray" and num_workers > 1:
# Note that we have to use partial() to pass a function
# instead of an object.
map_batches_kwargs["ray_remote_args_fn"] = partial(
_ray_scheduling_strategy_fn,
num_gpus,
num_workers,
accelerator_type,
resources_per_worker,
)
num_gpus = 0
num_mp_workers = 0

if not resources_per_worker:
map_batches_kwargs["num_gpus"] = num_mp_workers
else:
ray_remote_args["resources"] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize now that there is some naming confusion. resource_per_worker in the top level api is referring to the resource required per worker within the replica while it might also be interpreted as resources per replica. Say you want to do tp=2 and pp=2 on NPUs. Then is resource_per_worker={"NPU": 4} the correct value or is resource_per_worker={"NPU": 1} the right thing. Worker could mean num workers seen from Ray's perspective. Inside this function however, resource_per_worker seems to be referring to resource_per_vllm_worker which is number of workers from vllm's perspective. We need to find a consistent naming to differentiate them. Here is my suggested implementation. I can push over your changes if that's ok?

  1. Change the ray_scheduling_strategy_fn to be more explicit about the meaning of these items and what can be None and what cannot be None (accelerator_type can be None and also should be ignored when custom resources are passed in).
def _ray_scheduling_strategy_fn(
    num_bundles_per_replica: int, 
    accelerator_type: Optional[str] = None,
    resources_per_bundle: Optional[Dict[str, float]] = None,
):
    """Create a Ray scheduling strategy for the engine.

    Args:
        num_bundles_per_replica: The number of device bundles per 
            engine replica.
        accelerator_type: The accelerator type. If None, the 
            accelerator_type label will not be set.
        resources_per_bundle: The custom resources per bundle. 
            If None, we default to 1xGPU + 1xCPU bundle.

    Returns:
        The Ray scheduling strategy.
    """

    def _get_bundle() -> Dict[str, float]:

        # Custom resources
        if resources_per_bundle:
            return resources_per_bundle
        
        # GPU bundles
        bundle = {"GPU": 1, "CPU": 1}
        if accelerator_type:
            bundle[f"accelerator_type:{accelerator_type}"] = 0.001
        return bundle

    pg = ray.util.placement_group(
        [_get_bundle()] * num_bundles_per_replica,
        strategy="STRICT_PACK",
    )
    return dict(
        scheduling_strategy=PlacementGroupSchedulingStrategy(
            pg, placement_group_capture_child_tasks=True
        )
    )
  1. Change the stage postinit implementation to reflect the new names and consistently use ray_remote_args in case the ray_remote_args_fn condition does not get exercised.
class vLLMEngineStage(StatefulStage):
    """
    A stage that runs vLLM engine.
    """

    fn: Type[StatefulStageUDF] = vLLMEngineStageUDF

    @root_validator(pre=True)
    def post_init(cls, values):
        """Post-initialize the stage. Specifically,
        this function determines the num_gpus and Ray remote args
        for the .map_batches() call in this stage.

        Args:
            values: The raw stage values.
        Returns:
            The updated values.
        """
        map_batches_kwargs = values["map_batches_kwargs"]
        resources_per_bundle = map_batches_kwargs.get("resources_per_bundle")
        accelerator_type = map_batches_kwargs.get("accelerator_type", "")
        fn_constructor_kwargs = values["fn_constructor_kwargs"]
        engine_kwargs = fn_constructor_kwargs.get("engine_kwargs", {})

        ray_remote_args = {}
        if accelerator_type:
            ray_remote_args["accelerator_type"] = accelerator_type

        # Setup num_workers required per vLLM engine.
        tp_size = engine_kwargs.get("tensor_parallel_size", 1)
        pp_size = engine_kwargs.get("pipeline_parallel_size", 1)
        num_bundles_per_replica = tp_size * pp_size

        # Use the MP backend by default.
        engine_kwargs.setdefault("distributed_executor_backend", "mp")
        executor_backend = engine_kwargs.get("distributed_executor_backend")

        # When Ray is used in the vLLM engine, we set num_devices to 0 so that
        # Ray Data won't reserve GPUs in advance. Instead, we specify scheduling
        # strategy in .map_batches() arguments and let vLLM Ray executor to
        # create placement groups for each TP/PP worker.
        if executor_backend == "ray" and num_bundles_per_replica > 1:
            # Note that we have to use partial() to pass a function
            # instead of an object.
            map_batches_kwargs["ray_remote_args_fn"] = partial(
                _ray_scheduling_strategy_fn,
                num_bundles_per_replica,
                accelerator_type,
                resources_per_bundle,
            )

        if not resources_per_bundle:
            # Default to GPUs per bundle if custom resources are not specified.
            ray_remote_args["num_gpus"] = num_bundles_per_replica
        else:
            ray_remote_args["resources"] = {
                resource_key: resource_count * num_bundles_per_replica
                for resource_key, resource_count in resources_per_bundle.items()
            }

        map_batches_kwargs.update(ray_remote_args)
        return values
  1. Reflect the name resource_per_bundle to public to save the user from the confusion.

key: value * num_mp_workers
for key, value in resources_per_worker.items()
}

map_batches_kwargs["num_gpus"] = num_gpus
map_batches_kwargs.update(ray_remote_args)
return values
6 changes: 6 additions & 0 deletions python/ray/llm/_internal/serve/configs/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ class LLMConfig(BaseModelExtended):
),
)

resources_per_worker: Optional[Dict[str, float]] = Field(
default=None,
description="This will pass to config like `VLLMEngineConfig` and override "
"the resources config for the workers in vLLM engine.",
)

accelerator_type: Optional[str] = Field(
default=None,
description=f"The type of accelerator runs the model on. Only the following values are supported: {str([t.value for t in GPUType])}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class VLLMEngineConfig(BaseModelExtended):
None,
description="Configuration for cloud storage mirror. This is for where the weights are downloaded from.",
)
resources_per_worker: Optional[Dict[str, float]] = Field(
default=None,
description="This overrides the vLLM engine worker's default resource configuration, "
"the number of resources returned by `placement_bundles`.",
)
accelerator_type: Optional[GPUType] = Field(
None,
description="The type of accelerator to use. This is used to determine the placement group strategy.",
Expand Down Expand Up @@ -104,6 +109,7 @@ def from_llm_config(cls, llm_config: LLMConfig) -> "VLLMEngineConfig":
model_id=llm_config.model_id,
hf_model_id=hf_model_id,
mirror_config=mirror_config,
resources_per_worker=llm_config.resources_per_worker,
accelerator_type=llm_config.accelerator_type,
engine_kwargs=llm_config.engine_kwargs,
runtime_env=llm_config.runtime_env,
Expand Down Expand Up @@ -134,7 +140,10 @@ def placement_strategy(self) -> str:

@property
def placement_bundles(self) -> List[Dict[str, float]]:
bundle = {"GPU": 1}
if not self.resources_per_worker:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar change should happen here as well.

bundle = {"GPU": 1}
else:
bundle = self.resources_per_worker
if self.accelerator_type:
bundle[self.ray_accelerator_type()] = 0.001
bundles = [bundle for _ in range(self.num_gpu_workers)]
Expand Down
2 changes: 2 additions & 0 deletions python/ray/util/accelerators/accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
GOOGLE_TPU_V5P = "TPU-V5P"
GOOGLE_TPU_V5LITEPOD = "TPU-V5LITEPOD"
GOOGLE_TPU_V6E = "TPU-V6E"
HUAWEI_NPU_910B = "Ascend910B"
HUAWEI_NPU_910B4 = "Ascend910B4"

# Use these instead of NVIDIA_A100 if you need a specific accelerator size. Note that
# these labels are not auto-added to nodes, you'll have to add them manually in
Expand Down