Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA][Launcher] Ensure device context is valid before calling getPointer #5276

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/test/unit/runtime/test_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import sys
from concurrent.futures import ThreadPoolExecutor
import torch

import triton
import triton.language as tl


def test_is_lazy():
Expand All @@ -12,3 +15,27 @@ def test_is_lazy():
assert triton.runtime.driver.active._obj is None
utils = triton.runtime.driver.active.utils # noqa: F841
assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase"))


def test_kernel_in_thread(device):
# Test calling in a new thread sets a valid device context
buf = torch.zeros((38016 * 1024, ), dtype=torch.float32, device=device)

@triton.jit
def _kernel(P, BLOCK: tl.constexpr):
pid = tl.program_id(0).to(tl.int64)
offset = pid * BLOCK + tl.arange(0, BLOCK)

p = tl.load(P + offset)
tl.store(P + offset, p)

def call_triton():
N = buf.numel()
grid = lambda meta: (triton.cdiv(N, meta["BLOCK"]), )
_kernel[grid](buf, BLOCK=1024)
getattr(torch, device).synchronize()

call_triton()
with ThreadPoolExecutor(1) as pool:
future = pool.submit(call_triton)
future.result()
27 changes: 18 additions & 9 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,6 @@ def format_of(ty):
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
void *params[] = {{ {', '.join(params)} }};
if (gridX*gridY*gridZ > 0) {{
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {{
// Ensure device context.
CUdevice device;
CUDA_CHECK(cuDeviceGet(&device, 0));
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}}
if (num_ctas == 1) {{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}} else {{
Expand Down Expand Up @@ -288,6 +279,9 @@ def format_of(ty):
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}} else if (status != CUDA_SUCCESS) {{
CUDA_CHECK(status); // Catch any other cuda API errors
ptr_info.valid = false;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we get an invalid memory access error in the test is that cuPointerGetAttributes is returning CUDA_ERROR_INVALID_CONTEXT which wasn't being handled here and we end up in an invalid state.

}}
ptr_info.dev_ptr = dev_ptr;
Py_DECREF(ret); // Thanks ChatGPT!
Expand Down Expand Up @@ -344,7 +338,22 @@ def format_of(ty):
return (CUtensorMap*)(ptr_as_uint);
}}

static void ensureCudaContext() {{
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {{
// Ensure device context.
CUdevice device;
CUDA_CHECK(cuDeviceGet(&device, 0));
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}}
}}

static PyObject* launch(PyObject* self, PyObject* args) {{
// ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
ensureCudaContext();

int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
Expand Down
Loading