Skip to content

Commit

Permalink
Merge branch 'perplexity-vmfb' of https://github.com/nod-ai/SHARK-Pla…
Browse files Browse the repository at this point in the history
…tform into perplexity-vmfb
  • Loading branch information
archana-ramalingam committed Oct 26, 2024
2 parents d1ed9a2 + da04fd1 commit 8ab20e0
Show file tree
Hide file tree
Showing 22 changed files with 592 additions and 195 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
ref: candidate-20240904.1006
ref: candidate-20241025.1058

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64_asan-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_SOURCE_DIR }}
submodules: false
ref: candidate-20240904.1006
ref: candidate-20241025.1058

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_SOURCE_DIR }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64_nogil-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
ref: candidate-20240904.1006
ref: candidate-20241025.1058

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_windows_x64-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
ref: candidate-20240904.1006
ref: candidate-20241025.1058

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
Expand Down
37 changes: 16 additions & 21 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,31 +59,26 @@ def main():
default="decomposed",
choices=["decomposed", "torch_sdpa"],
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="How many devices are involved for tensor parallel sharding.",
)

args = cli.parse(parser)
dataset_type = cli.get_input_data_files(args)
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
dataset = cli.get_input_dataset(args)

hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
llama_config = LlamaModelConfig(hp)
if args.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, llama_config)
llama_config.use_hf = False
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
llama_config.attention_kernel = args.attention_kernel

# This is a bit gross and should be changed in the future. Best Idea I had so far.
attn_q_weight = dataset.root_theta.tensor("blk")["0"]["attn_q"]["weight"]
if isinstance(attn_q_weight, SplitPrimitiveTensor):
llama_config.tensor_parallelism_size = attn_q_weight.shard_count
tensor_parallelism_size = (
dataset.properties["tensor_parallelism_size"]
if "tensor_parallelism_size" in dataset.properties
else 1
)
llama_config = LlamaModelConfig(
hp,
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="direct" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
)

if llama_config.hp.expert_count:
if llama_config.hp.model_arch == "grok":
Expand Down Expand Up @@ -127,7 +122,7 @@ def setup_cache(model, shard_count):
# Direct cache dimensions:
# 2 * transformer_block_count of...
# [bs, seq_length, attn_head_count, attn_head_dim]
dynamic_shapes = (2 * hp.block_count) * [{}]
dynamic_shapes = [None]
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")

Expand All @@ -148,7 +143,7 @@ def setup_cache(model, shard_count):
for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))

return unpacked, shard_dim, dynamic_shapes, arg_affinities
return torch.stack(unpacked), shard_dim, dynamic_shapes, arg_affinities

