diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index 0a07470739088..0c6e5b06d0898 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -244,163 +244,6 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una }); } -void logit_mps_impl(const Tensor& self, c10::optional eps, Tensor& output, const std::string op_name) { - std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]"; - - mps::unary_op(self, output, key, - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* logitInputTensor; - - if (eps.has_value()) { - MPSGraphTensor *lowTensor = [mpsGraph constantWithScalar:eps.value() - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor - secondaryTensor: lowTensor - name: nil]; - logitInputTensor = [mpsGraph clampWithTensor:inputTensor - minValueTensor:lowTensor - maxValueTensor:highTensor - name:nil]; - } else { - logitInputTensor = inputTensor; - } - - MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor - secondaryTensor: logitInputTensor - name: nil]; - MPSGraphTensor *outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor - secondaryTensor:oneMinusInputTensor - name:nil]; - return [mpsGraph logarithmWithTensor:outputTensor - name:nil]; - }); -} - -Tensor& logit_out_mps(const Tensor& self, - c10::optional eps, - Tensor& result) { - logit_mps_impl(self, eps, result, "logit_out_mps"); - return result; -} - -Tensor logit_mps(const Tensor& self, c10::optional eps) { - Tensor result = at::native::empty_mps( - self.sizes(), - ScalarType::Float, - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - logit_mps_impl(self, eps, result, "logit_mps"); - return result; -} - -TORCH_IMPL_FUNC(logit_backward_out_mps) ( - const Tensor& grad_output, - const Tensor& input, - c10::optional eps, - const Tensor& grad_input) - { - using namespace mps; - - // Empty output - if(grad_input.numel() == 0) - return; - - double eps_ = eps ? eps.value() : -1.0; - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" + - "[" + (eps.has_value() ? std::to_string(eps.value()) : "-1" ) + "]"; - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* outputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_input); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_ - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor *inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor - secondaryTensor: lowTensor - name: nil]; - MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor - secondaryTensor: lowTensor - name: nil]; - MPSGraphTensor *inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor - secondaryTensor: highTensor - name: nil]; - MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor: inputLessThanLowPredicateTensor - secondaryTensor: inputGreaterThanHighPredicateTensor - name: nil]; - MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor - secondaryTensor: inputTensor - name: nil]; - outputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:oneMinusInputTensor - name:nil]; - outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor - secondaryTensor:outputTensor - name:nil]; - outputTensor = [mpsGraph selectWithPredicateTensor: outOfIntervalTensor - truePredicateTensor: zeroTensor - falsePredicateTensor: outputTensor - name: nil]; - - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); - - // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } -} - TORCH_IMPL_FUNC(cumsum_out_mps) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 95a09f809f488..a666027379899 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4931,7 +4931,6 @@ variants: function, method dispatch: CPU, CUDA: logit - MPS: logit_mps tags: pointwise - func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) @@ -4943,7 +4942,6 @@ - func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: logit_out - MPS: logit_out_mps tags: pointwise - func: sin(Tensor self) -> Tensor @@ -12137,7 +12135,6 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: logit_backward_out - MPS: logit_backward_out_mps tags: pointwise - func: logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor diff --git a/test/test_mps.py b/test/test_mps.py index 671ebd1eac5dc..bed445ee37253 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9277,7 +9277,6 @@ class TestConsistency(TestCaseMPS): 'logical_not': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logspace': ['f32', 'i16', 'i32', 'i64', 'u8'], 'logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'masked_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -9527,7 +9526,6 @@ class TestConsistency(TestCaseMPS): 'log_softmax': ['f32'], 'logaddexp': ['f32'], 'logical_not': ['f16', 'f32'], - 'logit': ['f16', 'f32'], 'logspace': ['f32'], 'matmul': ['f32'], 'mm': ['f32'],