Skip to content

Commit

Permalink
Merge branch 'main' into where-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms authored Dec 17, 2024
2 parents e58a2b8 + 0aed232 commit 815c418
Show file tree
Hide file tree
Showing 11 changed files with 727 additions and 0 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ exclude_patterns = [
'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
'onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py', # onnxscript code
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
'onnxscript/tools/function_unittest_producer.py', # FIXME
Expand Down
66 changes: 66 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,78 @@
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import basic_constant_propagation


def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5) -> None:
"""Display the (backward or forward) subgraph from a given value or node upto a certain depth."""
slice = []

def visit(node: ir.Node, depth):
if node in slice:
return
slice.append(node)
if depth < depth_limit:
if backward:
for inp in node.inputs:
if inp is not None and inp.producer() is not None:
visit(inp.producer(), depth + 1) # type: ignore[arg-type]
else:
for out in node.outputs:
for consumer, _ in out.uses():
visit(consumer, depth + 1)

if isinstance(x, ir.Node):
visit(x, 0)
elif isinstance(x, ir.Value) and x.producer() is not None:
visit(x.producer(), 0) # type: ignore[arg-type]
if slice:
graph = slice[0].graph
if graph:
# Display nodes in same order as in graph:
# Currently doesn't handle (control-flow) subgraphs
for node in graph:
if node in slice:
node.display()
else:
for node in reversed(slice):
node.display()


def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:
node = value.producer()
if node is not None:
basic_constant_propagation([node])
return value.const_value


def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
"""Convenience wrapper to get (optional) numpy value from an optional IR Value.
This is intended for use in optimizations/rewriting. Note that this does not
yet handle the distinction between inputs with default values (values that are
both graph inputs and graph initializers), which should not be treated as a
constant, and true constant values. The caller should make the distinction, as
a value does not contain enough information to determine this. (TODO)
"""
if val is None:
return None
const_value = val.const_value
if const_value is not None:
try:
return const_value.numpy()
except FileNotFoundError:
# External data is not available.
return None
return None