def repack_cache(cache, shard_dim):
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]
Expand Down Expand Up @@ -189,7 +184,7 @@ def generate_batch_prefill(bs: int):
arg_device=arg_affinities,
)
def _(model, tokens, seq_lens, seq_block_ids, cs):
cache_tensors = cs
cache_tensors = torch.unbind(cs)

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
Expand Down
19 changes: 13 additions & 6 deletions sharktank/sharktank/examples/sharding/shard_llm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
weights of an LLM by converting the RHS of all eligible layers to a sharded
form.
"""
from ...transforms.dataset import MmtRHSShardingTransform
from ...models.llama.sharding import shard_theta
from ...layers import LlamaHParams, LlamaModelConfig
from ...types import *


Expand All @@ -21,16 +22,22 @@ def main(raw_args=None):
cli.add_input_dataset_options(parser)
cli.add_output_dataset_options(parser)
parser.add_argument(
"--num-shards", type=int, required=True, help="Number of shards to split"
"--tensor-parallelism-size",
type=int,
required=True,
help="Number of shards to split",
)
args = cli.parse(parser, args=raw_args)
dataset = cli.get_input_dataset(args)

tr = MmtRHSShardingTransform(
r"^blk\.[0-9]+\.(attn_k|attn_q|attn_v|ffn_gate|ffn_up|ffn_down)\.weight$",
num_shards=8,
hp = LlamaHParams.from_gguf_props(dataset.properties)
llama_config = LlamaModelConfig(
hp, tensor_parallelism_size=args.tensor_parallelism_size
)
dataset.transform(tr)
sharded_theta = shard_theta(dataset.root_theta, llama_config)
sharded_theta.rename_tensors_to_paths()
dataset.root_theta = sharded_theta
dataset.properties["tensor_parallelism_size"] = args.tensor_parallelism_size
dataset.save(args.output_irpa_file, io_report_callback=print)


Expand Down
24 changes: 2 additions & 22 deletions sharktank/sharktank/ops/qconv_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings

import torch
import torch.nn.functional as F

from sharktank import kernels

Expand Down Expand Up @@ -119,7 +120,7 @@ def qconv2d_tensor_scaled(
padding = _expand_int_to_2_tuple(padding)
dilation = _expand_int_to_2_tuple(dilation)
extended_padding_list = [item for item in padding for _ in range(2)]
padded_input = _pad_last_2d(input_qs, extended_padding_list)
padded_input = F.pad(input_qs, pad=extended_padding_list)
y_qs = _invoke_conv2d_kernel(
padded_input,
weight_qs,
Expand Down Expand Up @@ -258,27 +259,6 @@ def _invoke_pooling_sum_kernel(input, kernel_size, stride, dilation, *, accum_dt
return output


def _pad_last_2d(input_tensor, pad_width):
# pad_width should be in the format [pad_left, pad_right, pad_top, pad_bottom]
pad_left, pad_right, pad_top, pad_bottom = pad_width
batch_size, channels, height, width = input_tensor.shape

# Create a new tensor with the desired padded size filled with zeros
padded_height = height + pad_top + pad_bottom
padded_width = width + pad_left + pad_right
padded_tensor = torch.zeros(
(batch_size, channels, padded_height, padded_width),
dtype=input_tensor.dtype,
device=input_tensor.device,
)

# Copy the values from the input tensor to the appropriate location in the padded tensor
padded_tensor[
:, :, pad_top : pad_top + height, pad_left : pad_left + width
] = input_tensor
return padded_tensor


def _flatten_input_scale_offset_channels(d, m):
"""Flattens either a 4d or 0d scale/offset as [N, C, H, W] to 1D.
Expand Down
7 changes: 4 additions & 3 deletions sharktank/tests/kernels/conv_2d_nchw_fchw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from parameterized import parameterized

import torch
import torch.nn.functional as F

from iree.turbine import aot
from sharktank import kernels
from sharktank.ops.qconv_impls import _pad_last_2d


class conv_2d_nchw_fchw_test(unittest.TestCase):
Expand All @@ -36,7 +36,8 @@ def testBS32(self, input_dtype, output_dtype_name, atol, rtol):
inputs = (torch.rand([2, 4, 64, 64]) * 64).to(input_dtype)
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
inputs_pad = _pad_last_2d(inputs, extended_list)
inputs_pad = F.pad(inputs, pad=extended_list)

