diff --git a/vllm/worker/tt_device.py b/vllm/worker/tt_device.py index d8150b69ed83..893e9dab9269 100644 --- a/vllm/worker/tt_device.py +++ b/vllm/worker/tt_device.py @@ -84,7 +84,7 @@ def create_mesh_device(device_params: Optional[Dict] = None): updated_device_params = get_updated_device_params(params) device_ids = ttnn.get_device_ids() - + env_mesh_shape = parse_mesh_shape_from_env() if env_mesh_shape: default_mesh_shape = env_mesh_shape diff --git a/vllm/worker/tt_worker.py b/vllm/worker/tt_worker.py index 4fad34459a2f..afa41e097d7b 100644 --- a/vllm/worker/tt_worker.py +++ b/vllm/worker/tt_worker.py @@ -505,6 +505,8 @@ def get_fabric_config(override_tt_config, num_devices): "FABRIC_1D_RING": ttnn.FabricConfig.FABRIC_1D_RING, "FABRIC_2D": ttnn.FabricConfig.FABRIC_2D, "CUSTOM": ttnn.FabricConfig.CUSTOM, + "FABRIC_2D_DYNAMIC": ttnn.FabricConfig.FABRIC_2D_DYNAMIC, + "FABRIC_2D_DYNAMIC_TORUS_XY": ttnn.FabricConfig.FABRIC_2D_DYNAMIC_TORUS_XY, } fabric_config = fabric_config_map.get(fabric_config_str) assert fabric_config is not None, ( @@ -620,6 +622,25 @@ def get_mesh_grid(local_dp_rank=0): return mesh_grid +def get_dispatch_core_axis(override_tt_config): + dispatch_core_axis: ttnn.DispatchCoreAxis = ttnn.DispatchCoreAxis.ROW + + if override_tt_config is None: + return dispatch_core_axis + + dispatch_core_axis_config = override_tt_config.get("dispatch_core_axis", None) + + if dispatch_core_axis_config is None: + return dispatch_core_axis + + assert dispatch_core_axis_config in ["row", "col"], ( + f"Invalid dispatch_core_axis: {dispatch_core_axis_config}. " + "Expected: row, col.") + dispatch_core_axis = (ttnn.DispatchCoreAxis.COL + if dispatch_core_axis_config == "col" + else ttnn.DispatchCoreAxis.ROW) + return dispatch_core_axis + def open_mesh_device(override_tt_config, trace_mode, local_dp_rank=0): assert local_dp_rank == 0, "open_mesh_device must run on local DP rank 0" mesh_grid = get_mesh_grid(local_dp_rank) @@ -637,9 +658,12 @@ def open_mesh_device(override_tt_config, trace_mode, local_dp_rank=0): # dispatch_core_config=get_dispatch_core_config(override_tt_config), # **device_params, # ) + device_params = {"trace_region_size": 95449088} if fabric_config: device_params["fabric_config"] = fabric_config + device_params["dispatch_core_axis"] = get_dispatch_core_axis(override_tt_config) + mesh_device = create_mesh_device(device_params) # set_and_get_device_cache(mesh_device) # logger.info("multidevice with %d devices and grid %s is created",