Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Fix some lint issues #241

Merged
merged 13 commits into from
Nov 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,14 @@ def __repr__(self):
def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}"

M = self.M
# This is a hack to utilize tensor core
if isinstance(M, int) and M < 16:
M = 16

# Simple TIR Compute Expression
ir_module = matmul_select_implementation(
M=self.M,
M=M,
N=self.N,
K=self.K,
in_dtype=self.in_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def main(
# Apply memory layout optimizations
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
B_shared: make_swizzle_layout(B_shared, is_smooth=True),
})

# Optional rasterization for L2 locality enhancement
Expand Down
32 changes: 32 additions & 0 deletions bitblas/tl/base_hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,50 @@
from typing import Dict


# Base class for Tensor Layout Hints that defines the interface and common functionality for derived classes.
class BaseTLHint(ABC):

# Constructor for the BaseTLHint class, takes variable arguments (*args and **kwargs) to allow flexibility.
def __init__(self, *args, **kwargs):
# Calls the superclass constructor (useful in complex inheritance hierarchies).
super().__init__(*args, **kwargs)

# Representation method to get a string representation of the object.
# This method is not implemented here, so derived classes should provide their own implementation.
def __repr__(self):
raise NotImplementedError("method __repr__ is not implemented")

# Class method to create an instance of BaseTLHint (or a derived class) from a Hint object.
# This method needs to be implemented by subclasses.
@classmethod
def from_roller_hint(self, hint: Hint) -> 'BaseTLHint':
raise NotImplementedError("method from_roller_hint is not implemented")

# Abstract method to retrieve configuration parameters.
# Derived classes must implement this method and return a dictionary of configuration parameters.
@abstractmethod
def get_config_params(self) -> Dict:
pass

# Allows the object to be accessed like a dictionary.
# Retrieves a configuration parameter by key using the dictionary returned by get_config_params.
def __getitem__(self, key):
return self.get_config_params()[key]

# Handles attempts to access non-existent attributes.
# If the attribute is `get_config_params`, it returns the method itself.
# Otherwise, raises an AttributeError if the attribute is not found.
def __getattr__(self, item):
# If the attribute is not found in the class, try to find it in the hint object
if item == 'get_config_params':
return self.get_config_params
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")

# Allows the object to be iterated as if it were a dictionary.
# Returns an iterator over the items (key-value pairs) in the configuration parameters.
def __iter__(self):
return iter(self.get_config_params().items())

# Returns the keys of the configuration parameters as if the object were a dictionary.
def keys(self):
return self.get_config_params()
Loading