Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Extract metric from a test for benchmarkAI #74

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions benchmark/scripts/logging_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import keras


class LoggingMetrics:
"""Callback that save metrics to a logfile.

# Arguments
history_callback: instance of `keras.callbacks.History`.
Training parameters
(eg. batch size, number of epochs, loss, acc).
time_callback: instance of `keras.callbacks.Callback`.
Training parameters
(eg. time, time-step, speed).

# Raises
TypeError: In case of invalid object instance.
"""

def __init__(self, history_callback, time_callback):
self.num_iteration = None
self.metric_list = []
self.pattern_list = []
self.retrieve_metrics(history_callback, time_callback)

def retrieve_metrics(self, history_callback, time_callback):
if not isinstance(history_callback, keras.callbacks.History):
raise TypeError('`history_callback` should be an instance of '
'`keras.callbacks.History`')
if not isinstance(time_callback, keras.callbacks.Callback):
raise TypeError('`time_callback` should be an instance of '
'`keras.callbacks.Callback`')

if hasattr(history_callback, 'epoch'):
self.metric_list.append(history_callback.epoch)
self.pattern_list.append('[Epoch %d]\t')

if hasattr(time_callback, 'times'):
self.metric_list.append(time_callback.get_time())
self.metric_list.append(time_callback.get_time_step())
self.metric_list.append(time_callback.get_speed())
self.pattern_list.append('time: %s\t')
self.pattern_list.append('time_step: %s\t')
self.pattern_list.append('speed: %s\t')

if 'loss' in history_callback.history:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move the metric constants to constant.py file & iterate over those to increase modularity

self.metric_list.append(history_callback.history['loss'])
self.pattern_list.append('train_loss: %.4f\t')

if 'acc' in history_callback.history:
self.metric_list.append(history_callback.history['acc'])
self.pattern_list.append('train_acc: %.4f\t')

if 'val_loss' in history_callback.history:
self.metric_list.append(history_callback.history['val_loss'])
self.pattern_list.append('val_loss: %.4f\t')

if 'val_acc' in history_callback.history:
self.metric_list.append(history_callback.history['val_acc'])
self.pattern_list.append('val_acc: %.4f\t')

self.num_iteration = history_callback.params['epochs']

def get_metrics_index(self, idx):
idx_metric_list = []
for metric in self.metric_list:
idx_metric_list.append(metric[idx])
return tuple(idx_metric_list)

def save_metrics_to_log(self, logging):
pattern_str = ''.join(self.pattern_list)
for i in range(self.num_iteration):
logging.info(pattern_str % self.get_metrics_index(i))
64 changes: 61 additions & 3 deletions benchmark/scripts/models/timehistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,69 @@


class TimeHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
"""Callback that extract execution time of every epoch, time-step,
and speed in terms of sample per sec
"""

def __init__(self):
super(TimeHistory, self).__init__()
self.times = []

def on_train_begin(self, logs=None):
self.times = []

def on_epoch_begin(self, batch, logs={}):
def on_epoch_begin(self, batch, logs=None):
self.epoch_time_start = time.time()

def on_epoch_end(self, batch, logs={}):
def on_epoch_end(self, batch, logs=None):
self.times.append(time.time() - self.epoch_time_start)

def get_num_samples(self):
if 'samples' in self.params:
return self.params['samples']
elif 'steps' in self.params:
return self.params['steps']
else:
raise ValueError('Incorrect metric parameter')

def reformat(self, var):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make this a private method & limit its scope or move it to utils such that LoggingMetrics class can take advantage of it while appending to pattern list

if var >= 1:
var = '%.2f ' % var
time_format = 'sec'
elif var >= 1e-3:
var = '%.2f ' % (var * 1e3)
time_format = 'msec'
else:
var = '%.2f ' % (var * 1e6)
time_format = 'usec'
return var, time_format

def get_time_step(self):
time_list = []
num_samples = self.get_num_samples()
for t in self.times:
speed = t / num_samples
speed, time_format = self.reformat(speed)
time_list.append(speed + time_format + '/step')
return time_list

def get_total_time(self):
total_time = sum(self.times)
total_time, time_format = self.reformat(total_time)
return total_time + time_format

def get_time(self):
time_list = []
for t in self.times:
time, time_format = self.reformat(t)
time_list.append(time + time_format)
return time_list

def get_speed(self):
samples_list = []
num_samples = self.get_num_samples()
for t in self.times:
sample_sec = num_samples / t
sample_sec, time_format = self.reformat(sample_sec)
samples_list.append(sample_sec + 'samples/' + time_format)
return samples_list