diff --git a/kernel_tuner/observers/pmt.py b/kernel_tuner/observers/pmt.py index bb1d76bd..4f85b1b3 100644 --- a/kernel_tuner/observers/pmt.py +++ b/kernel_tuner/observers/pmt.py @@ -30,7 +30,7 @@ class PMTObserver(BenchmarkObserver): """ - def __init__(self, observable=None): + def __init__(self, observable=None, use_continuous_observer=False, continuous_duration=1): if not pmt: raise ImportError("could not import pmt") @@ -54,6 +54,9 @@ def __init__(self, observable=None): self.begin_states = [None] * len(self.pms) self.initialize_results(self.pm_names) + if use_continuous_observer: + self.continuous_observer = ContinuousObserver("pmt", [], self, continuous_duration=continuous_duration) + def initialize_results(self, pm_names): self.results = dict() for pm_name in pm_names: @@ -82,3 +85,37 @@ def get_results(self): averages = {key: np.average(values) for key, values in self.results.items()} self.initialize_results(self.pm_names) return averages + + +class PMTContinuousObserver(ContinuousObserver): + """Generic observer that measures power while and continuous benchmarking. + + To support continuous benchmarking an Observer should support: + a .read_power() method, which the ContinuousObserver can call to read power in Watt + """ + def before_start(self): + pass + + def after_start(self): + self.parent.after_start() + + def during(self): + pass + + def after_finish(self): + self.parent.after_finish() + + def get_results(self): + average_kernel_execution_time_ms = self.results["time"] + + averages = {key: np.average(values) for key, values in self.results.items()} + self.parent.initialize_results(self.pm_names) + + # correct energy measurement, because current _energy number is collected over the entire duration + # we estimate energy as the average power over the continuous duration times the kernel execution time + for pm_name in pm_names: + energy_result_name = f"{pm_name}_energy" + power_result_name = f"{pm_name}_power" + averages[energy_result_name] = averages[power_result_name] * (average_kernel_execution_time_ms / 1e3) + + return averages