Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[emulator] feat: veScale correctness emulator #45

Merged
merged 1 commit into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/pictures/emulator.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/emulator_distributed.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/emulator_dtensor.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/emulator_mesh_collectives.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/training_losses_diff_in_bf16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added test/emulator/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions test/emulator/common_emulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

from typing import Callable, Tuple, Dict, Any
from functools import wraps

TestFunc = Callable[[object], object]


# wrapper to initialize comms (process group) within emulator
def with_comms_emulator(func: TestFunc) -> TestFunc:
assert func is not None

@wraps(func) # pyre-ignore[6]
def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None: # type: ignore[misc]
# launch
self.init_emulator_pg()
func(self, *args, **kwargs) # type: ignore[misc]
self.destroy_emulator_pg()

return wrapper
212 changes: 212 additions & 0 deletions test/emulator/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################


import os
import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)

import vescale
from vescale.emulator.distributed import ProcessGroup, dump_nccl_graph_for_pg
from vescale.emulator.reduce_kernel import ReduceOp

from vescale.emulator.all_gather import expand_tensor_list
from vescale.emulator.reduce_scatter import contract_tensor_list
from common_dtensor import DTensorTestBase, with_comms
from emulator.common_emulator import with_comms_emulator
from vescale.emulator.utils import emulator_reduce_op_to_torch


class TestDistributed(DTensorTestBase):
@property
def world_size(self) -> int:
return 4

def init_emulator_pg(self):
torch.manual_seed(0)
backend = "nccl"
world_size = self.world_size

vescale.emulator.distributed.init_process_group(backend=backend, world_size=world_size, rank=0)
vescale.emulator.distributed.set_rank(0)
self.pg: ProcessGroup = vescale.emulator.distributed._world.default_pg
self.torch_pg = torch.distributed.distributed_c10d._get_default_group()
dump_nccl_graph_for_pg(self.pg, self.torch_pg, self.rank)

def destroy_emulator_pg(self):
vescale.emulator.distributed.destroy_process_group()

@with_comms
@with_comms_emulator
def test_process_group(self):
ground_truth_pg_group_ranks = [{0: 0, 1: 1, 2: 2, 3: 3}, {0: 0, 2: 1}, {1: 0, 3: 1}, {0: 0, 1: 1}, {2: 0, 3: 1}]
for count, value in enumerate(vescale.emulator.distributed._world.pg_group_ranks.values()):
self.assertEqual(value, ground_truth_pg_group_ranks[count])

@with_comms
@with_comms_emulator
# @parametrize("reduce_op", [ReduceOp.SUM, ReduceOp.PRODUCT, ReduceOp.MAX, ReduceOp.MIN])
@parametrize("reduce_op", [ReduceOp.SUM])
@parametrize("nelement", [1, 1024, 1024 * 1024])
def test_all_reduce(self, nelement, reduce_op):
nranks = self.pg.size()
tree_structure = [[0, 1], [2, 3]]
torch_rank = self.rank
device = f"cuda:{torch_rank}"

input_file = "input_distributed.pt"
if self.rank == 0:
# To ensure all ranks have the same input
input_list = []
for i in range(nranks):
input_list.append(torch.randn((nelement,), device="cuda"))
torch.save(input_list, input_file)
dist.barrier()

data_list = torch.load(input_file)
data_list = [data.to(device) for data in data_list]
ground_truth = [data_list[rank].clone().to(device) if rank == torch_rank else [] for rank in range(nranks)]
torch_reduce_op = emulator_reduce_op_to_torch(reduce_op)

torch.distributed.all_reduce(ground_truth[torch_rank], torch_reduce_op)
self.pg.all_reduce(data_list, op=reduce_op, tree_structure=tree_structure)

self.assertTrue(torch.equal(data_list[torch_rank], ground_truth[torch_rank]))

if self.rank == 0:
if os.path.exists(input_file):
os.remove(input_file)

@with_comms
@with_comms_emulator
@parametrize("nelement", [1, 1024, 1024 * 1024])
def test_all_gather(self, nelement):
nranks = self.pg.size()
torch_rank = self.rank
device = f"cuda:{torch_rank}"

input_file = "input_distributed.pt"
if self.rank == 0:
# To ensure all ranks have the same input
input_list = []
for i in range(nranks):
input_list.append(torch.randn((nelement,), device="cuda"))
torch.save(input_list, input_file)
dist.barrier()

data_list = torch.load(input_file)
data_list = [data.to(device) for data in data_list]
ground_truth_list = [torch.zeros(nelement).to(device) for _ in range(nranks)]
output_list = expand_tensor_list(data_list)

torch.distributed.all_gather(ground_truth_list, data_list[torch_rank])
self.pg.all_gather(output_list, data_list)

for gt, data in zip(ground_truth_list, data_list):
self.assertTrue(torch.equal(gt, data))

if self.rank == 0:
if os.path.exists(input_file):
os.remove(input_file)

@with_comms
@with_comms_emulator
# @parametrize("reduce_op", [ReduceOp.SUM, ReduceOp.PRODUCT, ReduceOp.MAX, ReduceOp.MIN])
@parametrize("reduce_op", [ReduceOp.SUM])
@parametrize("nelement", [1, 1024, 1024 * 1024])
def test_reduce_scatter(self, nelement, reduce_op):
nranks = self.pg.size()
torch_rank = self.rank
device = f"cuda:{torch_rank}"

input_file = "input_distributed.pt"
if self.rank == 0:
# To ensure all ranks have the same input
input_list = []
for i in range(nranks):
input_list.append([])
for j in range(nranks):
input_list[i].append(torch.randn((nelement,), device="cuda"))
torch.save(input_list, input_file)
dist.barrier()

data_list = torch.load(input_file)
data_list = [[elem.to(device) for elem in data] for data in data_list]
ground_truth = torch.zeros(nelement).to(device)
outputs = contract_tensor_list(data_list)
torch_reduce_op = emulator_reduce_op_to_torch(reduce_op)

torch.distributed.reduce_scatter(ground_truth, data_list[torch_rank], torch_reduce_op)

self.pg.reduce_scatter(outputs, data_list, op=reduce_op)

result = outputs[torch_rank]
self.assertTrue(torch.equal(result, ground_truth))

if self.rank == 0:
if os.path.exists(input_file):
os.remove(input_file)

@with_comms
@with_comms_emulator
@parametrize("nelement", [1, 1024, 1024 * 1024])
def test_all_to_all(self, nelement):
nranks = self.pg.size()
torch_rank = self.rank
device = f"cuda:{torch_rank}"

input_file = "input_distributed.pt"
if self.rank == 0:
# To ensure all ranks have the same input
input_list = []
for i in range(nranks):
input_list.append([])
for j in range(nranks):
input_list[i].append(torch.randn((nelement,), device="cuda"))
torch.save(input_list, input_file)
dist.barrier()

data_list = torch.load(input_file)
outputs_list = []
ground_truth_list = []
for i in range(nranks):
outputs_list.append([])
for j in range(nranks):
data_list[i][j] = data_list[i][j].to(device)
outputs_list[i].append((torch.zeros(nelement)).to(device))
ground_truth_list.append((torch.zeros(nelement)).to(device))

torch.distributed.all_to_all(ground_truth_list, data_list[torch_rank])
self.pg.all_to_all(outputs_list, data_list)

for gt, output in zip(ground_truth_list, outputs_list[torch_rank]):
self.assertTrue(torch.equal(gt, output))

if self.rank == 0:
if os.path.exists(input_file):
os.remove(input_file)


instantiate_parametrized_tests(TestDistributed)

if __name__ == "__main__":
run_tests()
118 changes: 118 additions & 0 deletions test/emulator/test_dtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
################################################################################
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
################################################################################
# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
################################################################################

import os

import numpy as np
from common_dtensor import (
DTensorTestBase, # skip_unless_torch_gpu,
with_comms,
)
from typing import List, cast

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.testing._internal.common_utils import run_tests

