Skip to content

Commit

Permalink
#7560: Add SD tests to FD nightly with legacy pass
Browse files Browse the repository at this point in the history
  • Loading branch information
mtatsumiTT committed Apr 30, 2024
1 parent 7b6fdb4 commit 984b420
Show file tree
Hide file tree
Showing 17 changed files with 112 additions and 33 deletions.
3 changes: 3 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def pytest_addoption(parser):
help="Path to json file with inputs",
)
parser.addoption("--cli-input", action="store", default=None, help="Enter prompt if --input-method=cli")
parser.addoption(
"--option", action="store", default="", help="Selectively run legacy pass for SD tests if --option legacy"
)


def pytest_generate_tests(metafunc):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __call__(
norm_elementwise_affine: bool = True,
attention_bias: bool = False,
attention_head_dim=None,
use_legacy_4096: bool = False,
):
use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
Expand Down Expand Up @@ -120,6 +121,7 @@ def __call__(
cross_attention_dim=cross_attention_dim,
dim_head=attention_head_dim,
upcast_attention=upcast_attention,
use_legacy_4096=use_legacy_4096,
)

if use_ada_layer_norm_zero:
Expand Down Expand Up @@ -156,6 +158,7 @@ def __call__(
cross_attention_dim=cross_attention_dim,
dim_head=attention_head_dim,
upcast_attention=upcast_attention,
use_legacy_4096=use_legacy_4096,
)
if attn_output.memory_config() != hidden_states.memory_config():
if attn_output.memory_config().is_sharded():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,14 @@ def sharded_attention(self, query, key, value, original_seq_len, head_size, inde
v_sharded.deallocate()
return ttnn.reshape(attention_scores, (2, 8, attention_scores.shape[-2], attention_scores.shape[-1]))

def get_attention_scores_opt(self, query, t_key, value, original_seq_len, head_size, index=-1):
if (query.shape[-2] == 4096 and t_key.shape[-1] == 4096) or (
def get_attention_scores_opt(
self, query, t_key, value, original_seq_len, head_size, index=-1, use_legacy_4096=False
):
if (query.shape[-2] == 4096 and t_key.shape[-1] == 4096 and not use_legacy_4096) or (
query.shape[-2] == 1024 and t_key.shape[-1] == 1024
):
return self.time_sharded_attention(query, t_key, value, head_size)
else:
elif not (query.shape[-2] == 4096 and t_key.shape[-1] == 4096 and use_legacy_4096):
return self.sharded_attention(query, t_key, value, original_seq_len, head_size, index)

print("Legacy path")
Expand All @@ -566,18 +568,20 @@ def get_attention_scores_opt(self, query, t_key, value, original_seq_len, head_s
ttnn.deallocate(query)
ttnn.deallocate(t_key)
orig_shape = attention_scores.shape
attention_scores = ttnn.reshape(
attention_scores,
(
1,
attention_scores.shape[-4] * attention_scores.shape[-3],
attention_scores.shape[-2],
attention_scores.shape[-1],
),
)
attention_scores = ttnn.transformer.attention_softmax_(
attention_scores, attention_mask=attention_mask, head_size=head_size
)
# attention_scores = ttnn.reshape(
# attention_scores,
# (
# 1,
# attention_scores.shape[-4] * attention_scores.shape[-3],
# attention_scores.shape[-2],
# attention_scores.shape[-1],
# ),
# )
attention_scores = attention_scores * self.scales[head_size]
# attention_scores = ttnn.transformer.attention_softmax_(
# attention_scores, attention_mask=attention_mask, head_size=head_size
# )
attention_scores = ttnn.experimental.operations.primary.softmax_in_place(attention_scores)
attention_scores = ttnn.reshape(attention_scores, orig_shape)
if attention_scores.shape[-2] > original_seq_len:
attention_scores = attention_scores[:, :, :original_seq_len, :]
Expand Down Expand Up @@ -714,6 +718,7 @@ def __call__(
upcast_softmax: bool = False,
cross_attention_kwargs={},
index=-1,
use_legacy_4096: bool = False,
):
assert dim_head in self.scales
original_seq_len = hidden_states.shape[-2] // 2 # 2 is the batch size
Expand Down Expand Up @@ -811,11 +816,12 @@ def __call__(
)
ttnn.deallocate(hidden_states)

M, K, N = (
encoder_hidden_states.shape[-2],
encoder_hidden_states.shape[-1],
self.parameters.kv.weight.shape[-1],
)
if encoder_hidden_states is not None:
M, K, N = (
encoder_hidden_states.shape[-2],
encoder_hidden_states.shape[-1],
self.parameters.kv.weight.shape[-1],
)
grid_sizes = {8192: (8, 2), 2048: (8, 2), 512: (8, 2), 128: (4, 2)}
grid_size = grid_size = grid_sizes[hidden_states.shape[-2]]
in0_block_h, in0_block_w, out_subblock_h, out_subblock_w, out_block_h, out_block_w = determine_blocking(
Expand Down Expand Up @@ -907,6 +913,7 @@ def __call__(
original_seq_len,
dim_head,
index=index,
use_legacy_4096=use_legacy_4096,
)

hidden_states = ttnn.transformer.concatenate_heads(hidden_states, memory_config=ttnn.L1_MEMORY_CONFIG)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __call__(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift: str = "default",
use_legacy_4096: bool = False,
):
output_states = ()

Expand Down Expand Up @@ -101,6 +102,7 @@ def __call__(
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
use_legacy_4096=use_legacy_4096,
)

output_states += (ttnn.to_memory_config(hidden_states, ttnn.DRAM_MEMORY_CONFIG),)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __call__(
attn_num_head_channels=1,
only_cross_attention: bool = False,
index=-1,
use_legacy_4096: bool = False,
):
for i, (resnet, attention) in enumerate(zip(self.resnets, self.attentions)):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
Expand Down Expand Up @@ -167,6 +168,7 @@ def __call__(
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
output_bfloat16=(not add_upsample) and (i == len(self.resnets) - 1),
use_legacy_4096=use_legacy_4096,
)
else:
assert False, "We do not support Dual Transformer2DModel"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __call__(
eps=1e-5,
norm_elementwise_affine: bool = True,
output_bfloat16: bool = False,
use_legacy_4096: bool = False,
):
inner_dim = num_attention_heads * attention_head_dim
assert norm_num_groups == 32
Expand Down Expand Up @@ -307,6 +308,7 @@ def __call__(
upcast_attention=upcast_attention,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
use_legacy_4096=use_legacy_4096,
)

# 3. Output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def __call__(
return_dict: bool = True,
reader_patterns_cache: Optional[Dict] = None,
dtype: Optional[ttnn.DataType] = None,
use_legacy_4096: bool = False,
):
num_upsamplers = len(block_out_channels) - 1
default_overall_up_factor = 2**num_upsamplers
Expand Down Expand Up @@ -463,6 +464,7 @@ def __call__(
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_legacy_4096=use_legacy_4096,
)
elif down_block_type == "DownBlock2D":
sample, res_samples = down_block(
Expand Down Expand Up @@ -509,6 +511,7 @@ def __call__(
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
use_legacy_4096=use_legacy_4096,
)

# 5.up
Expand Down Expand Up @@ -570,6 +573,7 @@ def __call__(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
index=i,
use_legacy_4096=use_legacy_4096,
)
elif up_block_type == "UpBlock2D":
sample = up_block(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __call__(
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
use_legacy_4096=False,
):
has_cross_attention = True

Expand Down Expand Up @@ -90,6 +91,7 @@ def __call__(
eps=1e-5,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_legacy_4096=use_legacy_4096,
)
else:
assert False, "We do not support Dual Transformer"
Expand Down
1 change: 1 addition & 0 deletions tests/scripts/nightly/run_wh_b0_only.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ fi
echo "Running nightly tests for WH B0 only"

env pytest tests/ttnn/integration_tests/unet # -> failing: issue #7556
env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/integration_tests/stable_diffusion -k 512 --option legacy
# env pytest tests/ttnn/integration_tests/stable_diffusion # -> failing/hanging: issue #7560

env pytest models/demos/mamba/tests/test_mamba_ssm.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
)


@pytest.fixture(scope="session")
def option(pytestconfig):
return pytestconfig.getoption("option")


@skip_for_grayskull()
@pytest.mark.parametrize("model_name", ["CompVis/stable-diffusion-v1-4"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -155,14 +160,15 @@ def test_basic_transformer_block_256x256(device, model_name, N, C, H, W, index,
),
],
)
def test_basic_transformer_block_512x512(device, model_name, N, C, H, W, index, attention_head_dim):
def test_basic_transformer_block_512x512(device, model_name, N, C, H, W, index, attention_head_dim, option):
torch.manual_seed(0)

pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float32)
model = pipe.unet
model.eval()
config = model.config
basic_transformer = pipe.unet.up_blocks[index].attentions[1].transformer_blocks[0]
use_legacy_4096 = option == "legacy"

hidden_states_shape = torch.Size([N, C, H, W])
hidden_states = torch.rand(hidden_states_shape) * 0.01
Expand Down Expand Up @@ -209,6 +215,7 @@ def test_basic_transformer_block_512x512(device, model_name, N, C, H, W, index,
class_labels=class_labels,
config=config,
attention_head_dim=attention_head_dim,
use_legacy_4096=use_legacy_4096,
)

ttnn_output = ttnn.reshape(ttnn_output, [1, 2, ttnn_output.shape[-2] // 2, ttnn_output.shape[-1]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
)


@pytest.fixture(scope="session")
def option(pytestconfig):
return pytestconfig.getoption("option")


@skip_for_grayskull()
@pytest.mark.parametrize("model_name", ["CompVis/stable-diffusion-v1-4"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -210,9 +215,11 @@ def test_cross_attention_256x256(device, model_name, N, C, H, W, index, has_enco
),
],
)
def test_cross_attention_512x512(device, model_name, N, C, H, W, index, has_encoder_hidden_states):
def test_cross_attention_512x512(device, model_name, N, C, H, W, index, has_encoder_hidden_states, option):
torch.manual_seed(0)

use_legacy_4096 = option == "legacy"

pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float32)
model = pipe.unet
model.eval()
Expand Down Expand Up @@ -258,6 +265,7 @@ def test_cross_attention_512x512(device, model_name, N, C, H, W, index, has_enco
ttnn_encoder_hidden_states,
attention_mask=None,
dim_head=W // 8,
use_legacy_4096=use_legacy_4096,
)

ttnn_output = ttnn.from_device(ttnn_output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
)


@pytest.fixture(scope="session")
def option(pytestconfig):
return pytestconfig.getoption("option")


def ttnn_to_torch(input):
input = ttnn.to_layout(input, ttnn.ROW_MAJOR_LAYOUT)
input = ttnn.from_device(input)
Expand Down Expand Up @@ -212,6 +217,7 @@ def test_cross_attn_up_block_2d_512x512(
prev_output_channel,
in_channels,
out_channels,
option,
):
# setup pytorch model
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
Expand All @@ -220,6 +226,7 @@ def test_cross_attn_up_block_2d_512x512(
config = unet.config
state_dict = unet.state_dict()
unet_upblock = pipe.unet.up_blocks[index]
use_legacy_4096 = option == "legacy"

parameters = preprocess_model_parameters(
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
Expand Down Expand Up @@ -352,6 +359,7 @@ def test_cross_attn_up_block_2d_512x512(
attn_num_head_channels=attn_num_head_channels,
attention_mask=attention_mask,
cross_attention_dim=cross_attention_dim,
use_legacy_4096=use_legacy_4096,
)

op = ttnn_to_torch(op)
Expand All @@ -361,4 +369,4 @@ def test_cross_attn_up_block_2d_512x512(
op = torch.reshape(op, (N, H * 2, W * 2, Cout))
op = op.permute(0, 3, 1, 2)

assert_with_pcc(torch_output, op, 0.92)
assert_with_pcc(torch_output, op, 0.91)
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,8 @@ def test_resnet_block_2d_512x512(
else:
parameters = parameters.mid_block.resnets[index2]
resnet = pipe.unet.mid_block.resnets[index2]
torch.save(resnet, "resnet.pt")
torch.save(config, "config.pt")

# torch.save(resnet, "resnet.pt")
# torch.save(config, "config.pt")
else:
resnet = torch.load("resnet.pt")
config = torch.load("config.pt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
)


@pytest.fixture(scope="session")
def option(pytestconfig):
return pytestconfig.getoption("option")


@skip_for_grayskull()
@pytest.mark.parametrize(
"input_shape, index1, index2, attention_head_dim, block",
Expand Down Expand Up @@ -171,14 +176,15 @@ def test_transformer_2d_model_256x256(
@pytest.mark.parametrize("model_name", ["CompVis/stable-diffusion-v1-4"])
@pytest.mark.parametrize("device_l1_small_size", [32768], indirect=True)
def test_transformer_2d_model_512x512(
input_shape, index1, index2, block, attention_head_dim, model_name, device, reset_seeds
input_shape, index1, index2, block, attention_head_dim, model_name, device, reset_seeds, option
):
torch.manual_seed(0)
encoder_hidden_states = [1, 2, 77, 768]
timestep = (None,)
class_labels = (None,)
cross_attention_kwargs = (None,)
return_dict = True
use_legacy_4096 = option == "legacy"

num_layers = 1
num_attention_heads = 8
Expand Down Expand Up @@ -270,6 +276,7 @@ def test_transformer_2d_model_512x512(
norm_type=norm_type,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_legacy_4096=use_legacy_4096,
)

output = post_process_output(
Expand Down
Loading

0 comments on commit 984b420

Please sign in to comment.