def get_singleton_value(val: ir.Value | None):
"""Returns element of a single element tensor constant value, and None otherwise."""
np_val = get_numpy_value(val)
if np_val is not None and np_val.size == 1:
return np_val.item()
return None
3 changes: 3 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
253 changes: 253 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""
A one-layer SmolLM model test case.
This is an onnxscript version of the model.
"""

import numpy
from onnx.helper import make_tensor

import onnxscript.ir as ir
from onnxscript import script
from onnxscript.onnx_opset import opset18
from onnxscript.onnx_types import FLOAT, INT64


def make_model(
input_layernorm_weight_0,
post_attention_layernorm_weight0,
norm_weight,
head_weight,
self_attn_q_proj_weight0,
self_attn_k_proj_weight0,
self_attn_v_proj_weight0,
self_attn_o_proj_weight0,
mlp_gate_proj_weight0,
mlp_up_proj_weight0,
mlp_down_proj_weight0,
):
@script()
def main_graph(
input0: INT64[1, 10], input1: FLOAT[1, 10], input2: INT64[1, 10]
) -> (FLOAT[1, 10, 49152], FLOAT[1, 32, 10, 64], FLOAT[1, 32, 10, 64]):
model_layers_0_input_layernorm_weight = opset18.Constant(
value=input_layernorm_weight_0
)
model_layers_0_post_attention_layernorm_weight = opset18.Constant(
value=post_attention_layernorm_weight0
)
model_norm_weight = opset18.Constant(value=norm_weight)
lm_head_weight = opset18.Constant(value=head_weight)
model_layers_0_self_attn_q_proj_weight = opset18.Constant(
value=self_attn_q_proj_weight0
)
model_layers_0_self_attn_k_proj_weight = opset18.Constant(
value=self_attn_k_proj_weight0
)
model_layers_0_self_attn_v_proj_weight = opset18.Constant(
value=self_attn_v_proj_weight0
)
model_layers_0_self_attn_o_proj_weight = opset18.Constant(
value=self_attn_o_proj_weight0
)
model_layers_0_mlp_gate_proj_weight = opset18.Constant(value=mlp_gate_proj_weight0)
model_layers_0_mlp_up_proj_weight = opset18.Constant(value=mlp_up_proj_weight0)
model_layers_0_mlp_down_proj_weight = opset18.Constant(value=mlp_down_proj_weight0)

embedding = opset18.Gather(lm_head_weight, input0, axis=0)
minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38])
mask_10x10 = opset18.Trilu(minus_inf_10x10, 1)
slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10])
unsqueeze_2 = opset18.Unsqueeze(input1, 1)
unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2)
add = slice_5 + unsqueeze_3
eq = add == 0.0
slice_10 = slice_5
masked_fill = opset18.Where(eq, -3.4028235e38, slice_10)
val_179 = opset18.Transpose(masked_fill, perm=[2, 1, 0, 3])
slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3])
val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3])
slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3])
unsqueeze_6 = opset18.Unsqueeze(input2, 1)
_to_copy_1 = opset18.Cast(unsqueeze_6, to=1)
view_1 = opset18.Constant(
value=make_tensor(
"value",
1,
dims=[1, 32, 1],
vals=[
1.0,
0.7498942017555237,
0.5623413324356079,
0.4216965138912201,
0.3162277638912201,
0.23713736236095428,
0.17782793939113617,
0.1333521455526352,
0.10000000149011612,
0.07498941570520401,
0.05623412877321243,
0.04216964915394783,
0.03162277862429619,
0.0237137358635664,
0.017782794311642647,
0.01333521492779255,
0.009999999776482582,
0.007498942315578461,
0.005623413249850273,
0.0042169648222625256,
0.003162277862429619,
0.0023713738191872835,
0.0017782794311642647,
0.0013335214462131262,
0.0010000000474974513,
0.0007498941849917173,
0.000562341301701963,
0.00042169648804701865,
0.0003162277862429619,
0.0002371373848291114,
0.00017782794020604342,
0.0001333521504420787,
],
)
)
view_2 = opset18.Reshape(_to_copy_1, [1, 1, 10], allowzero=0)
bmm = view_1 @ view_2
view_3 = opset18.Reshape(bmm, [1, 32, 10], allowzero=0)
transpose = opset18.Transpose(view_3, perm=[0, 2, 1])
cat = opset18.Concat(transpose, transpose, axis=-1)
cos = opset18.Cos(cat)
sin = opset18.Sin(cat)
pow_1 = embedding**2.0
mean = opset18.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0)
add_1 = mean + 1e-05
val_244 = opset18.Sqrt(add_1)
rsqrt = opset18.Reciprocal(val_244)
mul_3 = embedding * rsqrt
mul_4 = model_layers_0_input_layernorm_weight * mul_3
t = opset18.Transpose(model_layers_0_self_attn_q_proj_weight, perm=[1, 0])
view_5 = mul_4 @ t
t_1 = opset18.Transpose(model_layers_0_self_attn_k_proj_weight, perm=[1, 0])
view_7 = mul_4 @ t_1
t_2 = opset18.Transpose(model_layers_0_self_attn_v_proj_weight, perm=[1, 0])
view_9 = mul_4 @ t_2
view_10 = opset18.Reshape(view_5, [1, 10, 32, 64], allowzero=0)
transpose_1 = opset18.Transpose(view_10, perm=[0, 2, 1, 3])
view_11 = opset18.Reshape(view_7, [1, 10, 32, 64], allowzero=0)
transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3])
view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0)
transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3])
unsqueeze_7 = opset18.Unsqueeze(cos, 1)
unsqueeze_8 = opset18.Unsqueeze(sin, 1)
mul_5 = transpose_1 * unsqueeze_7
val_267 = opset18.Constant(value_ints=[1])
slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267)
val_277 = opset18.Constant(value_ints=[1])
slice_20 = opset18.Slice(transpose_1, [32], [9223372036854775807], [3], val_277)
neg = opset18.Neg(slice_20)
cat_1 = opset18.Concat(neg, slice_19, axis=-1)
mul_6 = cat_1 * unsqueeze_8
add_2 = mul_5 + mul_6
mul_7 = transpose_2 * unsqueeze_7
val_287 = opset18.Constant(value_ints=[1])
slice_21 = opset18.Slice(transpose_2, [0], [32], [3], val_287)
val_297 = opset18.Constant(value_ints=[1])
slice_22 = opset18.Slice(transpose_2, [32], [9223372036854775807], [3], val_297)
neg_1 = opset18.Neg(slice_22)
cat_2 = opset18.Concat(neg_1, slice_21, axis=-1)
mul_8 = cat_2 * unsqueeze_8
add_3 = mul_7 + mul_8
val_346 = opset18.Reshape(add_3, [-1, 10, 64], allowzero=0)
val_347 = opset18.Transpose(val_346, perm=[0, 2, 1])
val_349 = opset18.Reshape(val_347, [1, 32, 64, 10], allowzero=0)
val_351 = add_2 * [0.35355338]
val_353 = val_349 * [0.35355338]
val_354 = val_351 @ val_353
val_355 = val_354 + slice_scatter_1
val_356 = opset18.Softmax(val_355, axis=-1)
getitem = val_356 @ transpose_3
transpose_4 = opset18.Transpose(getitem, perm=[0, 2, 1, 3])
view_13 = opset18.Reshape(transpose_4, [1, 10, -1], allowzero=0)
t_3 = opset18.Transpose(model_layers_0_self_attn_o_proj_weight, perm=[1, 0])
view_15 = view_13 @ t_3
add_4 = embedding + view_15
pow_2 = add_4**2.0
mean_1 = opset18.ReduceMean(pow_2, [-1], keepdims=1, noop_with_empty_axes=0)
add_5 = mean_1 + 1e-05
val_379 = opset18.Sqrt(add_5)
rsqrt_1 = opset18.Reciprocal(val_379)
mul_9 = add_4 * rsqrt_1
mul_10 = model_layers_0_post_attention_layernorm_weight * mul_9
t_4 = opset18.Transpose(model_layers_0_mlp_gate_proj_weight, perm=[1, 0])
view_17 = mul_10 @ t_4
val_383 = opset18.Sigmoid(view_17)
silu = view_17 * val_383
t_5 = opset18.Transpose(model_layers_0_mlp_up_proj_weight, perm=[1, 0])
view_19 = mul_10 @ t_5
mul_11 = silu * view_19
t_6 = opset18.Transpose(model_layers_0_mlp_down_proj_weight, perm=[1, 0])
view_21 = mul_11 @ t_6
add_6 = add_4 + view_21
pow_3 = add_6**2.0
mean_2 = opset18.ReduceMean(pow_3, [-1], keepdims=1, noop_with_empty_axes=0)
add_7 = mean_2 + 1e-05
val_391 = opset18.Sqrt(add_7)
rsqrt_2 = opset18.Reciprocal(val_391)
mul_12 = add_6 * rsqrt_2
mul_13 = model_norm_weight * mul_12
t_7 = opset18.Transpose(lm_head_weight, perm=[1, 0])
view_23 = mul_13 @ t_7
_to_copy_12 = opset18.Identity(view_23)
return _to_copy_12, add_3, transpose_3

model = main_graph.to_model_proto()
return model


def make_model_with_random_weights():
input_layernorm_weight_0 = numpy.random.rand(2048).astype(numpy.float32)
post_attention_layernorm_weight0 = numpy.random.rand(2048).astype(numpy.float32)
norm_weight = numpy.random.rand(2048).astype(numpy.float32)
head_weight = numpy.random.rand(49152, 2048).astype(numpy.float32)
self_attn_q_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32)
self_attn_k_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32)
self_attn_v_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32)
self_attn_o_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32)
mlp_gate_proj_weight0 = numpy.random.rand(8192, 2048).astype(numpy.float32)
mlp_up_proj_weight0 = numpy.random.rand(8192, 2048).astype(numpy.float32)
mlp_down_proj_weight0 = numpy.random.rand(2048, 8192).astype(numpy.float32)
model = make_model(
input_layernorm_weight_0,
post_attention_layernorm_weight0,
norm_weight,
head_weight,
self_attn_q_proj_weight0,
self_attn_k_proj_weight0,
self_attn_v_proj_weight0,
self_attn_o_proj_weight0,
mlp_gate_proj_weight0,
mlp_up_proj_weight0,
mlp_down_proj_weight0,
)
return model


class _SmollmTestData:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = make_model_with_random_weights()
model = ir.serde.deserialize_model(model_proto)
self._onnx_model = model
return self._onnx_model

def get_ort_inputs(self):
if not hasattr(self, "_ort_inputs"):
inputs = {
"input0": numpy.random.randint(0, 49152, (1, 10)).astype(numpy.int64),
"input1": numpy.ones((1, 10), dtype=numpy.float32),
"input2": numpy.arange(10, dtype=numpy.int64).reshape(1, 10),
}
self._ort_inputs = inputs
return self._ort_inputs
Loading

0 comments on commit 815c418

Please sign in to comment.