Skip to content

Commit

Permalink
Merge pull request #208 from NREL/bnb/tensorboard_logging
Browse files Browse the repository at this point in the history
tensorboard logging
  • Loading branch information
bnb32 authored Apr 22, 2024
2 parents f36cc10 + e2d62ae commit a171300
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 81 deletions.
157 changes: 112 additions & 45 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# -*- coding: utf-8 -*-
"""
Abstract class to define the required interface for Sup3r model subclasses
"""
"""Abstract class defining the required interface for Sup3r model subclasses"""
import json
import logging
import os
Expand All @@ -23,10 +21,83 @@
import sup3r.utilities.loss_metrics
from sup3r.preprocessing.data_handling.exogenous_data_handling import ExoData
from sup3r.utilities import VERSION_RECORD
from sup3r.utilities.utilities import Timer

logger = logging.getLogger(__name__)


class TensorboardMixIn:
"""MixIn class for tensorboard logging and profiling."""

def __init__(self):
self._tb_writer = None
self._tb_log_dir = None
self._write_tb_profile = False
self._total_batches = None
self._history = None
self.timer = Timer()

@property
def total_batches(self):
"""Record of total number of batches for logging."""
if self._total_batches is None and self._history is None:
self._total_batches = 0
elif self._history is None and 'total_batches' in self._history:
self._total_batches = self._history['total_batches'].values[-1]
elif self._total_batches is None and self._history is not None:
self._total_batches = 0
return self._total_batches

@total_batches.setter
def total_batches(self, value):
"""Set total number of batches."""
self._total_batches = value

def dict_to_tensorboard(self, entry):
"""Write data to tensorboard log file. This is usually a loss_details
dictionary.
Parameters
----------
entry: dict
Dictionary of values to write to tensorboard log file
"""
if self._tb_writer is not None:
with self._tb_writer.as_default():
for name, value in entry.items():
if isinstance(value, str):
tf.summary.text(name, value, self.total_batches)
else:
tf.summary.scalar(name, value, self.total_batches)

def profile_to_tensorboard(self, name):
"""Write profile data to tensorboard log file.
Parameters
----------
name : str
Tag name to use for profile info
"""
if self._tb_writer is not None and self._write_tb_profile:
with self._tb_writer.as_default():
tf.summary.trace_export(name=name, step=self.total_batches,
profiler_outdir=self._tb_log_dir)

def _init_tensorboard_writer(self, out_dir):
"""Initialize the ``tf.summary.SummaryWriter`` to use for writing
tensorboard compatible log files.
Parameters
----------
out_dir : str
Standard out_dir where model epochs are saved. e.g. './gan_{epoch}'
"""
tb_log_pardir = os.path.abspath(os.path.join(out_dir, os.pardir))
self._tb_log_dir = os.path.join(tb_log_pardir, 'logs')
os.makedirs(self._tb_log_dir, exist_ok=True)
self._tb_writer = tf.summary.create_file_writer(self._tb_log_dir)


class AbstractInterface(ABC):
"""
Abstract class to define the required interface for Sup3r model subclasses
Expand Down Expand Up @@ -371,9 +442,8 @@ def hr_exo_features(self):
# pylint: disable=E1101
features = []
if hasattr(self, '_gen'):
for layer in self._gen.layers:
if isinstance(layer, (Sup3rAdder, Sup3rConcat)):
features.append(layer.name)
features = [layer.name for layer in self._gen.layers
if isinstance(layer, (Sup3rAdder, Sup3rConcat))]
return features

@property
Expand Down Expand Up @@ -465,13 +535,14 @@ def save_params(self, out_dir):


# pylint: disable=E1101,W0201,E0203
class AbstractSingleModel(ABC):
class AbstractSingleModel(ABC, TensorboardMixIn):
"""
Abstract class to define the required training interface
for Sup3r model subclasses
"""

def __init__(self):
super().__init__()
self.gpu_list = tf.config.list_physical_devices('GPU')
self.default_device = '/cpu:0'
self._version_record = VERSION_RECORD
Expand Down Expand Up @@ -743,13 +814,13 @@ def init_optimizer(optimizer, learning_rate):
"""
if isinstance(optimizer, dict):
class_name = optimizer['name']
OptimizerClass = getattr(optimizers, class_name)
sig = signature(OptimizerClass)
optimizer_class = getattr(optimizers, class_name)
sig = signature(optimizer_class)
optimizer_kwargs = {
k: v
for k, v in optimizer.items() if k in sig.parameters
}
optimizer = OptimizerClass.from_config(optimizer_kwargs)
optimizer = optimizer_class.from_config(optimizer_kwargs)
elif optimizer is None:
optimizer = optimizers.Adam(learning_rate=learning_rate)

Expand Down Expand Up @@ -915,10 +986,9 @@ def update_loss_details(loss_details, new_data, batch_len, prefix=None):
prior_n_obs = loss_details['n_obs']
new_n_obs = prior_n_obs + batch_len

