From 3c184c075cdb94ffae05ea7424e33dd98a4c09f9 Mon Sep 17 00:00:00 2001 From: Julian Hoever Date: Sat, 9 Sep 2023 11:50:25 +0200 Subject: [PATCH] feat: add parameter getters --- elasticai/creator/nn/fixed_point/linear/layer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/elasticai/creator/nn/fixed_point/linear/layer.py b/elasticai/creator/nn/fixed_point/linear/layer.py index c7c62661..749bcafb 100644 --- a/elasticai/creator/nn/fixed_point/linear/layer.py +++ b/elasticai/creator/nn/fixed_point/linear/layer.py @@ -87,6 +87,22 @@ def __init__( device=device, ) + @property + def lin_weight(self) -> torch.Tensor: + return self._linear.weight + + @property + def lin_bias(self) -> torch.Tensor | None: + return self._linear.bias + + @property + def bn_weight(self) -> torch.Tensor: + return self._batch_norm.weight + + @property + def bn_bias(self) -> torch.Tensor: + return self._batch_norm.bias + def forward(self, x: torch.Tensor) -> torch.Tensor: has_batches = x.dim() == 2 input_shape = x.shape if has_batches else (1, -1)