Skip to content

Commit

Permalink
Change ATEN generator argument type to const std::optional<Generator>& (
Browse files Browse the repository at this point in the history
pytorch#6686)

Co-authored-by: cyy <cyyever@outlook.com>

Summary:
To adopt the changes in pytorch/pytorch#120076

To be noted, the original author is @cyyever and this is copied from pytorch#6595

Test Plan:
CI
  • Loading branch information
alanwaketan authored Mar 7, 2024
1 parent 46e2230 commit 8078b8f
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self,
}

at::Tensor XLANativeFunctions::bernoulli(
const at::Tensor& self, c10::optional<at::Generator> generator) {
const at::Tensor& self, const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand All @@ -905,7 +905,8 @@ at::Tensor XLANativeFunctions::bernoulli(
}

at::Tensor XLANativeFunctions::bernoulli(
const at::Tensor& self, double p, c10::optional<at::Generator> generator) {
const at::Tensor& self, double p,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -917,7 +918,7 @@ at::Tensor XLANativeFunctions::bernoulli(

at::Tensor& XLANativeFunctions::bernoulli_(
at::Tensor& self, const at::Tensor& p,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand Down Expand Up @@ -1308,7 +1309,8 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self,
}

at::Tensor& XLANativeFunctions::exponential_(
at::Tensor& self, double lambd, c10::optional<at::Generator> generator) {
at::Tensor& self, double lambd,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -1991,7 +1993,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self,

at::Tensor XLANativeFunctions::multinomial(
const at::Tensor& self, int64_t num_samples, bool replacement,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
XLA_CHECK(num_samples > 0)
<< "Multinomial number of samples must be greater than 0";
XLA_CHECK(at::isFloatingType(self.scalar_type()))
Expand Down Expand Up @@ -2305,8 +2307,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self,
bridge::GetXlaTensor(self), p, c10::nullopt, dim, keepdim));
}

at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std,
c10::optional<at::Generator> generator) {
at::Tensor XLANativeFunctions::normal(
const at::Tensor& mean, double std,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2317,8 +2320,9 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std,
tensor_methods::normal(bridge::GetXlaTensor(mean), std));
}

at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std,
c10::optional<at::Generator> generator) {
at::Tensor XLANativeFunctions::normal(
double mean, const at::Tensor& std,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2329,9 +2333,9 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std,
tensor_methods::normal(mean, bridge::GetXlaTensor(std)));
}

at::Tensor XLANativeFunctions::normal(const at::Tensor& mean,
const at::Tensor& std,
c10::optional<at::Generator> generator) {
at::Tensor XLANativeFunctions::normal(
const at::Tensor& mean, const at::Tensor& std,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2344,7 +2348,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean,

at::Tensor& XLANativeFunctions::normal_(
at::Tensor& self, double mean, double std,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2481,7 +2485,7 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::qr(
// The value generated should be within (from, to].
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, int64_t from, c10::optional<int64_t> to,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2501,7 +2505,8 @@ at::Tensor& XLANativeFunctions::random_(

// The value generated should be in (0, to].
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, int64_t to, c10::optional<at::Generator> generator) {
at::Tensor& self, int64_t to,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand All @@ -2517,7 +2522,7 @@ at::Tensor& XLANativeFunctions::random_(

// The value generated should be in (self_type_min, self_type_max).
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, c10::optional<at::Generator> generator) {
at::Tensor& self, const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2690,7 +2695,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self,
at::Tensor XLANativeFunctions::rrelu_with_noise(
const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower,
const at::Scalar& upper, bool training,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
// The fallback path for rrelu_with_noise when training=true is wrong
Expand Down Expand Up @@ -3229,7 +3234,7 @@ std::vector<at::Tensor> XLANativeFunctions::unbind_copy(const at::Tensor& self,

at::Tensor& XLANativeFunctions::uniform_(
at::Tensor& self, double from, double to,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down

0 comments on commit 8078b8f

Please sign in to comment.