This repository has been archived by the owner on Jul 1, 2024. It is now read-only.
forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 65
Extract metric from a test for benchmarkAI #74
Closed
Closed
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import keras | ||
|
||
|
||
class LoggingMetrics: | ||
"""Callback that save metrics to a logfile. | ||
|
||
# Arguments | ||
history_callback: instance of `keras.callbacks.History`. | ||
Training parameters | ||
(eg. batch size, number of epochs, loss, acc). | ||
time_callback: instance of `keras.callbacks.Callback`. | ||
Training parameters | ||
(eg. time, time-step, speed). | ||
|
||
# Raises | ||
TypeError: In case of invalid object instance. | ||
""" | ||
|
||
def __init__(self, history_callback, time_callback): | ||
self.num_iteration = None | ||
self.metric_list = [] | ||
self.pattern_list = [] | ||
self.retrieve_metrics(history_callback, time_callback) | ||
|
||
def retrieve_metrics(self, history_callback, time_callback): | ||
if not isinstance(history_callback, keras.callbacks.History): | ||
raise TypeError('`history_callback` should be an instance of ' | ||
'`keras.callbacks.History`') | ||
if not isinstance(time_callback, keras.callbacks.Callback): | ||
raise TypeError('`time_callback` should be an instance of ' | ||
'`keras.callbacks.Callback`') | ||
|
||
if hasattr(history_callback, 'epoch'): | ||
self.metric_list.append(history_callback.epoch) | ||
self.pattern_list.append('[Epoch %d]\t') | ||
|
||
if hasattr(time_callback, 'times'): | ||
self.metric_list.append(time_callback.get_time()) | ||
self.metric_list.append(time_callback.get_time_step()) | ||
self.metric_list.append(time_callback.get_speed()) | ||
self.pattern_list.append('time: %s\t') | ||
self.pattern_list.append('time_step: %s\t') | ||
self.pattern_list.append('speed: %s\t') | ||
|
||
if 'loss' in history_callback.history: | ||
self.metric_list.append(history_callback.history['loss']) | ||
self.pattern_list.append('train_loss: %.4f\t') | ||
|
||
if 'acc' in history_callback.history: | ||
self.metric_list.append(history_callback.history['acc']) | ||
self.pattern_list.append('train_acc: %.4f\t') | ||
|
||
if 'val_loss' in history_callback.history: | ||
self.metric_list.append(history_callback.history['val_loss']) | ||
self.pattern_list.append('val_loss: %.4f\t') | ||
|
||
if 'val_acc' in history_callback.history: | ||
self.metric_list.append(history_callback.history['val_acc']) | ||
self.pattern_list.append('val_acc: %.4f\t') | ||
|
||
self.num_iteration = history_callback.params['epochs'] | ||
|
||
def get_metrics_index(self, idx): | ||
idx_metric_list = [] | ||
for metric in self.metric_list: | ||
idx_metric_list.append(metric[idx]) | ||
return tuple(idx_metric_list) | ||
|
||
def save_metrics_to_log(self, logging): | ||
pattern_str = ''.join(self.pattern_list) | ||
for i in range(self.num_iteration): | ||
logging.info(pattern_str % self.get_metrics_index(i)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,69 @@ | |
|
||
|
||
class TimeHistory(keras.callbacks.Callback): | ||
def on_train_begin(self, logs={}): | ||
"""Callback that extract execution time of every epoch, time-step, | ||
and speed in terms of sample per sec | ||
""" | ||
|
||
def __init__(self): | ||
super(TimeHistory, self).__init__() | ||
self.times = [] | ||
|
||
def on_train_begin(self, logs=None): | ||
self.times = [] | ||
|
||
def on_epoch_begin(self, batch, logs={}): | ||
def on_epoch_begin(self, batch, logs=None): | ||
self.epoch_time_start = time.time() | ||
|
||
def on_epoch_end(self, batch, logs={}): | ||
def on_epoch_end(self, batch, logs=None): | ||
self.times.append(time.time() - self.epoch_time_start) | ||
|
||
def get_num_samples(self): | ||
if 'samples' in self.params: | ||
return self.params['samples'] | ||
elif 'steps' in self.params: | ||
return self.params['steps'] | ||
else: | ||
raise ValueError('Incorrect metric parameter') | ||
|
||
def reformat(self, var): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please make this a private method & limit its scope or move it to utils such that LoggingMetrics class can take advantage of it while appending to pattern list |
||
if var >= 1: | ||
var = '%.2f ' % var | ||
time_format = 'sec' | ||
elif var >= 1e-3: | ||
var = '%.2f ' % (var * 1e3) | ||
time_format = 'msec' | ||
else: | ||
var = '%.2f ' % (var * 1e6) | ||
time_format = 'usec' | ||
return var, time_format | ||
|
||
def get_time_step(self): | ||
time_list = [] | ||
num_samples = self.get_num_samples() | ||
for t in self.times: | ||
speed = t / num_samples | ||
speed, time_format = self.reformat(speed) | ||
time_list.append(speed + time_format + '/step') | ||
return time_list | ||
|
||
def get_total_time(self): | ||
total_time = sum(self.times) | ||
total_time, time_format = self.reformat(total_time) | ||
return total_time + time_format | ||
|
||
def get_time(self): | ||
time_list = [] | ||
for t in self.times: | ||
time, time_format = self.reformat(t) | ||
time_list.append(time + time_format) | ||
return time_list | ||
|
||
def get_speed(self): | ||
samples_list = [] | ||
num_samples = self.get_num_samples() | ||
for t in self.times: | ||
sample_sec = num_samples / t | ||
sample_sec, time_format = self.reformat(sample_sec) | ||
samples_list.append(sample_sec + 'samples/' + time_format) | ||
return samples_list |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please move the metric constants to constant.py file & iterate over those to increase modularity