diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 116635c688b..27e148b3c4d 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -920,7 +920,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, c10::optional generator) { + const at::Tensor& self, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -932,7 +932,8 @@ at::Tensor XLANativeFunctions::bernoulli( } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, double p, c10::optional generator) { + const at::Tensor& self, double p, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -944,7 +945,7 @@ at::Tensor XLANativeFunctions::bernoulli( at::Tensor& XLANativeFunctions::bernoulli_( at::Tensor& self, const at::Tensor& p, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -1347,7 +1348,8 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, } at::Tensor& XLANativeFunctions::exponential_( - at::Tensor& self, double lambd, c10::optional generator) { + at::Tensor& self, double lambd, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2030,7 +2032,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, at::Tensor XLANativeFunctions::multinomial( const at::Tensor& self, int64_t num_samples, bool replacement, - c10::optional generator) { + const std::optional& generator) { XLA_CHECK(num_samples > 0) << "Multinomial number of samples must be greater than 0"; XLA_CHECK(at::isFloatingType(self.scalar_type())) @@ -2344,8 +2346,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, bridge::GetXlaTensor(self), p, c10::nullopt, dim, keepdim)); } -at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, - c10::optional generator) { +at::Tensor XLANativeFunctions::normal( + const at::Tensor& mean, double std, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2356,8 +2359,9 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, tensor_methods::normal(bridge::GetXlaTensor(mean), std)); } -at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, - c10::optional generator) { +at::Tensor XLANativeFunctions::normal( + double mean, const at::Tensor& std, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2368,9 +2372,9 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, tensor_methods::normal(mean, bridge::GetXlaTensor(std))); } -at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, - const at::Tensor& std, - c10::optional generator) { +at::Tensor XLANativeFunctions::normal( + const at::Tensor& mean, const at::Tensor& std, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2383,7 +2387,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, at::Tensor& XLANativeFunctions::normal_( at::Tensor& self, double mean, double std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2527,7 +2531,7 @@ std::tuple XLANativeFunctions::qr( // The value generated should be within (from, to]. at::Tensor& XLANativeFunctions::random_( at::Tensor& self, int64_t from, c10::optional to, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2547,7 +2551,8 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (0, to]. at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, int64_t to, c10::optional generator) { + at::Tensor& self, int64_t to, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2563,7 +2568,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (self_type_min, self_type_max). at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, c10::optional generator) { + at::Tensor& self, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2736,7 +2741,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self, at::Tensor XLANativeFunctions::rrelu_with_noise( const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { // The fallback path for rrelu_with_noise when training=true is wrong @@ -3275,7 +3280,7 @@ std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, at::Tensor& XLANativeFunctions::uniform_( at::Tensor& self, double from, double to, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback,