Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchDISC] compile disc nodes with a fake cluster algorithm #173

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ jobs:
git submodule sync
git submodule update --init --depth=1
docker pull bladedisc/torch-disc:devel-cuda11.0
docker pull bladedisc/bladedisc:latest-devel-cuda11.0
docker build --cache-from bladedisc/torch-disc:devel-cuda11.0 -t torch-disc-devel-cuda11.0 \
--build-arg PYTORCH_COMMIT=$(git submodule status torch_disc/pytorch | awk '{print $1}') \
-f docker/dev/Dockerfile.torch-disc ./docker/
Expand Down
2 changes: 1 addition & 1 deletion docker/dev/Dockerfile.torch-disc
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ RUN apt-get update -y && \
ln -s /usr/bin/python3.8 /usr/bin/python && \
python -m pip install --upgrade pip && \
python -m pip install cpython pyyaml typing_extensions virtualenv numpy
RUN bash /opt/scripts/install-pytorch-ltc.sh
RUN python -m pip install https://bladedisc-ci.oss-cn-hongkong.aliyuncs.com/download/torch-ltc/torch-1.12.0a0%2Bgitd9896b8-cp38-cp38-linux_x86_64.whl
22 changes: 0 additions & 22 deletions docker/scripts/install-pytorch-ltc.sh

This file was deleted.

8 changes: 6 additions & 2 deletions scripts/ci/test_torch_disc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ set -e

# 1. configure tensorflow
python scripts/python/tao_build.py /opt/venv_disc -s configure --bridge-gcc default --compiler-gcc default
python scripts/python/tao_build.py /opt/venv_disc -s build_mlir_ral
ln -s /workspace/tf_community/bazel-bin/tensorflow/compiler/mlir/disc/disc_compiler_main torch_disc/disc_compiler_main
# 2. using a virtualenv to avoid permission issue
python -m virtualenv --system-site-packages myenv && source myenv/bin/activate
# 3. call LTC code generator, that's used in ts lowering
cd torch_disc
bash pytorch/lazy_tensor_core/scripts/generate_code.sh
# 4. build "_torch_disc.so"
python setup.py develop
# 5. a easy way to test torch_disc, just try to import the pybind library
(cd bazel-bin/torch_disc && python -c "import torch; import _torch_disc")
# 5. test a e2e demo
ln -s bazel-bin/torch_disc/_torch_disc.so ./_torch_disc.so
python disc_demo.py

deactivate
16 changes: 2 additions & 14 deletions torch_disc/disc_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import _torch_disc as disc
Expand All @@ -32,28 +31,17 @@ def forward(self, x):
device = 'lazy'
model = SimpleNet().to(device)

transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

ds = datasets.MNIST('../data', train=True, download=True,
transform=transform)

train_loader = torch.utils.data.DataLoader(ds, batch_size=32, num_workers=1, pin_memory=True, shuffle=True)

optimizer = optim.Adadelta(model.parameters(), lr=1.0)

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

model.train().to(device)

data, target = next(iter(train_loader))
data, target = torch.rand(32, 1, 28, 28), torch.randint(9, (32,))
data, target = data.to(device), target.to(device)

optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
#disc.mark_step()
disc._step_marker()
13 changes: 11 additions & 2 deletions torch_disc/torch_disc/BUILD
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")

