diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index 0c6e5b06d0898..0a07470739088 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -244,6 +244,163 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una }); } +void logit_mps_impl(const Tensor& self, c10::optional eps, Tensor& output, const std::string op_name) { + std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]"; + + mps::unary_op(self, output, key, + ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 + shape:@[@1] + dataType:inputTensor.dataType]; + MPSGraphTensor* logitInputTensor; + + if (eps.has_value()) { + MPSGraphTensor *lowTensor = [mpsGraph constantWithScalar:eps.value() + shape:@[@1] + dataType:inputTensor.dataType]; + MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor + secondaryTensor: lowTensor + name: nil]; + logitInputTensor = [mpsGraph clampWithTensor:inputTensor + minValueTensor:lowTensor + maxValueTensor:highTensor + name:nil]; + } else { + logitInputTensor = inputTensor; + } + + MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor + secondaryTensor: logitInputTensor + name: nil]; + MPSGraphTensor *outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor + secondaryTensor:oneMinusInputTensor + name:nil]; + return [mpsGraph logarithmWithTensor:outputTensor + name:nil]; + }); +} + +Tensor& logit_out_mps(const Tensor& self, + c10::optional eps, + Tensor& result) { + logit_mps_impl(self, eps, result, "logit_out_mps"); + return result; +} + +Tensor logit_mps(const Tensor& self, c10::optional eps) { + Tensor result = at::native::empty_mps( + self.sizes(), + ScalarType::Float, + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + logit_mps_impl(self, eps, result, "logit_mps"); + return result; +} + +TORCH_IMPL_FUNC(logit_backward_out_mps) ( + const Tensor& grad_output, + const Tensor& input, + c10::optional eps, + const Tensor& grad_input) + { + using namespace mps; + + // Empty output + if(grad_input.numel() == 0) + return; + + double eps_ = eps ? eps.value() : -1.0; + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = getCurrentMPSStream(); + + @autoreleasepool { + std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" + + "[" + (eps.has_value() ? std::to_string(eps.value()) : "-1" ) + "]"; + + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor* outputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_input); + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 + shape:@[@1] + dataType:inputTensor.dataType]; + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 + shape:@[@1] + dataType:inputTensor.dataType]; + MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_ + shape:@[@1] + dataType:inputTensor.dataType]; + MPSGraphTensor *inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor + secondaryTensor: lowTensor + name: nil]; + MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor + secondaryTensor: lowTensor + name: nil]; + MPSGraphTensor *inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor + secondaryTensor: highTensor + name: nil]; + MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor: inputLessThanLowPredicateTensor + secondaryTensor: inputGreaterThanHighPredicateTensor + name: nil]; + MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor + secondaryTensor: inputTensor + name: nil]; + outputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:oneMinusInputTensor + name:nil]; + outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor + secondaryTensor:outputTensor + name:nil]; + outputTensor = [mpsGraph selectWithPredicateTensor: outOfIntervalTensor + truePredicateTensor: zeroTensor + falsePredicateTensor: outputTensor + name: nil]; + + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; + NSDictionary* results = @{ + gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } +} + TORCH_IMPL_FUNC(cumsum_out_mps) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a666027379899..95a09f809f488 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4931,6 +4931,7 @@ variants: function, method dispatch: CPU, CUDA: logit + MPS: logit_mps tags: pointwise - func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) @@ -4942,6 +4943,7 @@ - func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: logit_out + MPS: logit_out_mps tags: pointwise - func: sin(Tensor self) -> Tensor @@ -12135,6 +12137,7 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: logit_backward_out + MPS: logit_backward_out_mps tags: pointwise - func: logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor diff --git a/test/test_mps.py b/test/test_mps.py index bed445ee37253..671ebd1eac5dc 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9277,6 +9277,7 @@ class TestConsistency(TestCaseMPS): 'logical_not': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'logit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logspace': ['f32', 'i16', 'i32', 'i64', 'u8'], 'logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'masked_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -9526,6 +9527,7 @@ class TestConsistency(TestCaseMPS): 'log_softmax': ['f32'], 'logaddexp': ['f32'], 'logical_not': ['f16', 'f32'], + 'logit': ['f16', 'f32'], 'logspace': ['f32'], 'matmul': ['f32'], 'mm': ['f32'],