Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/worker/tt_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# from tests.scripts.common import get_updated_device_params
import ttnn


Check failure on line 7 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:7:81: E501 Line too long (124 > 80)
# Device helpers to be shared with top-level conftest.py and other conftest.py files that will handle open/close of devices.


Expand All @@ -17,18 +17,18 @@
dispatch_core_type = new_device_params.pop("dispatch_core_type", None)
fabric_tensix_config = new_device_params.get("fabric_tensix_config", None)

if ttnn.device.is_blackhole():

Check failure on line 20 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:20:81: E501 Line too long (133 > 80)
# Only when both fabric_config and fabric_tensix_config are set, we can use ROW dispatch, otherwise force to use COL dispatch
fabric_config = new_device_params.get("fabric_config", None)
if not (fabric_config and fabric_tensix_config):
# When not both are set, force COL dispatch
if dispatch_core_axis == ttnn.DispatchCoreAxis.ROW:
logger.warning(

Check failure on line 26 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:26:81: E501 Line too long (110 > 80)
"ROW dispatch requires both fabric and tensix config, using DispatchCoreAxis.COL instead."
)
dispatch_core_axis = ttnn.DispatchCoreAxis.COL
elif fabric_config and fabric_tensix_config:
logger.warning(

Check failure on line 31 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:31:81: E501 Line too long (131 > 80)
f"Blackhole with fabric_config and fabric_tensix_config enabled, using fabric_tensix_config={fabric_tensix_config}"
)

Expand Down Expand Up @@ -58,13 +58,13 @@
except (ValueError, AttributeError) as e:
raise ValueError(f"Invalid TT_MESHDEVICE_SHAPE format '{env_shape}': expected 'rows,cols' (e.g., '2,4')")


def create_mesh_device(device_params: Optional[Dict] = None):

Check failure on line 62 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:62:81: E501 Line too long (99 > 80)
"""Create mesh device with appropriate mesh shape based on available devices.

Check failure on line 63 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (B904)

vllm/worker/tt_device.py:61:9: B904 Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling

Mesh shape selection:
- TT_MESHDEVICE_SHAPE env var (e.g., "2,4" or "1,8") if set
- Galaxy (32 devices): 4x8

Check failure on line 67 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:67:81: E501 Line too long (81 > 80)
- T3000 (8+ devices): 2x4
- Single/Few devices: 1x{num_devices}

Expand All @@ -73,7 +73,7 @@

Returns:
Initialized mesh device
"""

Check failure on line 76 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:76:81: E501 Line too long (90 > 80)
global _printed_device_info

params = dict(device_params or {})
Expand All @@ -84,7 +84,7 @@

updated_device_params = get_updated_device_params(params)
device_ids = ttnn.get_device_ids()

env_mesh_shape = parse_mesh_shape_from_env()
if env_mesh_shape:
default_mesh_shape = env_mesh_shape
Expand All @@ -111,7 +111,7 @@
if "trace_region_size" in params:
trace_mb = params["trace_region_size"] // (1024 * 1024)
print(f" Trace Region Size: {trace_mb}MB")
print(f"=" * 60)

Check failure on line 114 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:114:81: E501 Line too long (84 > 80)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created with shape {mesh_device.shape}")

Expand Down
24 changes: 24 additions & 0 deletions vllm/worker/tt_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,8 @@ def get_fabric_config(override_tt_config, num_devices):
"FABRIC_1D_RING": ttnn.FabricConfig.FABRIC_1D_RING,
"FABRIC_2D": ttnn.FabricConfig.FABRIC_2D,
"CUSTOM": ttnn.FabricConfig.CUSTOM,
"FABRIC_2D_DYNAMIC": ttnn.FabricConfig.FABRIC_2D_DYNAMIC,
"FABRIC_2D_DYNAMIC_TORUS_XY": ttnn.FabricConfig.FABRIC_2D_DYNAMIC_TORUS_XY,
}
fabric_config = fabric_config_map.get(fabric_config_str)
assert fabric_config is not None, (
Expand Down Expand Up @@ -620,6 +622,25 @@ def get_mesh_grid(local_dp_rank=0):
return mesh_grid


def get_dispatch_core_axis(override_tt_config):
dispatch_core_axis: ttnn.DispatchCoreAxis = ttnn.DispatchCoreAxis.ROW

if override_tt_config is None:
return dispatch_core_axis

dispatch_core_axis_config = override_tt_config.get("dispatch_core_axis", None)

if dispatch_core_axis_config is None:
return dispatch_core_axis

assert dispatch_core_axis_config in ["row", "col"], (
f"Invalid dispatch_core_axis: {dispatch_core_axis_config}. "
"Expected: row, col.")
dispatch_core_axis = (ttnn.DispatchCoreAxis.COL
if dispatch_core_axis_config == "col"
else ttnn.DispatchCoreAxis.ROW)
return dispatch_core_axis

def open_mesh_device(override_tt_config, trace_mode, local_dp_rank=0):
assert local_dp_rank == 0, "open_mesh_device must run on local DP rank 0"
mesh_grid = get_mesh_grid(local_dp_rank)
Expand All @@ -637,9 +658,12 @@ def open_mesh_device(override_tt_config, trace_mode, local_dp_rank=0):
# dispatch_core_config=get_dispatch_core_config(override_tt_config),
# **device_params,
# )

device_params = {"trace_region_size": 95449088}
if fabric_config:
device_params["fabric_config"] = fabric_config
device_params["dispatch_core_axis"] = get_dispatch_core_axis(override_tt_config)

mesh_device = create_mesh_device(device_params)
# set_and_get_device_cache(mesh_device)
# logger.info("multidevice with %d devices and grid %s is created",
Expand Down
Loading