diff --git a/src/liger_kernel/transformers/layer_norm.py b/src/liger_kernel/transformers/layer_norm.py index 2b045129c..9590898a7 100644 --- a/src/liger_kernel/transformers/layer_norm.py +++ b/src/liger_kernel/transformers/layer_norm.py @@ -11,6 +11,8 @@ def __init__(self, hidden_size, eps=1e-6, bias=False, init_fn="ones"): "ones", "zeros", ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" + self.hidden_size = hidden_size + self.eps = eps self.weight = nn.Parameter( torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size) ) @@ -23,3 +25,6 @@ def forward(self, hidden_states): return LigerLayerNormFunction.apply( hidden_states, self.weight, self.bias, self.variance_epsilon ) + + def extra_repr(self): + return f"{self.hidden_size}, eps={self.eps}" diff --git a/src/liger_kernel/transformers/rms_norm.py b/src/liger_kernel/transformers/rms_norm.py index c610573e6..3191ac24f 100644 --- a/src/liger_kernel/transformers/rms_norm.py +++ b/src/liger_kernel/transformers/rms_norm.py @@ -30,3 +30,6 @@ def forward(self, hidden_states): self.offset, self.casting_mode, ) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}" diff --git a/src/liger_kernel/triton/monkey_patch.py b/src/liger_kernel/triton/monkey_patch.py index 70863f4e3..590842a83 100644 --- a/src/liger_kernel/triton/monkey_patch.py +++ b/src/liger_kernel/triton/monkey_patch.py @@ -37,6 +37,6 @@ def apply_liger_triton_cache_manager(): Experimental feature to get around transient FileNotFoundError in triton compilation. For more details please see https://github.com/triton-lang/triton/pull/4295 """ - os.environ["TRITON_CACHE_MANAGER"] = ( - "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" - ) + os.environ[ + "TRITON_CACHE_MANAGER" + ] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"