Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#16391: propagate sub_device_ids to mesh #16410

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions models/demos/t3000/llama2_70b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
from loguru import logger
import re
from typing import Tuple
from typing import Tuple, List
import numpy as np
import torch
import ttnn
Expand Down Expand Up @@ -70,8 +70,11 @@ def __init__(self, mesh_device, dims, cluster_shape):
self.cluster_shape = cluster_shape
self.mesh_device = mesh_device

def compose(self, tensor: ttnn.Tensor) -> torch.Tensor:
tt_shards = [ttnn.to_torch(tt_input_tensor) for tt_input_tensor in ttnn.get_device_tensors(tensor)]
def compose(self, tensor: ttnn.Tensor, sub_device_ids: List[ttnn.SubDeviceId] = []) -> torch.Tensor:
tt_shards = [
ttnn.to_torch(tt_input_tensor, sub_device_ids=sub_device_ids)
for tt_input_tensor in ttnn.get_device_tensors(tensor)
]

row_concat = []
for cluster_row in range(self.cluster_shape[1]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,9 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
False,
)
output_mem_config = ttnn.MemoryConfig(tensor_memory_layout, buffer_type=buffer_type, shard_spec=output_shard_spec)
ttnn_tensor = ttnn.from_torch(
full_input_tensor_unfractured,
tile=ttnn.Tile(tile),
dtype=input_dtype,
device=mesh_device,
layout=layout,
memory_config=input_mem_config,
mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dims),
)
ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device)

worker_subdevice_ids = []
fabric_torn_down = True
if use_all_gather_async:
compute_grid_size = mesh_device.compute_with_storage_grid_size()
worker_sub_device = ttnn.SubDevice(
Expand All @@ -205,89 +197,117 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
]
)
worker_sub_device_id = ttnn.SubDeviceId(0)
worker_subdevice_ids = [worker_sub_device_id]
if create_persistent_fabric:
logger.info("Create persistent fabric interface")
mesh_sub_device_manager_id = create_and_load_sub_device_manager_with_fabric_interface(
mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric
)
logger.info("Done Create persistent fabric interface")

# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
if trace_mode:
ttnn_tensor_out = run_with_trace(
input_tensor=ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
output_mem_config=output_mem_config,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
fabric_torn_down = False
elif teardown_persistent_fabric:
fabric_torn_down = False

try:
ttnn_tensor = ttnn.from_torch(
full_input_tensor_unfractured,
tile=ttnn.Tile(tile),
dtype=input_dtype,
device=mesh_device,
layout=layout,
memory_config=input_mem_config,
mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dims),
sub_device_ids=worker_subdevice_ids,
)
else:
for _ in range(num_iters):
if use_all_gather_async:
logger.info("Running all-gather async")
ttnn_tensor_out = ttnn.experimental.all_gather_async(
ttnn_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
create_semaphore_handles=True,
)
else:
ttnn_tensor_out = ttnn.all_gather(
ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=ttnn.Topology.Linear,
)
ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, sub_device_ids=worker_subdevice_ids)

# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
if trace_mode:
ttnn_tensor_out = run_with_trace(
input_tensor=ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
output_mem_config=output_mem_config,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
)
else:
for _ in range(num_iters):
if use_all_gather_async:
ttnn_tensor_out = ttnn.experimental.all_gather_async(
ttnn_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
create_semaphore_handles=True,
)
else:
ttnn_tensor_out = ttnn.all_gather(
ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=ttnn.Topology.Linear,
)

if enable_persistent_fabric:
logger.info(f"Waiting for op completion")
for d in mesh_device.get_devices():
ttnn.synchronize_device(d, sub_device_ids=worker_subdevice_ids)
logger.info(f"Done synchronizing with op")

# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor_out)
tt_output_tensor = ttnn.to_torch(
ttnn_tensor_out,
mesh_composer=ConcatMesh2dToTensor(mesh_device, mesh_shape=mesh_shape, dims=concat_dims),
sub_device_ids=worker_subdevice_ids,
)
output_tensors_list = torch.chunk(
tt_output_tensor, num_all_gather_instances, dim=all_gather_instances_concat_dim
)
output_golden = torch.zeros(tt_output_tensor.shape)

# Repeat the input tensor to represent the fact that the full concatenated input tensor lives across every
# device in the line
repeat_factor = [1] * len(output_golden.shape)
repeat_factor[dim] = num_devices_per_line
output_golden[:, :, :, :] = full_input_tensor_unfractured.repeat(repeat_factor)

eq = True
logger.info("Comparing output tensors")
if input_dtype == ttnn.bfloat16:
eq, output = comp_equal(tt_output_tensor, output_golden)
if not eq and debug is True:
logger.error(f"found mismatches")
report_mismatches(tt_output_tensor, output_golden, 100)
print_tile_corners_of_tensor(tt_output_tensor)
else:
eq, output = comp_pcc(tt_output_tensor, output_golden)
if not eq:
logger.error(f"output mismatch for tensor: {output}")

if enable_persistent_fabric:
logger.info(f"Waiting for op {i}")
for d in mesh_device.get_devices():
ttnn.synchronize_device(d, sub_device_ids=[worker_sub_device_id])
logger.info(f"Done iteration {i}")
logger.info("Done op call")

if enable_persistent_fabric and teardown_persistent_fabric:
logger.info("Tearing down persistent fabric interface")
teardown_fabric_interface(mesh_device)
logger.info("Done tearing down persistent fabric interface")
if enable_persistent_fabric and teardown_persistent_fabric:
logger.info("Tearing down persistent fabric interface")
teardown_fabric_interface(mesh_device)
logger.info("Done tearing down persistent fabric interface")
fabric_torn_down = True

# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor_out)
tt_output_tensor = ttnn.to_torch(
ttnn_tensor_out, mesh_composer=ConcatMesh2dToTensor(mesh_device, mesh_shape=mesh_shape, dims=concat_dims)
)
output_tensors_list = torch.chunk(tt_output_tensor, num_all_gather_instances, dim=all_gather_instances_concat_dim)
output_golden = torch.zeros(tt_output_tensor.shape)

# Repeat the input tensor to represent the fact that the full concatenated input tensor lives across every
# device in the line
repeat_factor = [1] * len(output_golden.shape)
repeat_factor[dim] = num_devices_per_line
output_golden[:, :, :, :] = full_input_tensor_unfractured.repeat(repeat_factor)

eq = True
if input_dtype == ttnn.bfloat16:
eq, output = comp_equal(tt_output_tensor, output_golden)
if not eq and debug is True:
logger.error(f"found mismatches")
report_mismatches(tt_output_tensor, output_golden, 100)
print_tile_corners_of_tensor(tt_output_tensor)
else:
eq, output = comp_pcc(tt_output_tensor, output_golden)
if not eq:
logger.error(f"output mismatch for tensor: {output}")

assert eq, f"FAILED: {output}"
assert eq, f"FAILED: {output}"

except Exception as e:
if create_persistent_fabric and not fabric_torn_down:
logger.error(f"Tearing down persistent fabric after failure")
teardown_fabric_interface(mesh_device)
raise e


# Enumerate the post-commit cases explicitly
Expand Down
Loading
Loading