weights = (torch.rand([8, 4, 3, 3]) * 64).to(input_dtype)
bias = (torch.rand([8]) * 64).to(dtype=output_dtype)
result = kernels.conv_2d_nchw_fchw(
Expand Down Expand Up @@ -68,7 +69,7 @@ def forward(self, a, b, c):
inputs = torch.rand([2, 320, 64, 64]) * 64
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
inputs_pad = _pad_last_2d(inputs, extended_list)
inputs_pad = F.pad(inputs, pad=extended_list)
ep = torch.export.export(
mod,
args=(
Expand Down
6 changes: 3 additions & 3 deletions sharktank/tests/kernels/pooling_nchw_sum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from parameterized import parameterized

import torch
import torch.nn.functional as F

from iree.turbine import aot
from sharktank import kernels
from sharktank.ops.qconv_impls import _pad_last_2d


class pooling_nchw_sum_test(unittest.TestCase):
Expand All @@ -34,7 +34,7 @@ def testBS32(self, atol, rtol):
a = (torch.randint(0, 100, (2, 1, 128, 128))).to(torch.float32)
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
inputs_pad = _pad_last_2d(a, extended_list)
inputs_pad = F.pad(a, pad=extended_list)
weight_shape = [3, 3]
stride = [1, 1]
dilations = [1, 1]
Expand Down Expand Up @@ -62,7 +62,7 @@ def forward(self, a):
inputs = torch.rand([2, 1, 128, 128]) * 64
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
inputs_pad = _pad_last_2d(inputs, extended_list)
inputs_pad = F.pad(inputs, pad=extended_list)
ep = torch.export.export(
mod,
args=((inputs_pad).to(dtype),),
Expand Down
36 changes: 23 additions & 13 deletions sharktank/tests/transforms/dataset_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,28 @@
from sharktank.utils.testing import MainRunnerTestBase


class MmtRHSShardingTransformTest(MainRunnerTestBase):
def testPrimitive(self):
class DatasetShardingTransformTest(MainRunnerTestBase):
def testShardLlmDataset(self):
orig_pts = [
DefaultPrimitiveTensor(
name="blk.1.attn_k.weight", data=torch.randn([32, 128])
),
DefaultPrimitiveTensor(
name="blk.2.attn_q.weight", data=torch.randn([48, 64])
),
DefaultPrimitiveTensor(name="other", data=torch.randn([2, 2])),
]
ds_orig = Dataset({}, Theta(orig_pts))
ds_orig = Dataset(
{
"general.architecture": "llm",
"llm.attention.head_count": 1,
"llm.context_length": 2,
"llm.embedding_length": 3,
"llm.block_count": 4,
"llm.feed_forward_length": 5,
"llm.attention.layer_norm_rms_epsilon": 0.1,
},
Theta(orig_pts),
)
input_path = self.save_dataset(ds_orig, "input")
output_path = self.get_irpa_path("output")
from sharktank.examples.sharding import shard_llm_dataset
Expand All @@ -41,38 +51,38 @@ def testPrimitive(self):
input_path,
"--output-irpa-file",
output_path,
"--num-shards",
"--tensor-parallelism-size",
8,
)
ds_tran = Dataset.load(output_path, mmap=False)

ds_tran.properties["tensor_parallelism_size"] = 8

# Verify.
flat_sts = ds_tran.root_theta.flatten()
self.assertEqual(3, len(flat_sts))
self.assertEqual(2, len(flat_sts))
st_1 = flat_sts["blk.1.attn_k.weight"]
st_2 = flat_sts["blk.2.attn_q.weight"]
pt_3 = flat_sts["other"]
self.assertIsInstance(st_1, SplitPrimitiveTensor)
self.assertIsInstance(st_2, SplitPrimitiveTensor)
self.assertIsInstance(pt_3, DefaultPrimitiveTensor)
self.assertListEqual(st_1.shape, [32, 128])
self.assertListEqual(st_2.shape, [48, 64])

# Verify component shapes for st_1.
self.assertEqual(8, len(st_1.shards))
self.assertTrue(all(pt.shape == [32, 16] for pt in st_1.shards))
self.assertTrue(all(pt.shape == [4, 128] for pt in st_1.shards))
self.assertTrue(
all(list(pt.as_torch().shape) == [32, 16] for pt in st_1.shards)
all(list(pt.as_torch().shape) == [4, 128] for pt in st_1.shards)
)

# Verify component shapes for st_2.
self.assertEqual(8, len(st_2.shards))
self.assertTrue(all(pt.shape == [48, 8] for pt in st_2.shards))
self.assertTrue(all(list(pt.as_torch().shape) == [48, 8] for pt in st_2.shards))
self.assertTrue(all(pt.shape == [6, 64] for pt in st_2.shards))
self.assertTrue(all(list(pt.as_torch().shape) == [6, 64] for pt in st_2.shards))

# Verify contents for one shard for sanity.
new_t = st_1.shards[0].as_torch()
torch.testing.assert_close(new_t, orig_pts[0].as_torch().split(16, dim=1)[0])
torch.testing.assert_close(new_t, orig_pts[0].as_torch().split(4, dim=0)[0])


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion shortfin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ elseif (SHORTFIN_BUNDLE_DEPS)
FetchContent_Declare(
shortfin_iree
GIT_REPOSITORY https://github.com/iree-org/iree.git
GIT_TAG candidate-20240904.1006
GIT_TAG candidate-20241025.1058
GIT_SUBMODULES ${IREE_SUBMODULES}
GIT_SHALLOW TRUE
SYSTEM
Expand Down
Loading

0 comments on commit 8ab20e0

Please sign in to comment.