diff --git a/cranberry/features/datasets.py b/cranberry/features/datasets.py index 3d79b15..30609b4 100644 --- a/cranberry/features/datasets.py +++ b/cranberry/features/datasets.py @@ -12,9 +12,7 @@ from tqdm import tqdm from cranberry import Tensor -_cache_dir: str = getenv( - "XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache") -) +_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")) def fetch( @@ -43,9 +41,7 @@ def fetch( progress_bar.update(f.write(chunk)) f.close() if (file_size := os.stat(f.name).st_size) < total_length: - raise RuntimeError( - f"fetch size incomplete, {file_size} < {total_length}" - ) + raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}") pathlib.Path(f.name).rename(fp) return fp @@ -53,9 +49,7 @@ def fetch( def _fetch_mnist(file, offset): return Tensor( np.frombuffer( - gzip.open( - fetch("https://storage.googleapis.com/cvdf-datasets/mnist/" + file) - ).read()[offset:], + gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/" + file)).read()[offset:], dtype=np.uint8, ) ) diff --git a/cranberry/nn/__init__.py b/cranberry/nn/__init__.py index c2bfdda..36b93d6 100644 --- a/cranberry/nn/__init__.py +++ b/cranberry/nn/__init__.py @@ -26,13 +26,9 @@ def parameters(self) -> List[Tensor]: class Linear(Module): # https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/__init__.py#L72-L80 def __init__(self, in_features: int, out_features: int, bias=True): - self.weight = Tensor.kaiming_uniform( - shape=(out_features, in_features), a=math.sqrt(5) - ) + self.weight = Tensor.kaiming_uniform(shape=(out_features, in_features), a=math.sqrt(5)) bound = 1 / math.sqrt(in_features) - self.bias = ( - Tensor.uniform(out_features, low=-bound, high=bound) if bias else None - ) + self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None def __call__(self, x: Tensor) -> Tensor: return x.linear(weight=self.weight.transpose(0, 1), bias=self.bias) diff --git a/cranberry/tensor.py b/cranberry/tensor.py index d36b36c..519258a 100644 --- a/cranberry/tensor.py +++ b/cranberry/tensor.py @@ -60,12 +60,8 @@ def __init__( else: raise ValueError(f"Invalid data type {type(data)}") - assert self._shape == Shape( - self._data.shape - ), f"shape {self._shape} must match data shape {self._data.shape}" - assert self._shape == Shape( - self._grad.shape - ), f"shape {self._shape} must match grad shape {self._grad.shape}" + assert self._shape == Shape(self._data.shape), f"shape {self._shape} must match data shape {self._data.shape}" + assert self._shape == Shape(self._grad.shape), f"shape {self._shape} must match grad shape {self._grad.shape}" # self._requires_grad self._requires_grad: bool = requires_grad @@ -81,12 +77,8 @@ def __init__( # ******************************************************** def backward(self): - assert ( - self._requires_grad - ), "cannot call backward on a tensor that doesn't require gradients" - assert ( - self.shape == tuple() - ), f"backward can only be called for scalar tensors, but it has shape {self.shape})" + assert self._requires_grad, "cannot call backward on a tensor that doesn't require gradients" + assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})" topo = [] visited = set() @@ -127,9 +119,7 @@ def _broadcasted(self, other: Tensor) -> Tuple[Tensor, Tensor]: # ******************************************************** def _unary_op(self, op: UnaryOps) -> Tensor: - out = Tensor._dummy( - shape=self._shape, requires_grad=self.requires_grad, prev=(self,), op=op - ) + out = Tensor._dummy(shape=self._shape, requires_grad=self.requires_grad, prev=(self,), op=op) if op == UnaryOps.NEG: out._data -= self._data @@ -197,19 +187,13 @@ def tanh(self) -> Tensor: # TODO: gelu sanity check def gelu(self) -> Tensor: - return ( - 0.5 - * self - * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh()) - ) + return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh()) # ******************************************************** # *************** binary ops *************** # ******************************************************** - def _binary_op( - self, other: Union[Tensor, int, float], reverse: bool, op: BinaryOps - ) -> Tensor: + def _binary_op(self, other: Union[Tensor, int, float], reverse: bool, op: BinaryOps) -> Tensor: if isinstance(other, (int, float)): other = Tensor(other) self, other = self._broadcasted(other) @@ -301,26 +285,18 @@ def __rtruediv__(self, other) -> Tensor: # ******************************************************** def _reduce_op(self, *args, op: ReduceOps) -> Tensor: - out = Tensor._dummy( - shape=Shape(()), requires_grad=self.requires_grad, prev=(self,), op=op - ) + out = Tensor._dummy(shape=Shape(()), requires_grad=self.requires_grad, prev=(self,), op=op) if op == ReduceOps.SUM: dim, keepdim = args out._data = self._data.sum(axis=dim, keepdims=keepdim) out._grad = np.zeros_like(out._data) - out._shape = ( - Shape(out._data.shape) - if isinstance(out._data, np.ndarray) - else Shape(()) - ) + out._shape = Shape(out._data.shape) if isinstance(out._data, np.ndarray) else Shape(()) def backward(): if dim is None or keepdim: self._grad += out._grad else: - o_new_shape = tuple( - 1 if i == dim else s for i, s in enumerate(self.shape) - ) + o_new_shape = tuple(1 if i == dim else s for i, s in enumerate(self.shape)) self._grad += out._grad.reshape(o_new_shape) out._backward = backward @@ -328,22 +304,14 @@ def backward(): dim, keepdim = args out._data = self._data.max(axis=dim, keepdims=keepdim) out._grad = np.zeros_like(out._data) - out._shape = ( - Shape(out._data.shape) - if isinstance(out._data, np.ndarray) - else Shape(()) - ) + out._shape = Shape(out._data.shape) if isinstance(out._data, np.ndarray) else Shape(()) def backward(): if dim is None or keepdim: self._grad += (self._data == out._data) * out._grad else: - o_new_shape = tuple( - 1 if i == dim else s for i, s in enumerate(self.shape) - ) - self._grad += ( - self._data == out._data.reshape(o_new_shape) - ) * out._grad.reshape(o_new_shape) + o_new_shape = tuple(1 if i == dim else s for i, s in enumerate(self.shape)) + self._grad += (self._data == out._data.reshape(o_new_shape)) * out._grad.reshape(o_new_shape) out._backward = backward else: @@ -390,9 +358,7 @@ def log_softmax(self, dim: int = -1) -> Tensor: # ******************************************************** def _movement_op(self, *args, op: MovementOps) -> Tensor: - out = Tensor._dummy( - shape=Shape(args[0]), requires_grad=self.requires_grad, prev=(self,), op=op - ) + out = Tensor._dummy(shape=Shape(args[0]), requires_grad=self.requires_grad, prev=(self,), op=op) if op == MovementOps.RESHAPE: out._data = self._data.reshape(args[0]) out._grad = self._grad.reshape(args[0]) @@ -413,9 +379,7 @@ def backward(): while len(s_shape) < len(o_shape): s_shape = (1,) + s_shape axis = tuple(i for i in range(len(o_shape)) if s_shape[i] == 1) - self._grad += out._grad.sum(axis=axis, keepdims=True).reshape( - self.shape - ) + self._grad += out._grad.sum(axis=axis, keepdims=True).reshape(self.shape) out._backward = backward elif op == MovementOps.PERMUTE: @@ -427,27 +391,21 @@ def backward(): def reshape(self, *shape: int) -> Tensor: assert shape.count(-1) <= 1, "can only specify one unknown dimension" - assert all( - s > 0 or s == -1 for s in shape - ), "shape dimensions must be positive or -1" + assert all(s > 0 or s == -1 for s in shape), "shape dimensions must be positive or -1" if shape.count(-1) == 1: assert ( prod(self._shape) % -prod(shape) == 0 ), f"cannot reshape tensor of size {prod(self._shape)} into shape {shape}" - shape = tuple( - s if s != -1 else prod(self._shape) // -prod(shape) for s in shape - ) - assert prod(shape) == prod( - self._shape - ), f"cannot reshape tensor of size {prod(self._shape)} into shape {shape}" + shape = tuple(s if s != -1 else prod(self._shape) // -prod(shape) for s in shape) + assert prod(shape) == prod(self._shape), f"cannot reshape tensor of size {prod(self._shape)} into shape {shape}" return self._movement_op(shape, op=MovementOps.RESHAPE) # https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html def expand(self, *shape: int) -> Tensor: - assert ( - len(shape) >= len(self.shape) + assert len(shape) >= len( + self.shape ), f"the expanded shape {shape} must have at least as many dimensions as the original shape {self.shape}" assert all( s == 1 or s == e for s, e in zip(self.shape, shape[-len(self.shape) :]) @@ -465,16 +423,8 @@ def permute(self, *dims: int) -> Tensor: def flatten(self, start_dim: int = 0, end_dim: int = -1) -> Tensor: if end_dim == -1: end_dim = len(self.shape) - assert ( - 0 <= start_dim < end_dim <= len(self.shape) - ), "invalid start_dim or end_dim" - return self.reshape( - *( - self.shape[:start_dim] - + (prod(self.shape[start_dim:end_dim]),) - + self.shape[end_dim:] - ) - ) + assert 0 <= start_dim < end_dim <= len(self.shape), "invalid start_dim or end_dim" + return self.reshape(*(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim]),) + self.shape[end_dim:])) def transpose(self, dim1: int, dim2: int) -> Tensor: dims = list(range(len(self.shape))) @@ -492,9 +442,7 @@ def matmul_2d(self, other: Tensor) -> Tensor: assert ( len(self.shape) == 2 and len(other.shape) == 2 ), "matmul_2d only supports 2D tensors, but got shapes {self.shape} and {other.shape}" - assert ( - self.shape[1] == other.shape[0] - ), f"matmul_2d shape mismatch: {self.shape} and {other.shape}" + assert self.shape[1] == other.shape[0], f"matmul_2d shape mismatch: {self.shape} and {other.shape}" N, M, K = self.shape[0], self.shape[1], other.shape[1] return (self.reshape(N, 1, M) * other.permute(1, 0).reshape(1, K, M)).sum(dim=2) @@ -502,9 +450,7 @@ def matmul(self, other: Tensor) -> Tensor: # https://pytorch.org/docs/stable/generated/torch.matmul.html # if both tensors are 1-dimensional, the dot product (scalar) is returned if len(self.shape) == 1 and len(other.shape) == 1: - assert ( - self.shape[0] == other.shape[0] - ), f"matmul shape mismatch: {self.shape} and {other.shape}" + assert self.shape[0] == other.shape[0], f"matmul shape mismatch: {self.shape} and {other.shape}" return self.mul(other).sum() # if both arguments are 2-dimensional, the matrix-matrix product is returned elif len(self.shape) == 2 and len(other.shape) == 2: @@ -521,14 +467,8 @@ def matmul(self, other: Tensor) -> Tensor: # if the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after # if the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after # the non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable) - elif ( - len(self.shape) >= 1 - and len(other.shape) >= 1 - and (len(self.shape) > 2 or len(other.shape) > 2) - ): - raise NotImplementedError( - "batched matrix multiply is not implemented yet: {self.shape} and {other.shape}" - ) + elif len(self.shape) >= 1 and len(other.shape) >= 1 and (len(self.shape) > 2 or len(other.shape) > 2): + raise NotImplementedError("batched matrix multiply is not implemented yet: {self.shape} and {other.shape}") else: raise RuntimeError(f"Invalid matmul shapes {self.shape} and {other.shape}") @@ -557,9 +497,7 @@ def sparse_categorical_crossentropy(self, Y: Tensor) -> Tensor: Y_pred = self.log_softmax() # TODO: need more efficient implementation. currently, it's not possible to use Y as a tensor of indices Y_onehot_data = np.zeros_like(Y_pred._data) - Y_onehot_data[ - np.arange(prod(Y._data.shape)), (Y._data + 1e-5).astype(np.int32) - ] = 1 + Y_onehot_data[np.arange(prod(Y._data.shape)), (Y._data + 1e-5).astype(np.int32)] = 1 Y_onehot = Tensor(Y_onehot_data) return -(Y_onehot * Y_pred).sum() / prod(Y._data.shape) # reduction="mean" @@ -606,9 +544,7 @@ def uniform(shape: Union[Tuple[int, ...], int], low=0.0, high=1.0) -> Tensor: def kaiming_uniform(shape: Union[Tuple[int, ...], int], a: float = 0.01) -> Tensor: if isinstance(shape, int): shape = (shape,) - bound = ( - math.sqrt(3.0) * math.sqrt(2.0 / (1 + a**2)) / math.sqrt(prod(shape[1:])) - ) + bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a**2)) / math.sqrt(prod(shape[1:])) return Tensor.uniform(shape, low=-bound, high=bound) # ******************************************************** @@ -616,9 +552,7 @@ def kaiming_uniform(shape: Union[Tuple[int, ...], int], a: float = 0.01) -> Tens # ******************************************************** @staticmethod - def _dummy( - shape: Shape, requires_grad: bool, prev: Optional[Tuple[Tensor, ...]], op: Op - ) -> Tensor: + def _dummy(shape: Shape, requires_grad: bool, prev: Optional[Tuple[Tensor, ...]], op: Op) -> Tensor: return Tensor( data=np.zeros(shape.dims), shape=shape, @@ -648,9 +582,7 @@ def ones(shape: Tuple[int], requires_grad: bool = False) -> Tensor: ) def detach(self) -> Tensor: - return Tensor( - data=self._data, shape=self._shape, requires_grad=False, prev=None, op=None - ) + return Tensor(data=self._data, shape=self._shape, requires_grad=False, prev=None, op=None) def numpy(self) -> np.ndarray: return self._data @@ -679,9 +611,7 @@ def size(self, dim: Optional[int] = None): return self.shape if dim is None else self.shape[dim] def item(self) -> float: - assert ( - self._shape == () - ), f"item() only supports tensors with a single element, but got shape {self.shape}" + assert self._shape == (), f"item() only supports tensors with a single element, but got shape {self.shape}" return self._data.item() def __hash__(self): diff --git a/examples/mnist.py b/examples/mnist.py index ca34ab8..433027f 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -7,9 +7,7 @@ X_train, Y_train, X_test, Y_test = mnist() X_train, X_test = X_train.flatten(1), X_test.flatten(1) -model = nn.Sequential( - nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10) -) +model = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10)) optimizer = optim.SGD(model.parameters(), lr=0.001) # TODO: use Adam diff --git a/poetry.lock b/poetry.lock index 0015a76..84704d6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -181,24 +181,24 @@ files = [ [[package]] name = "maturin" -version = "1.7.0" +version = "1.7.1" description = "Build and publish crates with pyo3, cffi and uniffi bindings as well as rust binaries as python packages" optional = false python-versions = ">=3.7" files = [ - {file = "maturin-1.7.0-py3-none-linux_armv6l.whl", hash = "sha256:15fe7920391a128897714f6ed38ebbc771150410b795a55cefca73f089d5aecb"}, - {file = "maturin-1.7.0-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:87a1fae70f1a6ad694832c735abf9f010edc4971c5cf89d2e7a54651a1a3792a"}, - {file = "maturin-1.7.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6fd312c56846d3cafa7c45e362d96b526170e79b9adb5b8ea02a10c88906069c"}, - {file = "maturin-1.7.0-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:928b82ceba924b1642c53f6684271e814b5ce5049cb4d35ff36bed078837eb83"}, - {file = "maturin-1.7.0-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:7460122333971b2492154c102d2981ae337ae0486dde7f4df7e645d724de59a5"}, - {file = "maturin-1.7.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:1f521ebe0344db8260df0d12779aefc06c1f763cd654151cf4a238fe14f65dc1"}, - {file = "maturin-1.7.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:0af4f2a4cfb99206d414dec138dd3aac3f506eb8928b7e38dfac570461b393d6"}, - {file = "maturin-1.7.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:29187d5c3e1e166c14eaadc63a8adc25b6bbb3e5b055d1bc87f6ca92b4b6e331"}, - {file = "maturin-1.7.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9cd5b992b6c131c5f47c85e7bc266bf5bf94f29720856678431ce6c91b726df"}, - {file = "maturin-1.7.0-py3-none-win32.whl", hash = "sha256:c1ae0b4162fb1152aea83098bf1b66a7bf6dd73fd1b108e6c4e22160118a997c"}, - {file = "maturin-1.7.0-py3-none-win_amd64.whl", hash = "sha256:2bd8227e020a9308c076253f29224c53b08b2a4ed41fcd94b4eb9349684fcfe7"}, - {file = "maturin-1.7.0-py3-none-win_arm64.whl", hash = "sha256:7c05226547778f31b73d48a19d11f57792bcc44f4047b84c73ea66cae2e62473"}, - {file = "maturin-1.7.0.tar.gz", hash = "sha256:1ba5277dd7832dc6181d69a005182b97b3520945825058484ffd9296f2efb59c"}, + {file = "maturin-1.7.1-py3-none-linux_armv6l.whl", hash = "sha256:372a141b31ae7396728d2dedc6061fe4522c1803ae1c05700d37008e1d1a2cc9"}, + {file = "maturin-1.7.1-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:49939608095d9bcdf19d081dfd6ac1e8f915c645115090514c7b86e1e382f241"}, + {file = "maturin-1.7.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:973126a36cfb9861b3207df579678c1bcd7c348578a41ccfbe80d811a84f1740"}, + {file = "maturin-1.7.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:6eec984d26f707b18765478f4892e58ac72e777287cd2ba721d6e2ef6da1f66e"}, + {file = "maturin-1.7.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:0df0a6aaf7e9ab92cce2490b03d80b8f5ecbfa0689747a2ea4dfb9e63877b79c"}, + {file = "maturin-1.7.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:09cca3491c756d1bce6ffff13f004e8a10e67c72a1cba9579058f58220505881"}, + {file = "maturin-1.7.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:00f0f8f5051f4c0d0f69bdd0c6297ea87e979f70fb78a377eb4277c932804e2d"}, + {file = "maturin-1.7.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:7bb184cfbac4e3c55ca21d322e4801e0f75e7932287e156c280c279eae60b69e"}, + {file = "maturin-1.7.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5e8e61468d7d79790f0b54f2ed24f2fefbce3518548bc4e1a1f0c7be5bad710"}, + {file = "maturin-1.7.1-py3-none-win32.whl", hash = "sha256:07c8800603e551a45e16fe7ad1742977097ea43c18b28e491df74d4ca15c5857"}, + {file = "maturin-1.7.1-py3-none-win_amd64.whl", hash = "sha256:c5e7e6d130072ca76956106daa276f24a66c3407cfe6cf64c196d4299fd4175c"}, + {file = "maturin-1.7.1-py3-none-win_arm64.whl", hash = "sha256:acf9f539f53a7ad64d406a40b27b768f67d75e6e4e93cb04b29025144a74ef45"}, + {file = "maturin-1.7.1.tar.gz", hash = "sha256:147754cb3d81177ee12d9baf575d93549e76121dacd3544ad6a50ab718de2b9c"}, ] [package.extras] @@ -253,63 +253,64 @@ files = [ [[package]] name = "numpy" -version = "2.1.0" +version = "2.1.1" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.10" files = [ - {file = "numpy-2.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8"}, - {file = "numpy-2.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911"}, - {file = "numpy-2.1.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673"}, - {file = "numpy-2.1.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b"}, - {file = "numpy-2.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b"}, - {file = "numpy-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f"}, - {file = "numpy-2.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84"}, - {file = "numpy-2.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33"}, - {file = "numpy-2.1.0-cp310-cp310-win32.whl", hash = "sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211"}, - {file = "numpy-2.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa"}, - {file = "numpy-2.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce"}, - {file = "numpy-2.1.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1"}, - {file = "numpy-2.1.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e"}, - {file = "numpy-2.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb"}, - {file = "numpy-2.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3"}, - {file = "numpy-2.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36"}, - {file = "numpy-2.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd"}, - {file = "numpy-2.1.0-cp311-cp311-win32.whl", hash = "sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e"}, - {file = "numpy-2.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b"}, - {file = "numpy-2.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736"}, - {file = "numpy-2.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529"}, - {file = "numpy-2.1.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1"}, - {file = "numpy-2.1.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8"}, - {file = "numpy-2.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745"}, - {file = "numpy-2.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111"}, - {file = "numpy-2.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0"}, - {file = "numpy-2.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574"}, - {file = "numpy-2.1.0-cp312-cp312-win32.whl", hash = "sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02"}, - {file = "numpy-2.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc"}, - {file = "numpy-2.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667"}, - {file = "numpy-2.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e"}, - {file = "numpy-2.1.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33"}, - {file = "numpy-2.1.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b"}, - {file = "numpy-2.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195"}, - {file = "numpy-2.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977"}, - {file = "numpy-2.1.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1"}, - {file = "numpy-2.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62"}, - {file = "numpy-2.1.0-cp313-cp313-win32.whl", hash = "sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324"}, - {file = "numpy-2.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5"}, - {file = "numpy-2.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06"}, - {file = "numpy-2.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883"}, - {file = "numpy-2.1.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300"}, - {file = "numpy-2.1.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d"}, - {file = "numpy-2.1.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd"}, - {file = "numpy-2.1.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6"}, - {file = "numpy-2.1.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a"}, - {file = "numpy-2.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb"}, - {file = "numpy-2.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b"}, - {file = "numpy-2.1.0-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d"}, - {file = "numpy-2.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575"}, - {file = "numpy-2.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2"}, - {file = "numpy-2.1.0.tar.gz", hash = "sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2"}, + {file = "numpy-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9"}, + {file = "numpy-2.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd"}, + {file = "numpy-2.1.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f"}, + {file = "numpy-2.1.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab"}, + {file = "numpy-2.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7"}, + {file = "numpy-2.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6"}, + {file = "numpy-2.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0"}, + {file = "numpy-2.1.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647"}, + {file = "numpy-2.1.1-cp310-cp310-win32.whl", hash = "sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728"}, + {file = "numpy-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae"}, + {file = "numpy-2.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550"}, + {file = "numpy-2.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f"}, + {file = "numpy-2.1.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0"}, + {file = "numpy-2.1.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95"}, + {file = "numpy-2.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca"}, + {file = "numpy-2.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf"}, + {file = "numpy-2.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e"}, + {file = "numpy-2.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2"}, + {file = "numpy-2.1.1-cp311-cp311-win32.whl", hash = "sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d"}, + {file = "numpy-2.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e"}, + {file = "numpy-2.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e"}, + {file = "numpy-2.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe"}, + {file = "numpy-2.1.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f"}, + {file = "numpy-2.1.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521"}, + {file = "numpy-2.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b"}, + {file = "numpy-2.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201"}, + {file = "numpy-2.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a"}, + {file = "numpy-2.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313"}, + {file = "numpy-2.1.1-cp312-cp312-win32.whl", hash = "sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed"}, + {file = "numpy-2.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270"}, + {file = "numpy-2.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5"}, + {file = "numpy-2.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5"}, + {file = "numpy-2.1.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136"}, + {file = "numpy-2.1.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0"}, + {file = "numpy-2.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb"}, + {file = "numpy-2.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df"}, + {file = "numpy-2.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78"}, + {file = "numpy-2.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556"}, + {file = "numpy-2.1.1-cp313-cp313-win32.whl", hash = "sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b"}, + {file = "numpy-2.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0"}, + {file = "numpy-2.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553"}, + {file = "numpy-2.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480"}, + {file = "numpy-2.1.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f"}, + {file = "numpy-2.1.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468"}, + {file = "numpy-2.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef"}, + {file = "numpy-2.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f"}, + {file = "numpy-2.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c"}, + {file = "numpy-2.1.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec"}, + {file = "numpy-2.1.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5"}, + {file = "numpy-2.1.1-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504"}, + {file = "numpy-2.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd"}, + {file = "numpy-2.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39"}, + {file = "numpy-2.1.1.tar.gz", hash = "sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd"}, ] [[package]] @@ -435,14 +436,14 @@ files = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.6.20" +version = "12.6.68" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb"}, - {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79"}, - {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6"}, + {file = "nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b"}, + {file = "nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab"}, + {file = "nvidia_nvjitlink_cu12-12.6.68-py3-none-win_amd64.whl", hash = "sha256:a55744c98d70317c5e23db14866a8cc2b733f7324509e941fc96276f9f37801d"}, ] [[package]] @@ -513,13 +514,13 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] [[package]] name = "pyright" -version = "1.1.377" +version = "1.1.379" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.377-py3-none-any.whl", hash = "sha256:af0dd2b6b636c383a6569a083f8c5a8748ae4dcde5df7914b3f3f267e14dd162"}, - {file = "pyright-1.1.377.tar.gz", hash = "sha256:aabc30fedce0ded34baa0c49b24f10e68f4bfc8f68ae7f3d175c4b0f256b4fcf"}, + {file = "pyright-1.1.379-py3-none-any.whl", hash = "sha256:01954811ac71db8646f50de1577576dc275ffb891a9e7324350e676cf6df323f"}, + {file = "pyright-1.1.379.tar.gz", hash = "sha256:6f426cb6443786fa966b930c23ad1941c8cb9fe672e4589daea8d80bb34193ea"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index d5ff7b8..656de90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,9 @@ pytest = "^8.3.1" torch = "^2.3.1" maturin = "^1.7.0" +[tool.ruff] +line-length = 120 + [build-system] requires = ["maturin>=0.13"] build-backend = "maturin" diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 5072976..95bb37a 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -680,9 +680,7 @@ def test_pytorch(): A = torch.tensor(A_np, requires_grad=True) W = torch.tensor(C_np) b = torch.tensor(E_np) - out = torch.nn.functional.linear( - A, W.transpose(1, 0), b - ).sum() # transpose to match cranberry + out = torch.nn.functional.linear(A, W.transpose(1, 0), b).sum() # transpose to match cranberry out.backward() assert A.grad is not None return out.detach().numpy(), A.grad.detach().numpy()