Skip to content

Commit

Permalink
unify solution with comet_ml.start()
Browse files Browse the repository at this point in the history
  • Loading branch information
japdubengsub committed Aug 29, 2024
1 parent 009cab0 commit ecc3d98
Showing 1 changed file with 79 additions and 109 deletions.
188 changes: 79 additions & 109 deletions src/lightning/pytorch/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,19 @@
from torch.nn import Module
from typing_extensions import override

from lightning.fabric.utilities.logger import _add_prefix, _convert_params
from lightning.fabric.utilities.logger import _convert_params
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_only

if TYPE_CHECKING:
from comet_ml import ExistingExperiment, Experiment, OfflineExperiment
from comet_ml import ExistingExperiment, Experiment, OfflineExperiment, ExperimentConfig, BaseExperiment

log = logging.getLogger(__name__)
_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml")

comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"]
framework = "pytorch-lightning"


class CometLogger(Logger):
r"""Track your parameters, metrics, source code and more using `Comet
Expand Down Expand Up @@ -104,6 +106,9 @@ def __init__(self, *args, **kwarg):
# log multiple parameters
logger.log_hyperparams({"batch_size": 16, "learning_rate": 0.001})
# log nested parameters
logger.log_hyperparams({"specific": {'param': {'subparam': "value"}}})
**Log Metrics:**
.. code-block:: python
Expand All @@ -114,6 +119,9 @@ def __init__(self, *args, **kwarg):
# add multiple metrics
logger.log_metrics({"train/loss": 0.001, "val/loss": 0.002})
# add nested metrics
logger.log_hyperparams({"specific": {'metric': {'submetric': "value"}}})
**Access the Comet Experiment object:**
You can gain access to the underlying Comet
Expand Down Expand Up @@ -164,92 +172,73 @@ def __init__(self, *args, **kwarg):
- `Comet Documentation <https://www.comet.com/docs/v2/integrations/ml-frameworks/pytorch-lightning/>`__
Args:
api_key: Required in online mode. API key, found on Comet.com. If not given, this
will be loaded from the environment variable COMET_API_KEY or ~/.comet.config
if either exists.
save_dir: Required in offline mode. The path for the directory to save local
comet logs. If given, this also sets the directory for saving checkpoints.
project_name: Optional. Send your experiment to a specific project.
Otherwise, will be sent to Uncategorized Experiments.
If the project name does not already exist, Comet.com will create a new project.
experiment_name: Optional. String representing the name for this particular experiment on Comet.com.
experiment_key: Optional. If set, restores from existing experiment.
offline: If api_key and save_dir are both given, this determines whether
the experiment will be in online or offline mode. This is useful if you use
save_dir to control the checkpoints directory and have a ~/.comet.config
file but still want to run offline experiments.
prefix: A string to put at the beginning of metric keys.
**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by
api_key (str, optional): Comet API key. It's recommended to configure the API Key with `comet login`.
workspace (str, optional): Comet workspace name. If not provided, uses the default workspace.
project (str, optional): Comet project name. Defaults to `Uncategorized`.
experiment_key (str, optional): The Experiment identifier to be used for logging. This is used either to append
data to an Existing Experiment or to control the key of new experiments (for example to match another
identifier). Must be an alphanumeric string whose length is between 32 and 50 characters.
mode (str, optional): Control how the Comet experiment is started.
* ``"get_or_create"``: Starts a fresh experiment if required, or persists logging to an existing one.
* ``"get"``: Continue logging to an existing experiment identified by the ``experiment_key`` value.
* ``"create"``: Always creates of a new experiment, useful for HPO sweeps.
online (boolean, optional): If True, the data will be logged to Comet server, otherwise it will be stored
locally in an offline experiment. Default is ``True``.
**kwargs: Additional arguments like `experiment_name`, `log_code`, `prefix`, `offline_directory` etc. used by
:class:`CometExperiment` can be passed as keyword arguments in this logger.
Raises:
ModuleNotFoundError:
If required Comet package is not installed on the device.
MisconfigurationException:
If neither ``api_key`` nor ``save_dir`` are passed as arguments.
ValueError: If no API Key is set in online mode.
ExperimentNotFound: If mode="get" and the experiment_key doesn't exist, or you don't have access to it.
InvalidExperimentMode:
* If mode="get" but no experiment_key was passed or configured.
* If mode="create", an experiment_key was passed or configured and
an Experiment with that Key already exists.
"""

LOGGER_JOIN_CHAR = "-"

def __init__(
self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
project_name: Optional[str] = None,
experiment_name: Optional[str] = None,
workspace: Optional[str] = None,
project: Optional[str] = None,
experiment_key: Optional[str] = None,
offline: bool = False,
prefix: str = "",
mode: Optional[Literal["get_or_create", "get", "create"]] = None,
online: Optional[bool] = None,
**kwargs: Any,
):
if not _COMET_AVAILABLE:
raise ModuleNotFoundError(str(_COMET_AVAILABLE))

