Skip to content

Commit

Permalink
Merge pull request #261 from KernelTuner/update-pmt
Browse files Browse the repository at this point in the history
Update PMTObserver for latest PMT changes
  • Loading branch information
benvanwerkhoven authored Jun 6, 2024
2 parents 99b5c90 + db9fc45 commit 8ebf8b8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/cuda/vector_add_observers_pmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def tune():
tune_params = dict()
tune_params["block_size_x"] = [128+64*i for i in range(15)]

pmtobserver = PMTObserver(["nvml", "rapl"])
pmtobserver = PMTObserver([("nvidia", 0), "rapl"])

metrics = OrderedDict()
metrics["GPU W"] = lambda p: p["nvml_power"]
metrics["GPU W"] = lambda p: p["nvidia_power"]
metrics["CPU W"] = lambda p: p["rapl_power"]

results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, observers=[pmtobserver], metrics=metrics, iterations=32)
Expand Down
8 changes: 4 additions & 4 deletions kernel_tuner/observers/pmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ def __init__(self, observable=None):
if type(observable) is dict:
pass
elif type(observable) is list:
# user specifies a list of platforms as observable
observable = dict([(obs, 0) for obs in observable])
# user specifies a list of platforms as observable, optionally with an argument
observable = dict([obs if isinstance(obs, tuple) else (obs, None) for obs in observable])
else:
# User specifices a string (single platform) as observable
observable = {observable: None}
supported = ["arduino", "jetson", "likwid", "nvml", "rapl", "rocm", "xilinx"]
supported = ["powersensor2", "powersensor3", "nvidia", "likwid", "rapl", "rocm", "xilinx"]
for obs in observable.keys():
if not obs in supported:
raise ValueError(f"Observable {obs} not in supported: {supported}")

self.pms = [pmt.get_pmt(obs[0], obs[1]) for obs in observable.items()]
self.pms = [pmt.create(obs[0], obs[1]) for obs in observable.items()]
self.pm_names = list(observable.keys())

self.begin_states = [None] * len(self.pms)
Expand Down

0 comments on commit 8ebf8b8

Please sign in to comment.