Skip to content

Commit

Permalink
Run lint
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Dec 20, 2024
1 parent 90f0b7b commit 1fdc19b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
return None
return default


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
Expand All @@ -326,6 +327,7 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return op.Identity(input)
return None


@register("Cast")
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ def is_singleton_value(
assert rtol is not None
return math.isclose(scalar, expected, rtol=rtol)

Check warning on line 98 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L97-L98

Added lines #L97 - L98 were not covered by tests

Check failure

Code scanning / lintrunner

MYPY/call-arg Error

Unexpected keyword argument "rtol" for "isclose" To disable, use # type: ignore[call-arg]

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "isclose" has incompatible type "float | Callable[..., Any]"; expected "SupportsFloat | SupportsIndex" To disable, use # type: ignore[arg-type]


def has_rank(value: ir.Value | None, rank: int) -> bool:
"""Returns True if the value is statically known to have the given rank, and False otherwise."""
if value is None:
return False
shape = value.shape
return (shape is not None) and (shape.rank() == rank)
return (shape is not None) and (shape.rank() == rank)

Check warning on line 106 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L104-L106

Added lines #L104 - L106 were not covered by tests
10 changes: 7 additions & 3 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.rewriter import _ir_utils, pattern

# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops.

# Original code (from transformers) for computing cos/sin cache for RoPE:
# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135
# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135
# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
# position_ids_expanded = position_ids[:, None, :].float()
# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# emb = torch.cat((freqs, freqs), dim=-1)
# cos = emb.cos()
# sin = emb.sin()


class CosSinCacheFusion(pattern.RewriteRuleClassBase):
def __init__(self, name: str, max_pos_id: int):
super().__init__(name)
Expand Down Expand Up @@ -48,16 +50,18 @@ def rewrite(self, op, inv_freq, position_ids, **_):
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
sin_value = np.sin(angles)
cos_2d= op.Constant(value=ir.tensor(cos_value))
cos_2d = op.Constant(value=ir.tensor(cos_value))
cos = op.Gather(cos_2d, position_ids, axis=0)
sin_2d = op.Constant(value=ir.tensor(sin_value))
sin = op.Gather(sin_2d, position_ids, axis=0)
return cos, sin

Check warning on line 57 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L48-L57

Added lines #L48 - L57 were not covered by tests


_rule = CosSinCacheFusion.rule("CosSinCache", 2048)

cos_sin_cache_rules = pattern.RewriteRuleSet([_rule])


def fuse_cos_sin_cache(model: ir.Model) -> None:
count = cos_sin_cache_rules.apply_to_model(model)
print(f"CosSinCache count: {count}")
print(f"CosSinCache count: {count}")

Check warning on line 67 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L66-L67

Added lines #L66 - L67 were not covered by tests

0 comments on commit 1fdc19b

Please sign in to comment.