diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index c78ee9e29..6056f9c6c 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -66,10 +66,12 @@ class BenchmarkLlama3_1_8B(BaseBenchmarkTest): def setUp(self): super().setUp() # TODO: add numpy files to Azure and download from it - self.artifacts_dir = Path("/shark-dev/data/llama3.1/weights/8b") - self.artifacts_dir_2048 = Path("/shark-dev/8b") - self.irpa_path = self.artifacts_dir / "fp16/llama3.1_8b_fp16.irpa" - self.irpa_path_fp8 = self.artifacts_dir / "f8/llama3.1_8b_fp8.irpa" + self.artifacts_dir = Path("/shark-dev/8b") + self.weights_dir = self.artifacts_dir / "instruct/weights" + self.irpa_path = self.weights_dir / "llama3.1_8b_instruct_fp16.irpa" + self.irpa_path_fp8 = ( + self.artifacts_dir / "fp8/native_fp8_e4m3fnuz_llama3_8b.irpa" + ) self.tensor_parallelism_size = 1 self.dir_path_8b = self.dir_path / "llama-8b" self.temp_dir_8b = Path(self.dir_path_8b) @@ -83,15 +85,6 @@ def setUp(self): tensor_parallelism_size=self.tensor_parallelism_size, block_seq_stride=32, ) - self.llama8b_fp8_decomposed_artifacts = ExportArtifacts( - irpa_path=str(self.irpa_path_fp8), - batch_size=4, - iree_hip_target="gfx942", - iree_hal_target_device="hip", - attention_kernel="decomposed", - tensor_parallelism_size=self.tensor_parallelism_size, - block_seq_stride=32, - ) self.llama8b_fp8_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), batch_size=4, @@ -104,16 +97,16 @@ def setUp(self): attention_dtype="float8_e4m3fnuz", ) self.prefill_args_bs4_128_stride_32_f16 = ( - self.artifacts_dir / "prefill_args_bs4_128_stride_32" + self.artifacts_dir / "prefill_args_bs4_128_stride_32_tp1" ) self.decode_args_bs4_128_stride_32_f16 = ( - self.artifacts_dir / "decode_args_bs4_128_stride_32" + self.artifacts_dir / "decode_args_bs4_128_stride_32_tp1" ) self.prefill_args_bs4_2048_stride_32_f16 = ( - self.artifacts_dir_2048 / "prefill_args_bs4_2048_stride_32" + self.artifacts_dir / "prefill_args_bs4_2048_stride_32" ) self.decode_args_bs4_2048_stride_32_f16 = ( - self.artifacts_dir_2048 / "decode_args_bs4_2048_stride_32" + self.artifacts_dir / "decode_args_bs4_2048_stride_32" ) self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8" self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8" @@ -169,39 +162,6 @@ def setUp(self): "--benchmark_repetitions=3", ] - def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_128(self): - output_file_name = self.dir_path_8b / "f16_torch_prefill_128" - output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( - suffix=".mlir", prefix=output_file_name - ) - output_json = self.llama8b_f16_torch_sdpa_artifacts.create_file( - suffix=".json", prefix=output_file_name - ) - output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( - suffix=".vmfb", prefix=output_file_name - ) - export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( - mlir_path=output_mlir, - json_path=output_json, - skip_decode=True, - ) - self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb( - mlir_path=str(output_mlir), - vmfb_path=output_vmfb, - hal_dump_path=output_file_name, - cwd=self.repo_root, - args=self.compile_args, - ) - # benchmark prefill - self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_prefill_nondecomposed_args_fp16, - cwd=self.repo_root, - ) - - @skipif_run_quick_llama_test def testBenchmark8B_f16_Non_Decomposed_Input_Len_128(self): output_file_name = self.dir_path_8b / "f16_torch_128" output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( @@ -283,7 +243,7 @@ def testBenchmark8B_f16_Non_Decomposed_Input_Len_2048(self): @pytest.mark.xfail( reason="Benchmark inputs not configured yet.", - strict=False, + strict=True, raises=IreeBenchmarkException, ) def testBenchmark8B_fp8_Non_Decomposed(self):