Skip to content

Commit

Permalink
Fixes #316, regards #313: Multiple changes involving the current cont…
Browse files Browse the repository at this point in the history
…ext, primary contexts, and ensuring their existence in various circumstanves:

* Renamed: `context::current::detail_::scoped_current_device_fallback_t` -> `scoped_existence_ensurer_t` `context::current::detail_::scoped_context_existence_ensurer`
* context::current::scoped_override_t` now has a ctor which accepts. `primary_context_t&&`'s - to hold on to their PC reference which they are about to let go of.
* Moved: `context::current::scoped_override_t` is now implemented in the multi-wrapper implementations directory; consequently
    * Moved the implementations of  `module_t::get_kernel()` and `module::create<Creator>` to the multi-wrapper directory, since they use `context::current::scoped_override_t`.
    * Added inclusion of `cuda/api/multi_wrapper_impls/module.hpp` to some example code.
* Made a device current in some examples to avoid having no current context when executing certain operations with no wrappers (e.g. memcpy with host-side addresses)
* When allocating managed or pinned-host memory, now increasing the reference of some  context by 1 (choosing the primary context of device 0 since that's the safest), and decreasing it again on destruction. That guarantees that operations involving that allocated memory will not occur with no constructed contexts.
    * Corresponding comment changes on the `allocate()` and `free()` methods for pinned-host and managed memory.
* Factored out the code in `context_t::is_primary()` to a function, `cuda::context::current::detail_::is_primary`, which can now also be used via `cuda::context::current::is_primary()`.
* Kernel launch functions now ensure a launch only occurs / is enqueued within a current context (any context).
* Getting the current device now ensures its primary context is also active (which getting an arbitrary device does not do so).
* Added doxygen comment for `device::detail_::wrap()` mentioning the primary context reference behavior.
  • Loading branch information
eyalroz committed Apr 15, 2022
1 parent 475441e commit f090b91
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <cuda/api/virtual_memory.hpp>
#include <cuda/api/multi_wrapper_impls/memory.hpp>
#include <cuda/api/multi_wrapper_impls/module.hpp>

#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
#ifndef WIN32_LEAN_AND_MEAN
Expand Down
10 changes: 7 additions & 3 deletions examples/modified_cuda_samples/simpleStreams/simpleStreams.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ int main(int argc, char **argv)

std::cout << "\n> ";
auto device = chooseCudaDevice(argc, (const char **)argv);
device.make_current();
// This is "necessary", for now, for the memory operations whose API is context-unaware,
// but which would actually fail if the appropriate context is not the current one

// Checking for compute capabilities
auto properties = device.properties();
Expand Down Expand Up @@ -279,11 +282,12 @@ int main(int argc, char **argv)
threads=dim3(512,1);
blocks=dim3(n/(nstreams*threads.x),1);
launch_config = cuda::make_launch_config(blocks, threads);
memset(h_a.get(), 255, nbytes); // set host memory bits to all 1s, for testing correctness
// TODO: Avoid need to push and pop here
cuda::context::current::push(device.primary_context());
memset(h_a.get(), 255, nbytes); // set host memory bits to all 1s, for testing correctness
// This instruction is actually the only one in our program
// for which the device.make_current() command was necessary.
// TODO: Avoid having to do that altogether...
cuda::memory::device::zero(cuda::memory::region_t{d_a.get(), nbytes}); // set device memory to all 0s, for testing correctness
cuda::context::current::pop();
start_event.record();

for (int k = 0; k < nreps; k++)
Expand Down
48 changes: 24 additions & 24 deletions src/cuda/api/current_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,31 +139,38 @@ inline void set(handle_t context_handle)
*
*/
class scoped_override_t {
protected:
public:
explicit scoped_override_t(handle_t context_handle) { push(context_handle); }
~scoped_override_t() { pop(); }

// explicit scoped_context_override_t(handle_t context_handle_) :
// did_push(push_if_not_on_top(context_handle_)) { }
// scoped_context_override_terride_t() { if (did_push) { pop(); } }
//
//protected:
// bool did_push;
bool hold_primary_context_ref_unit_;
device::id_t device_id_or_0_;

explicit scoped_override_t(handle_t context_handle) : scoped_override_t(false, 0, context_handle) {}
scoped_override_t(device::id_t device_for_which_context_is_primary, handle_t context_handle)
: scoped_override_t(true, device_for_which_context_is_primary, context_handle) {}
explicit scoped_override_t(bool hold_primary_context_ref_unit, device::id_t device_id, handle_t context_handle);
~scoped_override_t();
};

/**
* @note See also the more complex @ref cuda::context::current::scoped_existence_ensurer_t ,
* which does _not_ take a fallback context handle, and rather obtains a reference to
* a primary context on its own.
*/
class scoped_ensurer_t {
public:
bool push_needed;
bool context_was_pushed_on_construction;

explicit scoped_ensurer_t(handle_t fallback_context_handle) : push_needed(not exists())
explicit scoped_ensurer_t(bool force_push, handle_t fallback_context_handle)
: context_was_pushed_on_construction(force_push)
{
if (push_needed) { push(fallback_context_handle); }
if (force_push) { push(fallback_context_handle); }
}
~scoped_ensurer_t() { if (push_needed) { pop(); } }
};

class scoped_current_device_fallback_t;
explicit scoped_ensurer_t(handle_t fallback_context_handle)
: scoped_ensurer_t(not exists(), fallback_context_handle)
{}

~scoped_ensurer_t() { if (context_was_pushed_on_construction) { pop(); } }
};

} // namespace detail_

Expand All @@ -180,14 +187,7 @@ class scoped_current_device_fallback_t;
* pushed.
*
*/
class scoped_override_t : private detail_::scoped_override_t {
protected:
using parent = detail_::scoped_override_t;
public:
explicit scoped_override_t(const context_t& device);
explicit scoped_override_t(context_t&& device);
~scoped_override_t() = default;
};
class scoped_override_t;

/**
* This macro will set the current device for the remainder of the scope in which it is
Expand Down
14 changes: 13 additions & 1 deletion src/cuda/api/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ using shared_memory_bank_size_t = context::shared_memory_bank_size_t;

namespace detail_ {

/**
* Construct a @ref device_t wrapper class instance for a given device ID
*
* @param id Numeric id (mostly an ordinal) of the device to wrap
* @param primary_context_handle if this is not the "none" value, the wrapper object
* will be owning a reference unit to the device's primary context, which it
* will release on destruction. Use this to allow runtime-API-style code, which does
* not explicitly construct contexts, to be able to function with a primary context
* being made and kept active.
*/
device_t wrap(id_t id, primary_context::handle_t primary_context_handle = context::detail_::none) noexcept;

} // namespace detail
Expand Down Expand Up @@ -679,7 +689,9 @@ namespace current {
inline device_t get()
{
ensure_driver_is_initialized();
return device::get(detail_::get_id());
auto id = detail_::get_id();
auto pc_handle = primary_context::detail_::obtain_and_increase_refcount(id);
return device::detail_::wrap(id, pc_handle);
}

/**
Expand Down
17 changes: 16 additions & 1 deletion src/cuda/api/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1499,13 +1499,18 @@ inline void* allocate(size_t size_in_bytes, cpu_write_combining cpu_wc)

/**
* Free a region of pinned host memory which was allocated with @ref allocate.
*
* @note You can't just use @ref cuMemFreeHost - or you'll leak a primary context reference unit.
*/
inline void free(void* host_ptr)
{
cuda::device::primary_context::detail_::decrease_refcount(cuda::device::default_device_id);
auto result = cuMemFreeHost(host_ptr);
throw_if_error(result, "Freeing pinned host memory at " + cuda::detail_::ptr_as_hex(host_ptr));
}

inline void free(region_t region) { return free(region.data()); }

namespace detail_ {

struct allocator {
Expand Down Expand Up @@ -1804,6 +1809,13 @@ inline region_t allocate_in_current_context(
device::address_t allocated = 0;
auto flags = (initial_visibility == initial_visibility_t::to_all_devices) ?
attachment_t::global : attachment_t::host;
// This is necessary because managed allocation requires at least one (primary)
// context to have been constructed. We could theoretically check what our current
// context is etc., but that would be brittle, since someone can managed-allocate,
// then change contexts, then de-allocate, and we can't be certain that whoever
// called us will call free
cuda::device::primary_context::detail_::increase_refcount(cuda::device::default_device_id);

// Note: Despite the templating by T, the size is still in bytes,
// not in number of T's
auto status = cuMemAllocManaged(&allocated, num_bytes, (unsigned) flags);
Expand All @@ -1817,12 +1829,15 @@ inline region_t allocate_in_current_context(
}

/**
* Free a region of pinned host memory which was allocated with @ref allocate.
* Free a region of managed memory which was allocated with @ref allocate_in_current_context.
*
* @note You can't just use @ref cuMemFree - or you'll leak a primary context reference unit.
*/
///@{
inline void free(void* ptr)
{
auto result = cuMemFree(device::address(ptr));
cuda::device::primary_context::detail_::decrease_refcount(cuda::device::default_device_id);
throw_if_error(result, "Freeing managed memory at 0x" + cuda::detail_::ptr_as_hex(ptr));
}
inline void free(region_t region)
Expand Down
24 changes: 2 additions & 22 deletions src/cuda/api/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,7 @@ class module_t {
device_t device() const;

// These API calls are not really the way you want to work.
cuda::kernel_t get_kernel(const char* name) const
{
context::current::detail_::scoped_override_t set_context_for_this_scope(context_handle_);
kernel::handle_t kernel_function_handle;
auto result = cuModuleGetFunction(&kernel_function_handle, handle_, name);
throw_if_error(result, ::std::string("Failed obtaining function ") + name
+ " from " + module::detail_::identify(*this));
return kernel::wrap(context::detail_::get_device_id(context_handle_), context_handle_, kernel_function_handle);
}
cuda::kernel_t get_kernel(const char* name) const;

memory::region_t get_global_region(const char* name) const
{
Expand Down Expand Up @@ -250,19 +242,7 @@ inline module_t construct(
}

template <typename Creator>
inline module_t create(const context_t& context, const void* module_data, Creator creator_function)
{
context::current::scoped_override_t set_context_for_this_scope(context);
handle_t new_module_handle;
auto status = creator_function(new_module_handle, module_data);
throw_if_error(status, ::std::string(
"Failed loading a module from memory location ") + cuda::detail_::ptr_as_hex(module_data) +
" within " + context::detail_::identify(context));
bool do_take_ownership { true };
// TODO: Make sure the default-constructed options correspond to what cuModuleLoadData uses as defaults
return detail_::construct(context.device_id(), context.handle(), new_module_handle,
link::options_t{}, do_take_ownership);
}
module_t create(const context_t& context, const void* module_data, Creator creator_function);

// TODO: Consider adding create_module() methods to context_t
inline module_t create(const context_t& context, const void* module_data, const link::options_t& link_options)
Expand Down
101 changes: 71 additions & 30 deletions src/cuda/api/multi_wrapper_impls/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ namespace current {

namespace detail_ {

inline bool is_primary(handle_t cc_handle, device::id_t current_context_device_id)
{
// Note we assume current_context_device_id really is the device ID for cc_handle;
// otherwise we could just use is_primary_for_device()
return cc_handle == device::primary_context::detail_::get_handle(current_context_device_id);
}

} // namespace detail_

inline bool is_primary()
{
auto current_context = get();
return detail_::is_primary(current_context.handle(), current_context.device_id());
}

namespace detail_ {

inline scoped_override_t::scoped_override_t(bool hold_primary_context_ref_unit, device::id_t device_id, handle_t context_handle)
: hold_primary_context_ref_unit_(hold_primary_context_ref_unit), device_id_or_0_(device_id)
{
if (hold_primary_context_ref_unit) { device::primary_context::detail_::increase_refcount(device_id); }
push(context_handle);
}

inline scoped_override_t::~scoped_override_t()
{
if (hold_primary_context_ref_unit_) { device::primary_context::detail_::decrease_refcount(device_id_or_0_); }
pop();
}


/**
* @todo This function is a bit shady, consider dropping it.s
*/
Expand All @@ -77,50 +108,62 @@ inline handle_t push_default_if_missing()

/**
* @note This specialized scope setter is used in API calls which aren't provided a context
* as a parameter, and when there is no context that's current. Such API calls are necessarily
* device-related (i.e. runtime-API-ish), and since there is always a current device, we can
* (and in fact must) fall back on that device's primary context as what the user assumes we
* would use. In these situations, we must also "leak" that device's primary context, in the
* sense of adding to its reference count without ever decreasing it again - since the only
* juncture at which we can decrease it is the scoped context setter's fallback; and if we
* do that, we will actually trigger the destruction of that primary context. As a consequence,
* if one _ever_ uses an API wrapper call which relies on this scoped context setter, the only
* way for them to destroy the primary context is either via @ref device_t::reset() (or
* manually decreasing the reference count to zero, which supposedly they will not do).
* as a parameter, and when it may be the case that no context is current. Such API calls
* are generally supposed to be independent of a specific context; but - CUDA still often
* expects some context to exist and be current to perform whatever it is we want it to do.
* It would be unreasonable to create new contexts for the purposes of such calls - as then,
* the caller would often need to maintain these contexts after the call. Instead, we fall
* back on a primary context of one of the devices - and since no particular device is
* specified, we choose that to be the default device. When we do want the caller to keep
* a context alive - we increase the primary context's refererence count, keeping it alive
* automatically. In these situations, the ref unit "leaks" past the scope of the ensurer
* object - but the instantiator would be aware of this, having asked for such behavior
* explicitly; and would itself carry the onus of decreasing the ref unit at some point.
*
* @note not sure about how appropriate it is to pop the created primary context off
* @note See also the simpler @ref cuda::context::current::scoped_ensurer_t ,
* which takes the context handle to push in the first place.
*/
class scoped_current_device_fallback_t {
class scoped_existence_ensurer_t {
public:
context::handle_t maybe_pc_handle_;
device::id_t device_id_;
context::handle_t pc_handle_ { context::detail_::none };
bool decrease_pc_refcount_on_destruct_;

explicit scoped_current_device_fallback_t()
explicit scoped_existence_ensurer_t(bool decrease_pc_refcount_on_destruct = true)
: maybe_pc_handle_(get_handle()),
decrease_pc_refcount_on_destruct_(decrease_pc_refcount_on_destruct)
{
auto current_context_handle = get_handle();
if (current_context_handle == context::detail_::none) {
if (maybe_pc_handle_ == context::detail_::none) {
device_id_ = device::current::detail_::get_id();
pc_handle_ = device::primary_context::detail_::obtain_and_increase_refcount(device_id_);
context::current::detail_::push(pc_handle_);
maybe_pc_handle_ = device::primary_context::detail_::obtain_and_increase_refcount(device_id_);
context::current::detail_::push(maybe_pc_handle_);
}
else { decrease_pc_refcount_on_destruct_ = false; }
}

~scoped_current_device_fallback_t()
~scoped_existence_ensurer_t()
{
// if (pc_handle_ != context::detail_::none) {
// context::current::detail_::pop();
// device::primary_context::detail_::decrease_refcount(device_id_);
// }
if (maybe_pc_handle_ != context::detail_::none and decrease_pc_refcount_on_destruct_) {
context::current::detail_::pop();
device::primary_context::detail_::decrease_refcount(device_id_);
}
}
};

} // namespace detail_

inline scoped_override_t::scoped_override_t(const context_t &context) : parent(context.handle())
{}
class scoped_override_t : private detail_::scoped_override_t {
protected:
using parent = detail_::scoped_override_t;
public:

explicit scoped_override_t(device::primary_context_t&& primary_context)
: parent(primary_context.is_owning(), primary_context.device_id(), primary_context.handle()) {}
explicit scoped_override_t(const context_t& context) : parent(context.handle()) {}
explicit scoped_override_t(context_t&& context) : parent(context.handle()) {}
~scoped_override_t() = default;
};

inline scoped_override_t::scoped_override_t(context_t &&context) : parent(context.handle())
{}

} // namespace current

Expand Down Expand Up @@ -216,9 +259,7 @@ inline context_t context_t::global_memory_type::associated_context() const

inline bool context_t::is_primary() const
{
auto pc_handle = device::primary_context::detail_::obtain_and_increase_refcount(device_id_);
device::primary_context::detail_::decrease_refcount(device_id_);
return handle_ == pc_handle;
return context::current::detail_::is_primary(handle(), device_id());
}

template <typename ContiguousContainer,
Expand Down
Loading

0 comments on commit f090b91

Please sign in to comment.