diff --git a/CHANGELOG.md b/CHANGELOG.md index 63040db8c..026fba7f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ ### Features - Series objects accept `timestamps` and `steps` in their constructors ([#1318](https://github.com/neptune-ai/neptune-client/pull/1318)) +- Added support for `pytorch` integration ([#1337](https://github.com/neptune-ai/neptune-client/pull/1337)) ## neptune 1.1.1 diff --git a/pyproject.toml b/pyproject.toml index e2256f567..775530193 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ neptune-lightgbm = { version = "*", optional = true } pytorch-lightning = { version = "*", optional = true } neptune-optuna = { version = "*", optional = true } neptune-prophet = { version = "*", optional = true } +neptune-pytorch = { version = "*", optional = true } neptune-sacred = { version = "*", optional = true } neptune-sklearn = { version = "*", optional = true } neptune-tensorflow-keras = { version = "*", optional = true } @@ -88,6 +89,7 @@ kedro = ["kedro-neptune"] lightgbm = ["neptune-lightgbm"] optuna = ["neptune-optuna"] prophet = ["neptune-prophet"] +pytorch = ["neptune-pytorch"] pytorch-lightning = ["pytorch-lightning"] sacred = ["neptune-sacred"] sklearn = ["neptune-sklearn"] diff --git a/src/neptune/integrations/pytorch/__init__.py b/src/neptune/integrations/pytorch/__init__.py new file mode 100644 index 000000000..1a30a144c --- /dev/null +++ b/src/neptune/integrations/pytorch/__init__.py @@ -0,0 +1,28 @@ +# +# Copyright (c) 2023, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +try: + from neptune_pytorch.impl import * # noqa: F401,F403 +except ModuleNotFoundError as e: + if e.name == "neptune_pytorch": + from neptune.new.exceptions import NeptuneIntegrationNotInstalledException + + raise NeptuneIntegrationNotInstalledException( + integration_package_name="neptune-pytorch", + framework_name="pytorch", + ) from None + else: + raise