import vescale
from vescale.dtensor.dtensor import DTensor
from vescale.dtensor.placement_types import Placement, Replicate, Shard

from vescale.emulator.device_mesh import dump_nccl_graph_for_mesh
from vescale.emulator.distributed import ProcessGroup, dump_nccl_graph_for_pg
from vescale.emulator.comm_api import distribute_tensor, redistribute_dtensor
from vescale.emulator.device_mesh import DeviceMesh
from vescale.emulator.emulator_instrumentation import EmulatorInstrumentation
from emulator.common_emulator import with_comms_emulator


class DistMatrixOpsTest(DTensorTestBase):
@property
def world_size(self) -> int:
return 4

def init_emulator_pg(self):
torch.manual_seed(0)
backend = "nccl"
world_size = self.world_size

vescale.emulator.distributed.init_process_group(backend=backend, world_size=world_size, rank=0)
vescale.emulator.distributed.set_rank(0)
# dump default process group
self.pg: ProcessGroup = vescale.emulator.distributed._world.default_pg
self.torch_pg = torch.distributed.distributed_c10d._get_default_group()
dump_nccl_graph_for_pg(self.pg, self.torch_pg, self.rank)

# dump for other process groups
mesh_tensor = list(range(world_size))
self.vescale_mesh = vescale.dtensor.device_mesh.DeviceMesh(self.device_type, mesh_tensor)
self.mesh = DeviceMesh(self.device_type, mesh_tensor)
dump_nccl_graph_for_mesh(self.mesh, self.vescale_mesh)

def destroy_emulator_pg(self):
vescale.emulator.distributed.destroy_process_group()

@with_comms
@with_comms_emulator
def test_mm(self):
device_mesh = self.mesh
vescale_device_mesh = vescale.dtensor.device_mesh.DeviceMesh(self.device_type, list(range(self.world_size)))
device = f"cuda:{self.rank}"
replica_spec = Replicate()

input_file = "input_dtensors.pt"
if self.rank == 0:
t1 = torch.randn(12, 8, requires_grad=True).cuda()
t2 = torch.randn(8, 12, requires_grad=True).cuda()
torch.save((t1, t2), input_file)
dist.barrier()

t1, t2 = torch.load(input_file)
t1 = t1.to(device)
t2 = t2.to(device)
t1_list = [t1.clone().detach().requires_grad_() for _ in range(self.world_size)]
t2_list = [t2.clone().detach().requires_grad_() for _ in range(self.world_size)]

def test_placement_comb(placements1: List[Placement], placements2: List[Placement]) -> None:
dt1_list = distribute_tensor(t1_list, device_mesh, placements1)
dt2_list = distribute_tensor(t2_list, device_mesh, placements2)

# Emulator replace the given pytorch function to accpet lists of tensors as input
func_list = ["mm"]
indices = [(0, 1)]
with EmulatorInstrumentation(torch, func_list, indices):
dist_res_list = torch.mm(dt1_list, dt2_list)
dist_res_list = redistribute_dtensor(dist_res_list, device_mesh, [replica_spec])

dt1 = vescale.distribute_tensor(t1.clone().detach().requires_grad_(), vescale_device_mesh, placements1)
dt2 = vescale.distribute_tensor(t2.clone().detach().requires_grad_(), vescale_device_mesh, placements2)
dist_res: DTensor = cast(DTensor, torch.mm(dt1, dt2)).redistribute(vescale_device_mesh, [replica_spec])

for dist_res_emu in dist_res_list:
self.assertTrue(torch.equal(dist_res.to_local(), dist_res_emu.to_local()))

shard_specs_comb = [
(Shard(dim=0), Replicate()),
(Shard(dim=1), Shard(dim=0)),
(Replicate(), Shard(dim=1)),
(Replicate(), Replicate()),
]

for spec in shard_specs_comb:
test_placement_comb([spec[0]], [spec[1]])

if self.rank == 0:
if os.path.exists(input_file):
os.remove(input_file)


if __name__ == "__main__":
run_tests()
Loading