Skip to content

Commit

Permalink
[MPS] Lift MSL version to 3.0+ and use relevant helpers (#8719)
Browse files Browse the repository at this point in the history
Summary:
1. Remove the custom atomic add function and use the one provided by MSL 3.0+ instead.
2. Use `MetalShaderLibrary` class.
  • Loading branch information
qqaatw authored Nov 11, 2024
1 parent 66c5629 commit cb9fdbf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 74 deletions.
87 changes: 14 additions & 73 deletions torchvision/csrc/ops/mps/mps_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace ops {

namespace mps {

static const char* METAL_VISION = R"VISION_METAL(
static at::native::mps::MetalShaderLibrary lib(R"VISION_METAL(
#include <metal_atomic>
#include <metal_stdlib>
Expand All @@ -26,46 +26,15 @@ inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
template <typename T>
inline void atomic_add_float( device T* data_ptr, const T val)
inline void atomic_add_float(device float* data_ptr, const float val)
{
#if __METAL_VERSION__ >= 300
// atomic_float is supported in Metal 3 (macOS Ventura) onward.
device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
#else
// Custom atomic addition implementation
// https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472
// https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639
// https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide)
// Create an atomic uint pointer for atomic transaction.
device atomic_uint* atom_var = (device atomic_uint*)data_ptr;
// Create necessary storage.
uint fetched_uint, assigning_uint;
T fetched_float, assigning_float;
// Replace the value in atom_var with 0 and return the previous value in atom_var.
fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed);
// Read out the previous value as float.
fetched_float = *( (thread T*) &fetched_uint );
// Do addition and represent the addition result in uint for atomic transaction.
assigning_float = fetched_float + val;
assigning_uint = *((thread uint*) &assigning_float);
// atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr).
while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) {
// If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads.
// Try to assign 0 and get the previously assigned addition result.
uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed);
T fetched_float_again = *( (thread T*) &fetched_uint_again );
// Re-add again
fetched_float = *((thread T*) &(fetched_uint));
// Previously assigned addition result + addition result from other threads.
assigning_float = fetched_float_again + fetched_float;
assigning_uint = *( (thread uint*) &assigning_float);
}
#endif
atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
}
inline void atomic_add_float(device half* data_ptr, const half val)
{
atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast<float>(val), memory_order_relaxed);
}
template <typename T, typename integer_t>
Expand Down Expand Up @@ -1061,40 +1030,12 @@ REGISTER_PS_ROI_POOL_OP(half, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t);
)VISION_METAL";

static id<MTLLibrary> compileVisionOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> visionLibrary = nil;
if (visionLibrary) {
return visionLibrary;
}

NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]);
return visionLibrary;
}

static id<MTLComputePipelineState> visionPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}

NSError* error = nil;
id<MTLLibrary> visionLib = compileVisionOpsLibrary(device);
id<MTLFunction> visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:visionFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
)VISION_METAL");

psoCache[kernel] = pso;
return pso;
static id<MTLComputePipelineState> visionPipelineState(
id<MTLDevice> device,
const std::string& kernel) {
return lib.getPipelineStateForFunc(kernel);
}

} // namespace mps
Expand Down
1 change: 0 additions & 1 deletion torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@

float spatial_scale_f = static_cast<float>(spatial_scale);

auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());

if (grad.numel() == 0) {
Expand Down

0 comments on commit cb9fdbf

Please sign in to comment.