for key, new_value in new_data.items():
key = key if prefix is None else prefix + key
new_value = (new_value if not isinstance(new_value, tf.Tensor) else
new_value.numpy())
for k, v in new_data.items():
key = k if prefix is None else prefix + k
new_value = (v if not isinstance(v, tf.Tensor) else v.numpy())

if key in loss_details:
saved_value = loss_details[key]
Expand Down Expand Up @@ -1061,9 +1131,7 @@ def finish_epoch(self,
stop : bool
Flag to early stop training.
"""

self.log_loss_details(loss_details)

self._history.at[epoch, 'elapsed_time'] = time.time() - t0
for key, value in loss_details.items():
if key != 'n_obs':
Expand Down Expand Up @@ -1135,20 +1203,19 @@ def run_gradient_descent(self,
loss_details : dict
Namespace of the breakdown of loss components
"""

t0 = time.time()
if optimizer is None:
optimizer = self.optimizer

if not multi_gpu or len(self.gpu_list) == 1:

grad, loss_details = self.get_single_grad(low_res, hi_res_true,
training_weights,
**calc_loss_kwargs)
optimizer.apply_gradients(zip(grad, training_weights))
t1 = time.time()
logger.debug(f'Finished single gradient descent step '
f'in {(t1 - t0):.3f}s')

else:
futures = []
lr_chunks = np.array_split(low_res, len(self.gpu_list))
Expand Down Expand Up @@ -1178,7 +1245,6 @@ def run_gradient_descent(self,
t1 = time.time()
logger.debug(f'Finished {len(futures)} gradient descent steps on '
f'{len(self.gpu_list)} GPUs in {(t1 - t0):.3f}s')

return loss_details

def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True):
Expand Down Expand Up @@ -1283,8 +1349,10 @@ def generate(self,
low_res = self.norm_input(low_res)

hi_res = self.generator.layers[0](low_res)
for i, layer in enumerate(self.generator.layers[1:]):
try:
layer_num = 1
try:
for i, layer in enumerate(self.generator.layers[1:]):
layer_num = i + 1
if isinstance(layer, (Sup3rAdder, Sup3rConcat)):
msg = (f'layer.name = {layer.name} does not match any '
'features in exogenous_data '
Expand All @@ -1299,11 +1367,11 @@ def generate(self,
hi_res = layer(hi_res, hi_res_exo)
else:
hi_res = layer(hi_res)
except Exception as e:
msg = ('Could not run layer #{} "{}" on tensor of shape {}'.
format(i + 1, layer, hi_res.shape))
logger.error(msg)
raise RuntimeError(msg) from e
except Exception as e:
msg = ('Could not run layer #{} "{}" on tensor of shape {}'.
format(layer_num, layer, hi_res.shape))
logger.error(msg)
raise RuntimeError(msg) from e

hi_res = hi_res.numpy()

Expand Down Expand Up @@ -1341,8 +1409,10 @@ def _tf_generate(self, low_res, hi_res_exo=None):
Synthetically generated high-resolution data
"""
hi_res = self.generator.layers[0](low_res)
for i, layer in enumerate(self.generator.layers[1:]):
try:
layer_num = 1
try:
for i, layer in enumerate(self.generator.layers[1:]):
layer_num = i + 1
if isinstance(layer, (Sup3rAdder, Sup3rConcat)):
msg = (f'layer.name = {layer.name} does not match any '
f'features in exogenous_data ({list(hi_res_exo)})')
Expand All @@ -1351,11 +1421,11 @@ def _tf_generate(self, low_res, hi_res_exo=None):
hi_res = layer(hi_res, hr_exo)
else:
hi_res = layer(hi_res)
except Exception as e:
msg = ('Could not run layer #{} "{}" on tensor of shape {}'.
format(i + 1, layer, hi_res.shape))
logger.error(msg)
raise RuntimeError(msg) from e
except Exception as e:
msg = ('Could not run layer #{} "{}" on tensor of shape {}'.
format(layer_num, layer, hi_res.shape))
logger.error(msg)
raise RuntimeError(msg) from e

return hi_res

Expand Down Expand Up @@ -1398,16 +1468,13 @@ def get_single_grad(self,
loss_details : dict
Namespace of the breakdown of loss components
"""
with tf.device(device_name):
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(training_weights)

hi_res_exo = self.get_high_res_exo_input(hi_res_true)
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
loss_out = self.calc_loss(hi_res_true, hi_res_gen,
**calc_loss_kwargs)
loss, loss_details = loss_out

grad = tape.gradient(loss, training_weights)

with tf.device(device_name), tf.GradientTape(
watch_accessed_variables=False) as tape:
self.timer(tape.watch, training_weights)
hi_res_exo = self.timer(self.get_high_res_exo_input, hi_res_true)
hi_res_gen = self.timer(self._tf_generate, low_res, hi_res_exo)
loss_out = self.timer(self.calc_loss, hi_res_true, hi_res_gen,
**calc_loss_kwargs)
loss, loss_details = loss_out
grad = self.timer(tape.gradient, loss, training_weights)
return grad, loss_details
Loading

0 comments on commit a171300

Please sign in to comment.