Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions matrix/app_server/llm/ray_serve_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,22 @@ def _build_app(cli_args: Dict[str, Any], use_grpc) -> serve.Application:
{"CPU": ray_resources.get("num_cpus", 1), accelerator: 1}
) # for the vLLM actors

ray_resources.pop("num_cpus", None)
ray_resources.pop("num_gpus", None)
custom_resources = ray_resources.pop("resources", None)

# Add custom resources to placement group bundles if specified
# This ensures the deployment is scheduled on nodes with the required resources
if custom_resources:
for bundle in pg_resources:
bundle.update(custom_resources)
# We use the "STRICT_PACK" strategy below to ensure all vLLM actors are placed on
# the same Ray node.
cls = VLLMDeployment if not use_grpc else GrpcDeployment
return cls.options( # type: ignore[union-attr]
placement_group_bundles=pg_resources,
placement_group_strategy="STRICT_PACK" if pp == 1 else "PACK",
**ray_resources,
).bind(
engine_args,
parsed_args.response_role,
Expand Down
4 changes: 4 additions & 0 deletions matrix/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def start_cluster(
force_new_head: bool = False,
use_array: bool = True,
prometheus_scrape_interval: int = 10,
logical_resources: tp.Dict[str, int] | None = None,
) -> tp.Dict[str, tp.Any]:
"""
Starts the Ray cluster with the specified number of workers and additional configuration.
Expand All @@ -96,6 +97,8 @@ def start_cluster(
enable_grafana (bool, optional): If True, enable prometheus and grafana dashboard.
force_new_head (bool): force to remove head.json if haven't run 'matrix stop_cluster'.
use_array (bool): If True, use Slurm job arrays for workers (default: True).
logical_resources (dict, optional): Custom logical resources to add to workers.
Keys are resource names, values are counts. Defaults to empty.

Returns:
None
Expand All @@ -108,6 +111,7 @@ def start_cluster(
force_new_head=force_new_head,
use_array=use_array,
prometheus_scrape_interval=prometheus_scrape_interval,
logical_resources=logical_resources or {},
)
return convert_to_json_compatible(status)

Expand Down
10 changes: 6 additions & 4 deletions matrix/cluster/ray_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def start(
force_new_head: bool = False,
use_array: bool = True,
prometheus_scrape_interval: int = 10,
logical_resources: tp.Dict[str, int] | None = None,
):
"""
Starts a Ray cluster on Slurm.
Expand Down Expand Up @@ -440,7 +441,7 @@ def start(
cluster=executor,
)
s_executor.update_parameters(
name=f"ray_head_{self.cluster_id}",
name=f"matrix_head_{self.cluster_id}",
**head_params,
)
head_job = s_executor.submit(
Expand Down Expand Up @@ -492,18 +493,19 @@ def start(
num_jobs = 1
else:
num_jobs = add_workers
logical_resources = {
worker_logical_resources = {
f"{key}-{value}": 1
for key, value in worker_params.items()
if key in _SLURM_KEY_ALIASES.values()
}
worker_logical_resources.update(logical_resources or {})
print(f"Worker Slurm parameters: {worker_params}")

s_executor = submitit.AutoExecutor(
folder=str(self._log_dir), cluster=executor
)
s_executor.update_parameters(
name=f"ray_worker_{self.cluster_id}",
name=f"matrix_worker_{self.cluster_id}",
**worker_params,
)

Expand All @@ -522,7 +524,7 @@ def start(
cluster_info,
worker_wait_timeout_seconds,
start_wait_time_seconds,
logical_resources,
worker_logical_resources,
worker_params,
)
)
Expand Down
1 change: 1 addition & 0 deletions matrix/cluster/ray_worker_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _start_ray_worker(
worker_env (Dict[str, str]): Worker environment variables
num_cpus (int): Number of CPUs
num_gpus (int): Number of GPUs
logical_resources (str): JSON string of logical resources
"""
subprocess.run(
[
Expand Down