Skip to content

Commit

Permalink
Set XLA_USE_SPMD for spmd cpp tests. (#6273)
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 authored Jan 11, 2024
1 parent 68f4750 commit a728afe
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a,
}
} // namespace

class XLAShardingTest : public AtenXlaTensorTestBase {};
class XLAShardingTest : public AtenXlaTensorTestBase {
protected:
static void SetUpTestCase() {
setenv("XLA_USE_SPMD", "1", /*overwrite=*/true);
CommonSetup();
}
};

TEST_F(XLAShardingTest, GetShardShape) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
Expand Down

0 comments on commit a728afe

Please sign in to comment.