self._save_dir: Optional[str]
self.api_key: str
self.mode: Literal["online", "offline"]

super().__init__()

# needs to be set before the first `comet_ml` import
# because comet_ml imported after another machine learning libraries (Torch)
os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"

self._prefix = kwargs.pop("prefix", None)

import comet_ml

comet_experiment = Union[comet_ml.Experiment, comet_ml.ExistingExperiment, comet_ml.OfflineExperiment]
self._experiment: Optional[comet_experiment] = None

# Determine online or offline mode based on which arguments were passed to CometLogger
api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config())

if api_key is not None and save_dir is not None:
self.mode = "offline" if offline else "online"
self.api_key = api_key
self._save_dir = save_dir
elif api_key is not None:
self.mode = "online"
self.api_key = api_key
self._save_dir = None
elif save_dir is not None:
self.mode = "offline"
self._save_dir = save_dir
else:
# If neither api_key nor save_dir are passed as arguments, raise an exception
raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.")

log.info(f"CometLogger will be initialized in {self.mode} mode")

self._project_name: Optional[str] = project_name
self._experiment_key: str = experiment_key or os.environ.get("COMET_EXPERIMENT_KEY") or comet_ml.generate_guid()
self._experiment_name: Optional[str] = experiment_name
self._prefix: str = prefix
self._kwargs: Dict[str, Any] = kwargs
comet_config = comet_ml.ExperimentConfig(**kwargs)

self._experiment = comet_ml.start(
api_key=api_key,
workspace=workspace,
project=project,
experiment_key=experiment_key,
mode=mode,
online=online,
experiment_config=comet_config,
)

self._experiment.log_other("Created from", "pytorch-lightning")

@property
@rank_zero_experiment
def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperiment"]:
def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperiment", "BaseExperiment"]:
r"""Actual Comet object. To use Comet features in your :class:`~lightning.pytorch.core.LightningModule` do the
following.
Expand All @@ -258,34 +247,16 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi
self.logger.experiment.some_comet_function()
"""
if self._experiment is not None and self._experiment.alive:
return self._experiment

import comet_ml

comet_comfig = comet_ml.ExperimentConfig(
offline_directory=self._save_dir,
name=self._experiment_name,
**self._kwargs,
)

self._experiment = comet_ml.start(
api_key=self.api_key,
project=self._project_name,
experiment_key=self._experiment_key,
online=self.mode == "online",
experiment_config=comet_comfig,
)

self._experiment.log_other("Created from", "pytorch-lightning")

return self._experiment

@override
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
self.experiment.log_parameters(params)
self.experiment.__internal_api__log_parameters__(
parameters=params,
framework=framework,
)

@override
@rank_zero_only
Expand All @@ -298,23 +269,25 @@ def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optiona
metrics_without_epoch[key] = val.cpu().detach()

epoch = metrics_without_epoch.pop("epoch", None)
metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR)
self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch)
self.experiment.__internal_api__log_metrics__(
metrics_without_epoch,
step=step,
epoch=epoch,
prefix=self._prefix,
framework=framework,
)

@override
@rank_zero_only
def finalize(self, status: str) -> None:
"""
We will not end experiment (self._experiment.end()) here
to have an ability to continue using it after training is complete
but instead of ending we will upload/save all the data
"""
"""We will not end experiment (self._experiment.end()) here to have an ability to continue using it after
training is complete but instead of ending we will upload/save all the data."""
if self._experiment is None:
# When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
# initialized there
return

# just save the data
self.experiment.flush()

@property
Expand All @@ -326,46 +299,41 @@ def save_dir(self) -> Optional[str]:
The path to the save directory.
"""
return self._save_dir
import comet_ml
if isinstance(self._experiment, comet_ml.OfflineExperiment):
return self._experiment.offline_directory

return None

@property
@override
def name(self) -> str:
"""Gets the project name.
Returns:
The project name if it is specified, else "comet-default".
The project name.
"""
# Don't create an experiment if we don't have one
if self._experiment is not None and self._experiment.project_name is not None:
return self._experiment.project_name

if self._project_name is not None:
return self._project_name

return "comet-default"
return self._experiment.project_name

@property
@override
def version(self) -> str:
"""Gets the version.
Returns:
experiment id/key
"""
if self._experiment is not None:
return self._experiment.id
experiment key
return self._experiment_key
"""
return self._experiment.get_key()

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()

# Save the experiment id in case an experiment object already exists,
# this way we could create an ExistingExperiment pointing to the same
# experiment
state["_experiment_key"] = self._experiment.id if self._experiment is not None else None
state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None

# Remove the experiment object as it contains hard to pickle objects
# (like network connections), the experiment object will be recreated if
Expand All @@ -375,5 +343,7 @@ def __getstate__(self) -> Dict[str, Any]:

@override
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
if self._experiment is not None:
self._experiment.set_model_graph(model)
self._experiment.__internal_api__set_model_graph__(
graph=model,
framework=framework,
)

0 comments on commit ecc3d98

Please sign in to comment.