diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index e720a1608f1..f85546a6c41 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -5,7 +5,7 @@ namespace ops { namespace mps { -static const char* METAL_VISION = R"VISION_METAL( +static at::native::mps::MetalShaderLibrary lib(R"VISION_METAL( #include #include @@ -26,46 +26,15 @@ inline T ceil_div(T n, T m) { return (n + m - 1) / m; } -template -inline void atomic_add_float( device T* data_ptr, const T val) +inline void atomic_add_float(device float* data_ptr, const float val) { -#if __METAL_VERSION__ >= 300 - // atomic_float is supported in Metal 3 (macOS Ventura) onward. - device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); -#else - // Custom atomic addition implementation - // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 - // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 - // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) - - // Create an atomic uint pointer for atomic transaction. - device atomic_uint* atom_var = (device atomic_uint*)data_ptr; - // Create necessary storage. - uint fetched_uint, assigning_uint; - T fetched_float, assigning_float; - - // Replace the value in atom_var with 0 and return the previous value in atom_var. - fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); - // Read out the previous value as float. - fetched_float = *( (thread T*) &fetched_uint ); - - // Do addition and represent the addition result in uint for atomic transaction. - assigning_float = fetched_float + val; - assigning_uint = *((thread uint*) &assigning_float); - - // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). - while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { - // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. - // Try to assign 0 and get the previously assigned addition result. - uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); - T fetched_float_again = *( (thread T*) &fetched_uint_again ); - // Re-add again - fetched_float = *((thread T*) &(fetched_uint)); - // Previously assigned addition result + addition result from other threads. - assigning_float = fetched_float_again + fetched_float; - assigning_uint = *( (thread uint*) &assigning_float); - } -#endif + atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); +} + + +inline void atomic_add_float(device half* data_ptr, const half val) +{ + atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast(val), memory_order_relaxed); } template @@ -1061,40 +1030,12 @@ REGISTER_PS_ROI_POOL_OP(half, int64_t); REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); -)VISION_METAL"; - -static id compileVisionOpsLibrary(id device) { - static id visionLibrary = nil; - if (visionLibrary) { - return visionLibrary; - } - - NSError* error = nil; - MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion:MTLLanguageVersion2_3]; - visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding] - options:options - error:&error]; - TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]); - return visionLibrary; -} - -static id visionPipelineState(id device, const std::string& kernel) { - static std::unordered_map> psoCache; - id pso = psoCache[kernel]; - if (pso) { - return pso; - } - - NSError* error = nil; - id visionLib = compileVisionOpsLibrary(device); - id visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; - TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel); - pso = [device newComputePipelineStateWithFunction:visionFunc error:&error]; - TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); +)VISION_METAL"); - psoCache[kernel] = pso; - return pso; +static id visionPipelineState( + id device, + const std::string& kernel) { + return lib.getPipelineStateForFunc(kernel); } } // namespace mps diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm index fc24f6990fa..75d0ff4845f 100644 --- a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -123,7 +123,6 @@ float spatial_scale_f = static_cast(spatial_scale); - auto num_rois = rois.size(0); auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); if (grad.numel() == 0) {