Skip to content

Commit

Permalink
Improve Comet Logger pickled behavior
Browse files Browse the repository at this point in the history
* Delay the creation of the actual experiment object for as long as we can.
* Save the experiment id in case an Experiment object is created so we can
  continue the same experiment in the sub-processes.
* Run pre-commit on the comet file.
  • Loading branch information
Lothiraldan committed Jul 8, 2020
1 parent d3f5717 commit e0c7a15
Showing 1 changed file with 39 additions and 36 deletions.
75 changes: 39 additions & 36 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from comet_ml import ExistingExperiment as CometExistingExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
from comet_ml import BaseExperiment as CometBaseExperiment

try:
from comet_ml.api import API
except ImportError: # pragma: no-cover
# For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300
from comet_ml.papi import API # pragma: no-cover

_COMET_AVAILABLE = True
except ImportError: # pragma: no-cover
CometExperiment = None
Expand All @@ -26,7 +26,6 @@
API = None
_COMET_AVAILABLE = False


import torch
from torch import is_tensor

Expand Down Expand Up @@ -90,19 +89,23 @@ class CometLogger(LightningLoggerBase):
experiment_key: Optional. If set, restores from existing experiment.
"""

def __init__(self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
**kwargs):
def __init__(
self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
**kwargs,
):

if not _COMET_AVAILABLE:
raise ImportError('You want to use `comet_ml` logger which is not installed yet,'
' install it with `pip install comet-ml`.')
raise ImportError(
"You want to use `comet_ml` logger which is not installed yet,"
" install it with `pip install comet-ml`."
)
super().__init__()
self._experiment = None
self._save_dir = save_dir
Expand All @@ -123,6 +126,7 @@ def __init__(self,
self.workspace = workspace
self.project_name = project_name
self.experiment_key = experiment_key
self.experiment_name = experiment_name
self._kwargs = kwargs

if rest_api_key is not None:
Expand All @@ -133,11 +137,6 @@ def __init__(self,
self.rest_api_key = None
self.comet_api = None

if experiment_name:
try:
self.name = experiment_name
except TypeError:
log.exception("Failed to set experiment name for comet.ml logger")
self._kwargs = kwargs

@property
Expand All @@ -158,10 +157,7 @@ def experiment(self) -> CometBaseExperiment:
if self.mode == "online":
if self.experiment_key is None:
self._experiment = CometExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
api_key=self.api_key, workspace=self.workspace, project_name=self.project_name, **self._kwargs
)
self.experiment_key = self._experiment.get_key()
else:
Expand All @@ -170,16 +166,19 @@ def experiment(self) -> CometBaseExperiment:
workspace=self.workspace,
project_name=self.project_name,
previous_experiment=self.experiment_key,
**self._kwargs
**self._kwargs,
)
else:
self._experiment = CometOfflineExperiment(
offline_directory=self.save_dir,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
**self._kwargs,
)

if self.experiment_name:
self._experiment.set_name(self.experiment_name)

return self._experiment

@rank_zero_only
Expand All @@ -189,13 +188,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
self.experiment.log_parameters(params)

@rank_zero_only
def log_metrics(
self,
metrics: Dict[str, Union[torch.Tensor, float]],
step: Optional[int] = None
) -> None:
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
Expand Down Expand Up @@ -225,18 +219,27 @@ def save_dir(self) -> Optional[str]:
return self._save_dir

@property
def name(self) -> str:
return str(self.experiment.project_name)
def name(self) -> Optional[str]:
# don't create an experiment if we don't have one
return self._experiment.project_name if self._experiment else None

@name.setter
def name(self, value: str) -> None:
self.experiment.set_name(value)
self.project_name = value

if self._experiment is not None:
self._experiment.set_name(value)

@property
def version(self) -> str:
return self.experiment.id
def version(self) -> Optional[str]:
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else None

def __getstate__(self):
state = self.__dict__.copy()
# args needed to reload correct experiment
state["experiment_key"] = self._experiment.id if self._experiment is not None else None

# cannot be pickled
state["_experiment"] = None
return state

0 comments on commit e0c7a15

Please sign in to comment.