Skip to content

Commit

Permalink
Revert "[MPS] Add logit op (pytorch#95162)"
Browse files Browse the repository at this point in the history
This reverts commit d96aac8.
  • Loading branch information
pruthvistony committed May 2, 2023
1 parent 46345ea commit 1021f5e
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 162 deletions.
157 changes: 0 additions & 157 deletions aten/src/ATen/native/mps/operations/UnaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -244,163 +244,6 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
});
}

void logit_mps_impl(const Tensor& self, c10::optional<double> eps, Tensor& output, const std::string op_name) {
std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]";

mps::unary_op(self, output, key,
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor* logitInputTensor;

if (eps.has_value()) {
MPSGraphTensor *lowTensor = [mpsGraph constantWithScalar:eps.value()
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: lowTensor
name: nil];
logitInputTensor = [mpsGraph clampWithTensor:inputTensor
minValueTensor:lowTensor
maxValueTensor:highTensor
name:nil];
} else {
logitInputTensor = inputTensor;
}

MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: logitInputTensor
name: nil];
MPSGraphTensor *outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor
secondaryTensor:oneMinusInputTensor
name:nil];
return [mpsGraph logarithmWithTensor:outputTensor
name:nil];
});
}

Tensor& logit_out_mps(const Tensor& self,
c10::optional<double> eps,
Tensor& result) {
logit_mps_impl(self, eps, result, "logit_out_mps");
return result;
}

Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
Tensor result = at::native::empty_mps(
self.sizes(),
ScalarType::Float,
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
logit_mps_impl(self, eps, result, "logit_mps");
return result;
}

TORCH_IMPL_FUNC(logit_backward_out_mps) (
const Tensor& grad_output,
const Tensor& input,
c10::optional<double> eps,
const Tensor& grad_input)
{
using namespace mps;

// Empty output
if(grad_input.numel() == 0)
return;

double eps_ = eps ? eps.value() : -1.0;

struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *gradOutputTensor_ = nil;
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" +
"[" + (eps.has_value() ? std::to_string(eps.value()) : "-1" ) + "]";

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {

CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* outputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_input);
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor *inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor
secondaryTensor: lowTensor
name: nil];
MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: lowTensor
name: nil];
MPSGraphTensor *inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor
secondaryTensor: highTensor
name: nil];
MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor: inputLessThanLowPredicateTensor
secondaryTensor: inputGreaterThanHighPredicateTensor
name: nil];
MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: inputTensor
name: nil];
outputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
secondaryTensor:oneMinusInputTensor
name:nil];
outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor
secondaryTensor:outputTensor
name:nil];
outputTensor = [mpsGraph selectWithPredicateTensor: outOfIntervalTensor
truePredicateTensor: zeroTensor
falsePredicateTensor: outputTensor
name: nil];

newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}



TORCH_IMPL_FUNC(cumsum_out_mps)
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4931,7 +4931,6 @@
variants: function, method
dispatch:
CPU, CUDA: logit
MPS: logit_mps
tags: pointwise

- func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!)
Expand All @@ -4943,7 +4942,6 @@
- func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: logit_out
MPS: logit_out_mps
tags: pointwise

- func: sin(Tensor self) -> Tensor
Expand Down Expand Up @@ -12137,7 +12135,6 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: logit_backward_out
MPS: logit_backward_out_mps
tags: pointwise

- func: logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor
Expand Down
2 changes: 0 additions & 2 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9277,7 +9277,6 @@ class TestConsistency(TestCaseMPS):
'logical_not': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logspace': ['f32', 'i16', 'i32', 'i64', 'u8'],
'logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
Expand Down Expand Up @@ -9527,7 +9526,6 @@ class TestConsistency(TestCaseMPS):
'log_softmax': ['f32'],
'logaddexp': ['f32'],
'logical_not': ['f16', 'f32'],
'logit': ['f16', 'f32'],
'logspace': ['f32'],
'matmul': ['f32'],
'mm': ['f32'],
Expand Down

0 comments on commit 1021f5e

Please sign in to comment.