Skip to content

Commit

Permalink
Merge branch 'master' into directives
Browse files Browse the repository at this point in the history
  • Loading branch information
isazi committed Jun 6, 2024
2 parents 8336cd0 + 41ae1d2 commit 99b5c90
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
</div>

---
[![Build Status](https://github.com/KernelTuner/kernel_tuner/actions/workflows/build-test-python-package.yml/badge.svg)](https://github.com/KernelTuner/kernel_tuner/actions/workflows/build-test-python-package.yml)
[![Build Status](https://github.com/KernelTuner/kernel_tuner/actions/workflows/test-python-package.yml/badge.svg)](https://github.com/KernelTuner/kernel_tuner/actions/workflows/test-python-package.yml)
[![CodeCov Badge](https://codecov.io/gh/KernelTuner/kernel_tuner/branch/master/graph/badge.svg)](https://codecov.io/gh/KernelTuner/kernel_tuner)
[![PyPi Badge](https://img.shields.io/pypi/v/kernel_tuner.svg?colorB=blue)](https://pypi.python.org/pypi/kernel_tuner/)
[![Zenodo Badge](https://zenodo.org/badge/54894320.svg)](https://zenodo.org/badge/latestdoi/54894320)
Expand Down
8 changes: 8 additions & 0 deletions doc/source/observers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,11 @@ More information about PMT can be found here: https://git.astron.nl/RD/pmt/



NCUObserver
~~~~~~~~~~~

The NCUObserver can be used to automatically extract performance counters during tuning using Nvidia's NsightCompute profiler.
The NCUObserver relies on an intermediate library, which can be found here: https://github.com/nlesc-recruit/nvmetrics

.. autoclass:: kernel_tuner.observers.ncu.NCUObserver

57 changes: 57 additions & 0 deletions examples/cuda/vector_add_observers_ncu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python
"""This is the minimal example from the README"""
import json

import numpy
from kernel_tuner import tune_kernel
from kernel_tuner.observers.ncu import NCUObserver

def tune():

kernel_string = """
__global__ void vector_add(float *c, float *a, float *b, int n) {
int i = blockIdx.x * block_size_x + threadIdx.x;
if (i<n) {
c[i] = a[i] + b[i];
}
}
"""

size = 80000000

a = numpy.random.randn(size).astype(numpy.float32)
b = numpy.random.randn(size).astype(numpy.float32)
c = numpy.zeros_like(b)
n = numpy.int32(size)

args = [c, a, b, n]

tune_params = dict()
tune_params["block_size_x"] = [128+64*i for i in range(15)]

ncu_metrics = ["dram__bytes.sum", # Counter byte # of bytes accessed in DRAM
"dram__bytes_read.sum", # Counter byte # of bytes read from DRAM
"dram__bytes_write.sum", # Counter byte # of bytes written to DRAM
"smsp__sass_thread_inst_executed_op_fadd_pred_on.sum", # Counter inst # of FADD thread instructions executed where all predicates were true
"smsp__sass_thread_inst_executed_op_ffma_pred_on.sum", # Counter inst # of FFMA thread instructions executed where all predicates were true
"smsp__sass_thread_inst_executed_op_fmul_pred_on.sum", # Counter inst # of FMUL thread instructions executed where all predicates were true
]

ncuobserver = NCUObserver(metrics=ncu_metrics)

def total_fp32_flops(p):
return p["smsp__sass_thread_inst_executed_op_fadd_pred_on.sum"] + 2 * p["smsp__sass_thread_inst_executed_op_ffma_pred_on.sum"] + p["smsp__sass_thread_inst_executed_op_fmul_pred_on.sum"]

metrics = dict()
metrics["GFLOP/s"] = lambda p: (total_fp32_flops(p) / 1e9) / (p["time"]/1e3)
metrics["Expected GFLOP/s"] = lambda p: (size / 1e9) / (p["time"]/1e3)
metrics["GB/s"] = lambda p: (p["dram__bytes.sum"] / 1e9) / (p["time"]/1e3)
metrics["Expected GB/s"] = lambda p: (size*4*3 / 1e9) / (p["time"]/1e3)

results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, observers=[ncuobserver], metrics=metrics, iterations=7)

return results


if __name__ == "__main__":
tune()
43 changes: 28 additions & 15 deletions kernel_tuner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from kernel_tuner.backends.opencl import OpenCLFunctions
from kernel_tuner.backends.hip import HipFunctions
from kernel_tuner.observers.nvml import NVMLObserver
from kernel_tuner.observers.observer import ContinuousObserver, OutputObserver
from kernel_tuner.observers.observer import ContinuousObserver, OutputObserver, PrologueObserver

try:
import torch
Expand Down Expand Up @@ -314,11 +314,13 @@ def __init__(
)
else:
raise ValueError("Sorry, support for languages other than CUDA, OpenCL, HIP, C, and Fortran is not implemented yet")
self.dev = dev

# look for NVMLObserver in observers, if present, enable special tunable parameters through nvml
self.use_nvml = False
self.continuous_observers = []
self.output_observers = []
self.prologue_observers = []
if observers:
for obs in observers:
if isinstance(obs, NVMLObserver):
Expand All @@ -328,49 +330,61 @@ def __init__(
self.continuous_observers.append(obs.continuous_observer)
if isinstance(obs, OutputObserver):
self.output_observers.append(obs)
if isinstance(obs, PrologueObserver):
self.prologue_observers.append(obs)

# Take list of observers from self.dev because Backends tend to add their own observer
self.benchmark_observers = [
obs for obs in self.dev.observers if not isinstance(obs, (ContinuousObserver, PrologueObserver))
]

self.iterations = iterations

self.lang = lang
self.dev = dev
self.units = dev.units
self.name = dev.name
self.max_threads = dev.max_threads
if not quiet:
print("Using: " + self.dev.name)

def benchmark_prologue(self, func, gpu_args, threads, grid, result):
"""Benchmark prologue one kernel execution per PrologueObserver"""

for obs in self.prologue_observers:
self.dev.synchronize()
obs.before_start()
self.dev.run_kernel(func, gpu_args, threads, grid)
self.dev.synchronize()
obs.after_finish()
result.update(obs.get_results())

def benchmark_default(self, func, gpu_args, threads, grid, result):
"""Benchmark one kernel execution at a time."""
observers = [
obs for obs in self.dev.observers if not isinstance(obs, ContinuousObserver)
]
"""Benchmark one kernel execution for 'iterations' at a time"""

self.dev.synchronize()
for _ in range(self.iterations):
for obs in observers:
for obs in self.benchmark_observers:
obs.before_start()
self.dev.synchronize()
self.dev.start_event()
self.dev.run_kernel(func, gpu_args, threads, grid)
self.dev.stop_event()
for obs in observers:
for obs in self.benchmark_observers:
obs.after_start()
while not self.dev.kernel_finished():
for obs in observers:
for obs in self.benchmark_observers:
obs.during()
time.sleep(1e-6) # one microsecond
self.dev.synchronize()
for obs in observers:
for obs in self.benchmark_observers:
obs.after_finish()

for obs in observers:
for obs in self.benchmark_observers:
result.update(obs.get_results())

def benchmark_continuous(self, func, gpu_args, threads, grid, result, duration):
"""Benchmark continuously for at least 'duration' seconds"""
iterations = int(np.ceil(duration / (result["time"] / 1000)))
# print(f"{iterations=} {(result['time']/1000)=}")
self.dev.synchronize()
for obs in self.continuous_observers:
obs.before_start()
Expand Down Expand Up @@ -420,9 +434,8 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett

result = {}
try:
self.benchmark_default(
func, gpu_args, instance.threads, instance.grid, result
)
self.benchmark_prologue(func, gpu_args, instance.threads, instance.grid, result)
self.benchmark_default(func, gpu_args, instance.threads, instance.grid, result)

if self.continuous_observers:
duration = 1
Expand Down
2 changes: 1 addition & 1 deletion kernel_tuner/observers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .observer import BenchmarkObserver, IterationObserver, ContinuousObserver, OutputObserver
from .observer import BenchmarkObserver, IterationObserver, ContinuousObserver, OutputObserver, PrologueObserver
41 changes: 41 additions & 0 deletions kernel_tuner/observers/ncu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from kernel_tuner.observers import PrologueObserver

try:
import nvmetrics
except (ImportError):
nvmetrics = None

class NCUObserver(PrologueObserver):
"""``NCUObserver`` measures performance counters.
The exact performance counters supported differ per GPU, some examples:
* "dram__bytes.sum", # Counter byte # of bytes accessed in DRAM
* "dram__bytes_read.sum", # Counter byte # of bytes read from DRAM
* "dram__bytes_write.sum", # Counter byte # of bytes written to DRAM
* "smsp__sass_thread_inst_executed_op_fadd_pred_on.sum", # Counter inst # of FADD thread instructions executed where all predicates were true
* "smsp__sass_thread_inst_executed_op_ffma_pred_on.sum", # Counter inst # of FFMA thread instructions executed where all predicates were true
* "smsp__sass_thread_inst_executed_op_fmul_pred_on.sum", # Counter inst # of FMUL thread instructions executed where all predicates were true
:param metrics: The metrics to observe. This should be a list of strings.
You can use ``ncu --query-metrics`` to get a list of valid metrics.
:type metrics: list[str]
"""

def __init__(self, metrics=None, device=0):
if not nvmetrics:
raise ImportError("could not import nvmetrics")

self.metrics = metrics
self.device = device
self.results = dict()

def before_start(self):
nvmetrics.measureMetricsStart(self.metrics, self.device)

def after_finish(self):
self.results = nvmetrics.measureMetricsStop()

def get_results(self):
return dict(zip(self.metrics, self.results))
11 changes: 11 additions & 0 deletions kernel_tuner/observers/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,15 @@ def process_output(self, answer, output):
"""
pass

class PrologueObserver(BenchmarkObserver):
"""Observer that measures something in a seperate kernel invocation prior to the normal benchmark."""

@abstractmethod
def before_start(self):
"""prologue start is called before the kernel starts"""
pass

@abstractmethod
def after_finish(self):
"""prologue finish is called after the kernel has finished execution"""
pass

0 comments on commit 99b5c90

Please sign in to comment.