pybind_library (
name = "torch_disc_pybind11",
srcs = ["csrc/init_python_bindings.cpp"],
name = "torch_disc_backend",
srcs = ["csrc/backend_impl.cpp", "csrc/disc_jit.cpp"],
hdrs = ["csrc/backend_impl.h", "csrc/disc_jit.h"],
deps = [
"@local_org_torch//:ltc_ts_backend",
# TODO(yancey1989): depends on one module
Expand All @@ -12,6 +13,14 @@ pybind_library (
]
)

pybind_library (
name = "torch_disc_pybind11",
srcs = ["csrc/init_python_bindings.cpp"],
deps = [
":torch_disc_backend",
]
)

pybind_extension(
name = "_torch_disc",
linkopts = ["-Wl,-rpath,$$ORIGIN"],
Expand Down
259 changes: 259 additions & 0 deletions torch_disc/torch_disc/csrc/backend_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
// Copyright 2022 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "torch_disc/csrc/backend_impl.h"

#include <ATen/Functions.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>

#include "lazy_tensor_core/csrc/ts_backend/backend_impl.h"
#include "torch_disc/csrc/disc_jit.h"

namespace torch_disc {
namespace compiler {

using BackendDeviceType = torch::lazy::BackendDeviceType;
using TSData = torch_lazy_tensors::compiler::TSData;

struct TSBackendDeviceType : public BackendDeviceType {
TSBackendDeviceType() = delete;
TSBackendDeviceType(c10::DeviceType deviceType) {
TORCH_CHECK(supported_device_types_.find((int8_t)deviceType) !=
supported_device_types_.end());
type = (int8_t)deviceType;
}

std::string toString() const override {
return c10::DeviceTypeName((c10::DeviceType)type);
}

c10::DeviceType c10Type() const { return (c10::DeviceType)type; }

private:
static const std::set<int8_t> supported_device_types_;
};
const std::set<int8_t> TSBackendDeviceType::supported_device_types_ = {
(int8_t)at::kCPU, (int8_t)at::kCUDA};

class DISCBackendImpl : public torch::lazy::BackendImplInterface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should discuss this with PyTorch LTC. We could propose our roadmap and our requirements on LTC?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, maybe we should write a detailed design document before communicating with LTC team.

public:
DISCBackendImpl() : default_device_type_(at::kCPU) {
auto type = at::kCPU;
default_device_type_ = TSBackendDeviceType(type);
}
std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
const std::string& name, torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status) const override {
return std::make_unique<torch::lazy::TSLoweringContext>(
name, device, post_order, emit_status);
}

std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
const std::string& name,
torch::lazy::BackendDevice device) const override {
return std::make_unique<torch::lazy::TSLoweringContext>(name, device);
}

std::vector<std::string> GetCompilationDevices(
const std::string& device,
c10::ArrayRef<std::string> devices) const override {
return std::vector<std::string>(devices.begin(), devices.end());
}

at::Tensor MakeTensorFromComputationData(
const torch::lazy::BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const override {
const auto ts_data = std::static_pointer_cast<TSData>(data);
return ts_data->data();
}

torch::lazy::BackendDataPtr MakeComputationDataFromTensor(
const at::Tensor& tensor, const torch::lazy::Shape& shape,
const torch::lazy::BackendDevice& device) const override {
at::TensorOptions options = tensor.options().device(
default_device_type_.c10Type(), device.ordinal());
if (tensor.device().type() == default_device_type_.c10Type() &&
default_device_type_.c10Type() == at::kCUDA) {
return std::make_shared<TSData>(tensor.to(options, /*non_blocking=*/true),
shape, device);
} else if (tensor.device().type() == at::kCPU && tensor.numel() == 1) {
// calling .item() on singleton cpu tensor is fast, and using fill is a
// safe, async way to copy cpu to cuda for a single value
auto device_tensor = at::full(tensor.sizes(), tensor.item(), options);
return std::make_shared<TSData>(device_tensor, shape, device);
} else {
return std::make_shared<TSData>(
tensor.to(options, /*non_blocking=*/false), shape, device);
}
}

torch::lazy::BackendDataPtr MakeComputationDataFromScalar(
const at::Scalar& scalar,
const torch::lazy::BackendDevice& device) const override {
return std::make_shared<TSData>(scalar, device);
}

std::string GetComputationBackendText(
const torch::lazy::ComputationPtr computation) const override {
auto ts_computation =
static_cast<torch::lazy::TSComputation*>(computation.get());
return ts_computation->graph()->toString();
}

//////////////computation client interfaces///////////////////////

public:
torch::lazy::BackendDataPtr CreateDataPlaceholder(
const torch::lazy::BackendDevice& device,
const torch::lazy::Shape& shape) const override;

std::vector<torch::lazy::ComputationPtr> Compile(
std::vector<torch::lazy::ComputationPtr> instances) const override;

std::vector<torch::lazy::BackendDataPtr> ExecuteComputation(
torch::lazy::Computation& computation,
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
const torch::lazy::BackendDevice& device) const override;

std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType()
const override {
return std::make_shared<BackendDeviceType>(default_device_type_);
}

at::DeviceType EagerFallbackDeviceType() const override;

void SetDefaultDeviceType(std::string type) override {
default_device_type_ = TSBackendDeviceType(c10::Device(type).type());
// The first CUDA usage could happen via lazy tensors. Initialize CUDA here
// to account for that, at::scalar_tensor constructor triggers everything we
// need.
static auto init_cuda = default_device_type_.c10Type() == at::kCUDA
? c10::optional<at::Tensor>(at::scalar_tensor(
0, at::TensorOptions().device(at::kCUDA)))
: c10::nullopt;
}

std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override;

torch::lazy::BackendDevice GetBackendDevice(
c10::Device device) const override;

void SetRngSeed(size_t seed) const override {
LOG(FATAL) << "Not implemented yet.";
}

// std::map<std::string, Metric> GetMetrics() const override { return {}; }

// MemoryInfo GetMemoryInfo(const std::string& device) override {
// LOG(FATAL) << "Not implemented yet.";
// }

void PrepareToExit() const override;

private:
TSBackendDeviceType default_device_type_;
};

torch::lazy::BackendDataPtr DISCBackendImpl::CreateDataPlaceholder(
const torch::lazy::BackendDevice& device,
const torch::lazy::Shape& shape) const {
return std::make_shared<TSData>(shape, device);
}

std::vector<torch::lazy::ComputationPtr> DISCBackendImpl::Compile(
std::vector<torch::lazy::ComputationPtr> instances) const {
for (const auto& instance : instances) {
auto ts_computation =
static_cast<torch::lazy::TSComputation*>(instance.get());
}
return instances;
}

std::vector<torch::lazy::BackendDataPtr> DISCBackendImpl::ExecuteComputation(
torch::lazy::Computation& computation,
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
const torch::lazy::BackendDevice& device) const {
auto ts_computation = static_cast<torch::lazy::TSComputation&>(computation);
try {
DiscJIT(ts_computation, arguments);
} catch (std::exception& e) {
LOG(FATAL) << e.what();
throw(e);
}
torch::jit::GraphExecutor& graph_executor = ts_computation.graph_executor();

std::vector<torch::jit::IValue> stack;
for (auto argument : arguments) {
const auto ts_data = std::static_pointer_cast<TSData>(argument);
if (ts_data->scalar.has_value()) {
stack.emplace_back(ts_data->scalar.value());
} else {
// TODO(whc) should this check be made more general? it's written somewhat
// oddly
CHECK((c10::DeviceType)default_device_type_.type != at::kCUDA ||
ts_data->data().device().type() == at::kCUDA);
stack.emplace_back(ts_data->data());
}
}
graph_executor.run(stack);
std::vector<torch::lazy::BackendDataPtr> results;
for (torch::jit::IValue component : stack) {
at::Tensor result = component.toTensor();
at::IntArrayRef result_sizes = result.sizes();
torch::lazy::Shape shape(
result.scalar_type(),
std::vector<int64_t>(result_sizes.begin(), result_sizes.end()));
results.push_back(std::make_shared<TSData>(result, shape, device));
}
return results;
}

std::vector<torch::lazy::BackendDevice> DISCBackendImpl::GetBackendDevices()
const {
std::vector<torch::lazy::BackendDevice> devices;
// TODO(whc) figure out how to query available devices from pytorch
devices.emplace_back(GetBackendDevice(c10::Device(c10::kCPU, 0)));
devices.emplace_back(GetBackendDevice(c10::Device(c10::kCUDA, 0)));
return devices;
}

torch::lazy::BackendDevice DISCBackendImpl::GetBackendDevice(
c10::Device device) const {
// Note, we ignore the device type specified by the c10::Device since it is
// expected to be a virtual device (lazy::), but we need to change this when
// we support lazy as a mode
return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index());
}

void DISCBackendImpl::PrepareToExit() const {}

c10::DeviceType DISCBackendImpl::EagerFallbackDeviceType() const {
// For TS backend, hardware device _is_ eager device
return (c10::DeviceType)GetDefaultDeviceType()->type;
}

torch::lazy::BackendImplInterface* GetDISCBackendImpl() {
static compiler::DISCBackendImpl* disc_backend_impl =
new compiler::DISCBackendImpl();
return disc_backend_impl;
}

void InitTorchScriptBackend() {
static std::unique_ptr<torch::lazy::BackendRegistrar> s_registrar;
s_registrar.reset(
new torch::lazy::BackendRegistrar(compiler::GetDISCBackendImpl()));
}

} // namespace compiler
} // namespace torch_disc
Loading