Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add internal counters to send only the latest datapoints to studio #788

Merged
merged 9 commits into from
Feb 28, 2024
5 changes: 0 additions & 5 deletions src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ def log_metrics(
sync: Optional[bool] = False,
) -> None:
if not sync and _should_sync():
if step == self.experiment._latest_studio_step: # noqa: SLF001
# We are in log_eval_end_metrics but there has been already
# a studio request sent with `step`.
# We decrease the number to bypass `live.studio._get_unsent_datapoints`
self.experiment._latest_studio_step -= 1 # noqa: SLF001
sync = True
super().log_metrics(metrics, step, sync)

Expand Down
2 changes: 2 additions & 0 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def __init__(
self._latest_studio_step = self.step if resume else -1
self._studio_events_to_skip: Set[str] = set()
self._dvc_studio_config: Dict[str, Any] = {}
self._num_points_sent_to_studio: Dict[str, int] = {}
self._num_points_read_from_file: Dict[str, int] = {}
self._init_studio()

def _init_resume(self):
Expand Down
28 changes: 12 additions & 16 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
logger = logging.getLogger("dvclive")


def _get_unsent_datapoints(plot, latest_step):
return [x for x in plot if int(x["step"]) > latest_step]


def _cast_to_numbers(datapoints):
for datapoint in datapoints:
for k, v in datapoint.items():
Expand All @@ -39,11 +35,6 @@ def _adapt_path(live, name):
return name


def _adapt_plot_datapoints(live, plot):
datapoints = _get_unsent_datapoints(plot, live._latest_studio_step)
return _cast_to_numbers(datapoints)


def _adapt_image(image_path):
with open(image_path, "rb") as fobj:
return base64.b64encode(fobj.read()).decode("utf-8")
Expand Down Expand Up @@ -71,15 +62,18 @@ def get_studio_updates(live):
metrics_file = _adapt_path(live, metrics_file)
metrics = {metrics_file: {"data": metrics}}

plots = {
_adapt_path(live, name): _adapt_plot_datapoints(live, plot)
for name, plot in plots.items()
}
plots = {k: {"data": v} for k, v in plots.items()}
plots_to_send = {}
live._num_points_read_from_file = {}
for name, plot in plots.items():
path = _adapt_path(live, name)
nb_points_already_sent = live._num_points_sent_to_studio.get(path, 0)
plots_to_send[path] = _cast_to_numbers(plot[nb_points_already_sent:])
live._num_points_read_from_file.update({path: len(plot)})

plots.update(_adapt_images(live))
plots_to_send = {k: {"data": v} for k, v in plots_to_send.items()}
plots_to_send.update(_adapt_images(live))

return metrics, params, plots
return metrics, params, plots_to_send


def get_dvc_studio_config(live):
Expand Down Expand Up @@ -120,6 +114,8 @@ def post_to_studio(live, event):
live._studio_events_to_skip.add("data")
live._studio_events_to_skip.add("done")
elif event == "data":
for path, num_points in live._num_points_read_from_file.items():
live._num_points_sent_to_studio[path] = num_points
AlexandreKempf marked this conversation as resolved.
Show resolved Hide resolved
live._latest_studio_step = live.step

if event == "done":
Expand Down
Loading