diff --git a/tests/ttnn/unit_tests/operations/test_max.py b/tests/ttnn/unit_tests/operations/test_max.py index 2a9c6f91ed9..0a851877853 100644 --- a/tests/ttnn/unit_tests/operations/test_max.py +++ b/tests/ttnn/unit_tests/operations/test_max.py @@ -90,3 +90,29 @@ def test_max_global(device, batch_size, h, w): output_tensor = output_tensor[0, 0, 0] assert_with_pcc(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize( + "input_shape_and_dim", + [ + ((32, 32, 32, 64), -4), + ((2, 32, 32, 64), -3), + ((32, 32, 64), -3), + ], +) +@pytest.mark.parametrize("keepdim", [True, False]) +def test_max_dim(device, input_shape_and_dim, keepdim): + input_shape, max_dim = input_shape_and_dim + + torch_input_tensor = torch_random(input_shape, -100, 100, dtype=torch.bfloat16) + torch_output_tensor, _ = torch.max(torch_input_tensor, dim=max_dim, keepdim=keepdim) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + + output_tensor = ttnn.max(input_tensor, dim=max_dim, keepdim=keepdim) + output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor) diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp index ca77f783e11..1f971e25b52 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp @@ -22,10 +22,6 @@ static Tensor reduce_impl( float scalar, bool reshape) { using ttnn::operations::experimental::auto_format::AutoFormat; - if (not keepdim) { - TT_THROW("keepdim=False is not supported"); - } - auto input_shape = input_tensor_arg.get_shape(); auto rank = input_shape.size(); auto memory_config = memory_config_arg.value_or(input_tensor_arg.memory_config()); @@ -58,41 +54,39 @@ static Tensor reduce_impl( rank); } - if (dim.size() == 1 && rank == 4) { - if (dim[0] == rank - 3) { - auto out_shape = input_tensor_arg.get_legacy_shape(); - out_shape[1] = 1; - - Tensor output = ttnn::transpose(input_tensor_arg, 1, -2, memory_config); - output = reduce_impl(output, 2, keepdim, memory_config, compute_kernel_config, scalar, false); - output = ttnn::transpose(output, 1, -2, memory_config); - return AutoFormat::format_output_tensor(output, out_shape, input_tensor_arg.device(), Layout::TILE); - } else if (dim[0] == 0) { - auto out_shape = input_tensor_arg.get_legacy_shape(); - out_shape[0] = 1; - - Tensor output = ttnn::transpose(input_tensor_arg, 0, -2, memory_config); - output = reduce_impl(output, 2, keepdim, memory_config, compute_kernel_config, scalar, false); - output = ttnn::transpose(output, 0, -2, memory_config); - return AutoFormat::format_output_tensor(output, out_shape, input_tensor_arg.device(), Layout::TILE); - } - } - std::sort(dim.begin(), dim.end()); ttnn::SmallVector output_shape; - ttnn::SmallVector padded_output_shape; for (int axis = 0; axis < input_shape.size(); axis++) { if (std::find(dim.begin(), dim.end(), axis) != dim.end()) { if (keepdim) { output_shape.push_back(1); - padded_output_shape.push_back(axis >= rank - 2 ? ttnn::TILE_SIZE : 1); } } else { // Get the shape for the output tensor output_shape.push_back(input_shape[axis]); - // Get the padded shape for the output tensor - padded_output_shape.push_back(input_shape.value[axis]); + } + } + + if (dim.size() == 1 && (rank == 3 || rank == 4)) { + if (dim[0] == 1 && rank == 4) { + Tensor output = ttnn::transpose(input_tensor_arg, 1, -2, memory_config); + output = reduce_impl( + output, 2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true); + output = ttnn::transpose(output, 1, -2, memory_config); + if (reshape) { + output = ttnn::reshape(output, ttnn::Shape{output_shape}); + } + return output; + } else if (dim[0] == 0) { + Tensor output = ttnn::transpose(input_tensor_arg, 0, -2, memory_config); + output = reduce_impl( + output, -2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true); + output = ttnn::transpose(output, 0, -2, memory_config); + if (reshape) { + output = ttnn::reshape(output, ttnn::Shape{output_shape}); + } + return output; } } @@ -199,7 +193,7 @@ static Tensor reduce_impl( } if (reshape) { - output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{output_shape, padded_output_shape}); + output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{output_shape}); } return output_tensor;