-
Notifications
You must be signed in to change notification settings - Fork 1
/
plotter.py
105 lines (83 loc) · 3.55 KB
/
plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
from io import BytesIO
import PIL
import numpy as np
import torch
from matplotlib import pyplot as plt
from pydgn.training.callback.plotter import Plotter
from pydgn.training.event.state import State
class UDNPlotter(Plotter):
def on_epoch_end(self, state: State):
super(UDNPlotter, self).on_epoch_end(state)
qL_probs = state.model.variational_L.compute_probability_vector()
plt.figure()
plt.bar(np.arange(qL_probs.shape[0]), qL_probs.detach().cpu().numpy())
buffer = BytesIO()
plt.savefig(buffer, format="png")
buffer.seek(0)
plt.close()
# Convert image buffer to CHW tensor
image = PIL.Image.open(buffer)
image = image.convert("RGB") # Convert to RGB format if necessary
image = np.array(image) # Convert PIL image to NumPy array
image = np.transpose(image, (2, 0, 1)) # Convert HWC to CHW
image_tensor = torch.ByteTensor(image)
self.writer.add_image(
tag="q(ell)", img_tensor=image_tensor, global_step=state.epoch
)
# self.writer.add_histogram(tag='q(ell)',
# values=c.sample((1000,)),
# global_step=state.epoch)
class StatsPlotter(UDNPlotter):
def on_fit_start(self, state: State):
super().on_fit_start(state)
self.distribution_params = {}
# update params file with new parameters for later
params_filepath = os.path.join(self.exp_path, "parameters.torch")
if os.path.exists(params_filepath):
try:
self.distribution_params = torch.load(params_filepath)
except Exception as e:
print(e)
self.distribution_params = {}
def on_epoch_end(self, state: State):
super().on_epoch_end(state)
named_parameters = state.model.get_q_ell_named_parameters()
# detach and bring to cpu
for k, v in named_parameters.items():
named_parameters[k] = v.detach().cpu()
# print on tensorboard
for k, v in named_parameters.items():
assert len(v.shape) == 1
if v.shape[0] == 1: # scalars
self.writer.add_scalar(
tag=k, scalar_value=v.item(), global_step=state.epoch
)
elif v.shape[0] > 1:
for i in range(v.shape[0]):
self.writer.add_scalar(
tag=f"{k}_{i}",
scalar_value=v[i].item(),
global_step=state.epoch,
)
# self.writer.add_histogram(
# tag=k,
# values=v,
# global_step=state.epoch,
# bins="auto",
# )
self.distribution_params[f"{int(state.epoch)}"] = named_parameters
# update params file with new parameters for later
try:
params_filepath = os.path.join(self.exp_path, "parameters.torch")
torch.save(self.distribution_params, params_filepath)
except Exception as e:
print(e)
def on_fit_end(self, state: State):
super().on_fit_end(state)
# update params file with new parameters for later
params_filepath = os.path.join(self.exp_path, "parameters.torch")
try:
torch.save(self.distribution_params, params_filepath)
except Exception as e:
print(e)