Skip to content

Commit

Permalink
Update code to transfer pylist to xla::OpSharding
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Nov 6, 2023
1 parent 08d6296 commit 4b453c7
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 38 deletions.
77 changes: 47 additions & 30 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,35 +718,19 @@ void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) {
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment, c10::List<at::IntArrayRef> group_assignment, c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List<at::IntArrayRef>& tile_assignment, c10::List<at::IntArrayRef>& group_assignment, c10::List<at::IntArrayRef>& replication_groups, int64_t sharding_type) {
std::cout << "at xla_mark_sharding_dynamo_custom_op0" << std::endl;

std::cout << "input: " << input << std::endl;
// std::cout << "tile_assignment: " << tile_assignment << std::endl;
std::cout << "tile_assignment.size(): " << tile_assignment.size() << std::endl;
std::cout << "converting tile_assignment_py" << std::endl;
// const py::list& tile_assignment_py = py::cast(tile_assignment[0]);
const py::list& tile_assignment_py = py::cast(torch::lazy::ToVector<int64_t>(tile_assignment[0]));

// std::cout << "group_assignment: " << group_assignment << std::endl;
std::cout << "converting group_assignment_py" << std::endl;
const py::list& group_assignment_py = py::cast(group_assignment);

// std::cout << "replication_groups: " << replication_groups << std::endl;
std::cout << "converting replication_groups_py" << std::endl;
const py::list& replication_groups_py = py::cast(replication_groups);

std::cout << "at xla_mark_sharding_dynamo_custom_op1" << std::endl;

const xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding(
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));


std::cout << "at xla_mark_sharding_dynamo_custom_op2" << std::endl;

xla_mark_sharding(input, op_sharding);
py::list tile_assignment_py = py::list();
for (const at::IntArrayRef t : tile_assignment) {
// auto t_vec = XlaHelpers::I64List(t);
// tile_assignment_py.append(py::cast(t_vec));
}

std::cout << "at xla_mark_sharding_dynamo_custom_op3" << std::endl;
std::cout << "tile_assignment_py.size(): " << tile_assignment_py.size() << std::endl;
}

// Macro for defining a function that will be run at static initialization time to define a library of operators in the namespace.
Expand Down Expand Up @@ -1671,6 +1655,12 @@ void InitXlaModuleBindings(py::module m) {
.def(py::init([](const py::list& tile_assignment,
const py::list& group_assignment,
const py::list& replication_groups, int sharding_type) {
std::cout << "at OpSharding" << std::endl;
// std::cout << "tile_assignment:" << tile_assignment << std::endl;
// auto vec = tile_assignment.cast<std::vector<std::vector<int64_t>>>(tile_assignment);
// std::cout << "casted: " << vec << std::endl;
// std::cout << "group_assignment:" << group_assignment << std::endl;
// std::cout << "replication_groups:" << replication_groups << std::endl;
return ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
Expand All @@ -1679,12 +1669,39 @@ void InitXlaModuleBindings(py::module m) {
xla::OpSharding sharding) {
xla_mark_sharding(input, sharding);
});
// m.def("_xla_mark_sharding_dynamo_custom_op",
// [](const at::Tensor& input, xla::OpSharding sharding) {
// // xla_mark_sharding_dynamo_custom_op(input, tile_assignment, group_assignment, replication_groups, sharding_type);
// // at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type
// at::IntArrayRef tile_assignment =
// });
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment,
const py::list& replication_groups, int sharding_type) {
// xla_mark_sharding_dynamo_custom_op(input, tile_assignment, group_assignment, replication_groups, sharding_type);
// at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type
// auto sharding_type = sharding.type();
std::cout << "WONJOO: in pybind::_xla_mark_sharding_dynamo_custom_op" << std::endl;
std::cout << "WONJOO: tile_assignment: " << tile_assignment << std::endl;
std::cout << "WONJOO: tile_assignment.size(): " << tile_assignment.size() << std::endl;
std::cout << "WONJOO: group_assignment: " << group_assignment << std::endl;
std::cout << "WONJOO: group_assignment.size(): " << group_assignment.size() << std::endl;
std::cout << "WONJOO: replication_groups: " << replication_groups << std::endl;
std::cout << "WONJOO: replication_groups.size(): " << replication_groups.size() << std::endl;
std::cout << "WONJOO: sharding_type: " << sharding_type << std::endl;

c10::List<at::IntArrayRef> time_assignment_list = c10::List<at::IntArrayRef>();
for (auto t : tile_assignment) {
time_assignment_list.push_back(at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

c10::List<at::IntArrayRef> group_assignment_list = c10::List<at::IntArrayRef>();
for (auto t : group_assignment) {
group_assignment_list.push_back(at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

c10::List<at::IntArrayRef> replication_groups_list = c10::List<at::IntArrayRef>();
for (auto t : replication_groups) {
replication_groups_list.push_back(at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

xla_mark_sharding_dynamo_custom_op(input, time_assignment_list, group_assignment_list, replication_groups_list, sharding_type);
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->ClearShardingSpec();
Expand Down
19 changes: 11 additions & 8 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ def get_op_sharding(self,

# If flatten = True, return the flattened version of OpSharding
# print each return type to debug
print(tile_assignment.tolist())
print(group_assignment)
print(replication_groups)
print(f'tile_assignment.tolist(): {tile_assignment.tolist()}')
print(f'type(tile_assignment.tolist()): {type(tile_assignment.tolist())}')
print(f'group_assignment: {group_assignment}')
print(f'type(group_assignment) {type(group_assignment)}')
print(f'replication_groups: {replication_groups}')
print(f'type(replication_groups): {type(replication_groups)}')
if flatten:
return (tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type))

Expand Down Expand Up @@ -520,7 +523,8 @@ def mark_sharding_dynamo_custom_op(
t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor:
"""
Same functionality as `mark_sharding` above, except this variant uses the custom mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it.
Same functionality as `mark_sharding` above, except this variant uses the custom
mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it.
"""
num_devices = xr.global_runtime_device_count()
assert num_devices > 0, "This requires XLA supported device(s)."
Expand All @@ -535,12 +539,11 @@ def mark_sharding_dynamo_custom_op(

tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(partition_spec, flatten = True)

print('about to call xla_mark_sharding_dynamo_custom_op')
if isinstance(t, XLAShardedTensor):
torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type)
# torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t.global_tensor, op_sharding)
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type)
return t
torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type)
print('xla_mark_sharding_dynamo_custom_op call finished')
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type)
return XLAShardedTensor(t)


Expand Down

0 comments on commit 4b453c7

Please sign in to comment.