Skip to content

Commit

Permalink
Added new instances for 405B decoding (pytorch#2936)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#2936

Reviewed By: jwfromm

Differential Revision: D60670405
  • Loading branch information
zjing14 authored and facebook-github-bot committed Aug 6, 2024
1 parent 6607072 commit 2ad32dc
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 4 deletions.
6 changes: 3 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def bench_with_rotating_buffer(self, fn, args):
import copy
import pickle

# torch.cuda.get_device_properties does not have L2 cache size,
# so hard code an overapproximation of L2 cache size to ensure L2 cache flush
total_buffer_size = 10000 * 1024 * 1024
# torch.cuda.get_device_properties does not have L2/L3 cache size,
# so hard code an overapproximation of L2/L3 cache size to ensure L2/L3 cache flush
total_buffer_size = 512 * 1024 * 1024

# Use pickle to serialize model input to estimate total sizes of input
input_sizes = len(pickle.dumps(args))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ static const std::unordered_map<
{{1, 13312, 6656},
fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1},
{{1, 13312, 16384},
fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1},
//fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1},
fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2},
{{1, 16384, 6656},
fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1},
{{1, 16384, 16384},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// The smallest kernel we have available. Works well for memory bound shapes.
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
512,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// The smallest kernel we have available. Works well for memory bound shapes.
using DeviceGemmInstance = DeviceGemmHelper<
64,
16,
16,
512,
16,
16,
1,
1,
S<8, 8, 1>,
S<8, 8, 1>,
S<1, 16, 1, 4>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ fp8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2(
at::Tensor w_scale,
at::Tensor Y);

at::Tensor
fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

// Alternate tiny kernel that seems to do well when M and K are all small.
at::Tensor
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2(
Expand Down Expand Up @@ -64,6 +72,14 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
at::Tensor w_scale,
at::Tensor Y);

at::Tensor
fp8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

// Alternate tiny kernel that seems to do well when M and N are all small.
at::Tensor
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2(
Expand Down

0 comments on commit 2ad32dc

Please sign in to comment.