-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchDISC] compile disc nodes with a fake cluster algorithm #173
Merged
Yancey1989
merged 8 commits into
alibaba:features/torch_disc_devel
from
Yancey1989:compile_disc_nodes
Mar 18, 2022
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
b9ff1d4
compile disc node with a fake cluster algo
Yancey1989 94c2e63
fix mhlo conversaion failed
Yancey1989 6f4bdd8
using pre-build pytorch wheel
Yancey1989 6a2115c
update
Yancey1989 30898d5
cleanup code
Yancey1989 f1f1b0b
update
Yancey1989 a8aac34
update
Yancey1989 41260f2
polish code
Yancey1989 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
// Copyright 2022 The BladeDISC Authors. All rights reserved. | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "torch_disc/csrc/backend_impl.h" | ||
|
||
#include <ATen/Functions.h> | ||
#include <torch/csrc/jit/passes/shape_analysis.h> | ||
#include <torch/csrc/lazy/backend/backend_device.h> | ||
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h> | ||
|
||
#include "lazy_tensor_core/csrc/ts_backend/backend_impl.h" | ||
#include "torch_disc/csrc/disc_jit.h" | ||
|
||
namespace torch_disc { | ||
namespace compiler { | ||
|
||
using BackendDeviceType = torch::lazy::BackendDeviceType; | ||
using TSData = torch_lazy_tensors::compiler::TSData; | ||
|
||
struct TSBackendDeviceType : public BackendDeviceType { | ||
TSBackendDeviceType() = delete; | ||
TSBackendDeviceType(c10::DeviceType deviceType) { | ||
TORCH_CHECK(supported_device_types_.find((int8_t)deviceType) != | ||
supported_device_types_.end()); | ||
type = (int8_t)deviceType; | ||
} | ||
|
||
std::string toString() const override { | ||
return c10::DeviceTypeName((c10::DeviceType)type); | ||
} | ||
|
||
c10::DeviceType c10Type() const { return (c10::DeviceType)type; } | ||
|
||
private: | ||
static const std::set<int8_t> supported_device_types_; | ||
}; | ||
const std::set<int8_t> TSBackendDeviceType::supported_device_types_ = { | ||
(int8_t)at::kCPU, (int8_t)at::kCUDA}; | ||
|
||
class DISCBackendImpl : public torch::lazy::BackendImplInterface { | ||
public: | ||
DISCBackendImpl() : default_device_type_(at::kCPU) { | ||
auto type = at::kCPU; | ||
default_device_type_ = TSBackendDeviceType(type); | ||
} | ||
std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext( | ||
const std::string& name, torch::lazy::BackendDevice device, | ||
c10::ArrayRef<torch::lazy::Node*> post_order, | ||
torch::lazy::Util::EmissionMap emit_status) const override { | ||
return std::make_unique<torch::lazy::TSLoweringContext>( | ||
name, device, post_order, emit_status); | ||
} | ||
|
||
std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext( | ||
const std::string& name, | ||
torch::lazy::BackendDevice device) const override { | ||
return std::make_unique<torch::lazy::TSLoweringContext>(name, device); | ||
} | ||
|
||
std::vector<std::string> GetCompilationDevices( | ||
const std::string& device, | ||
c10::ArrayRef<std::string> devices) const override { | ||
return std::vector<std::string>(devices.begin(), devices.end()); | ||
} | ||
|
||
at::Tensor MakeTensorFromComputationData( | ||
const torch::lazy::BackendDataPtr data, | ||
c10::optional<at::ScalarType> logical_scalar_type) const override { | ||
const auto ts_data = std::static_pointer_cast<TSData>(data); | ||
return ts_data->data(); | ||
} | ||
|
||
torch::lazy::BackendDataPtr MakeComputationDataFromTensor( | ||
const at::Tensor& tensor, const torch::lazy::Shape& shape, | ||
const torch::lazy::BackendDevice& device) const override { | ||
at::TensorOptions options = tensor.options().device( | ||
default_device_type_.c10Type(), device.ordinal()); | ||
if (tensor.device().type() == default_device_type_.c10Type() && | ||
default_device_type_.c10Type() == at::kCUDA) { | ||
return std::make_shared<TSData>(tensor.to(options, /*non_blocking=*/true), | ||
shape, device); | ||
} else if (tensor.device().type() == at::kCPU && tensor.numel() == 1) { | ||
// calling .item() on singleton cpu tensor is fast, and using fill is a | ||
// safe, async way to copy cpu to cuda for a single value | ||
auto device_tensor = at::full(tensor.sizes(), tensor.item(), options); | ||
return std::make_shared<TSData>(device_tensor, shape, device); | ||
} else { | ||
return std::make_shared<TSData>( | ||
tensor.to(options, /*non_blocking=*/false), shape, device); | ||
} | ||
} | ||
|
||
torch::lazy::BackendDataPtr MakeComputationDataFromScalar( | ||
const at::Scalar& scalar, | ||
const torch::lazy::BackendDevice& device) const override { | ||
return std::make_shared<TSData>(scalar, device); | ||
} | ||
|
||
std::string GetComputationBackendText( | ||
const torch::lazy::ComputationPtr computation) const override { | ||
auto ts_computation = | ||
static_cast<torch::lazy::TSComputation*>(computation.get()); | ||
return ts_computation->graph()->toString(); | ||
} | ||
|
||
//////////////computation client interfaces/////////////////////// | ||
|
||
public: | ||
torch::lazy::BackendDataPtr CreateDataPlaceholder( | ||
const torch::lazy::BackendDevice& device, | ||
const torch::lazy::Shape& shape) const override; | ||
|
||
std::vector<torch::lazy::ComputationPtr> Compile( | ||
std::vector<torch::lazy::ComputationPtr> instances) const override; | ||
|
||
std::vector<torch::lazy::BackendDataPtr> ExecuteComputation( | ||
torch::lazy::Computation& computation, | ||
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments, | ||
const torch::lazy::BackendDevice& device) const override; | ||
|
||
std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() | ||
const override { | ||
return std::make_shared<BackendDeviceType>(default_device_type_); | ||
} | ||
|
||
at::DeviceType EagerFallbackDeviceType() const override; | ||
|
||
void SetDefaultDeviceType(std::string type) override { | ||
default_device_type_ = TSBackendDeviceType(c10::Device(type).type()); | ||
// The first CUDA usage could happen via lazy tensors. Initialize CUDA here | ||
// to account for that, at::scalar_tensor constructor triggers everything we | ||
// need. | ||
static auto init_cuda = default_device_type_.c10Type() == at::kCUDA | ||
? c10::optional<at::Tensor>(at::scalar_tensor( | ||
0, at::TensorOptions().device(at::kCUDA))) | ||
: c10::nullopt; | ||
} | ||
|
||
std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override; | ||
|
||
torch::lazy::BackendDevice GetBackendDevice( | ||
c10::Device device) const override; | ||
|
||
void SetRngSeed(size_t seed) const override { | ||
LOG(FATAL) << "Not implemented yet."; | ||
} | ||
|
||
// std::map<std::string, Metric> GetMetrics() const override { return {}; } | ||
|
||
// MemoryInfo GetMemoryInfo(const std::string& device) override { | ||
// LOG(FATAL) << "Not implemented yet."; | ||
// } | ||
|
||
void PrepareToExit() const override; | ||
|
||
private: | ||
TSBackendDeviceType default_device_type_; | ||
}; | ||
|
||
torch::lazy::BackendDataPtr DISCBackendImpl::CreateDataPlaceholder( | ||
const torch::lazy::BackendDevice& device, | ||
const torch::lazy::Shape& shape) const { | ||
return std::make_shared<TSData>(shape, device); | ||
} | ||
|
||
std::vector<torch::lazy::ComputationPtr> DISCBackendImpl::Compile( | ||
std::vector<torch::lazy::ComputationPtr> instances) const { | ||
for (const auto& instance : instances) { | ||
auto ts_computation = | ||
static_cast<torch::lazy::TSComputation*>(instance.get()); | ||
} | ||
return instances; | ||
} | ||
|
||
std::vector<torch::lazy::BackendDataPtr> DISCBackendImpl::ExecuteComputation( | ||
torch::lazy::Computation& computation, | ||
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments, | ||
const torch::lazy::BackendDevice& device) const { | ||
auto ts_computation = static_cast<torch::lazy::TSComputation&>(computation); | ||
try { | ||
DiscJIT(ts_computation, arguments); | ||
} catch (std::exception& e) { | ||
LOG(FATAL) << e.what(); | ||
throw(e); | ||
} | ||
torch::jit::GraphExecutor& graph_executor = ts_computation.graph_executor(); | ||
|
||
std::vector<torch::jit::IValue> stack; | ||
for (auto argument : arguments) { | ||
const auto ts_data = std::static_pointer_cast<TSData>(argument); | ||
if (ts_data->scalar.has_value()) { | ||
stack.emplace_back(ts_data->scalar.value()); | ||
} else { | ||
// TODO(whc) should this check be made more general? it's written somewhat | ||
// oddly | ||
CHECK((c10::DeviceType)default_device_type_.type != at::kCUDA || | ||
ts_data->data().device().type() == at::kCUDA); | ||
stack.emplace_back(ts_data->data()); | ||
} | ||
} | ||
graph_executor.run(stack); | ||
std::vector<torch::lazy::BackendDataPtr> results; | ||
for (torch::jit::IValue component : stack) { | ||
at::Tensor result = component.toTensor(); | ||
at::IntArrayRef result_sizes = result.sizes(); | ||
torch::lazy::Shape shape( | ||
result.scalar_type(), | ||
std::vector<int64_t>(result_sizes.begin(), result_sizes.end())); | ||
results.push_back(std::make_shared<TSData>(result, shape, device)); | ||
} | ||
return results; | ||
} | ||
|
||
std::vector<torch::lazy::BackendDevice> DISCBackendImpl::GetBackendDevices() | ||
const { | ||
std::vector<torch::lazy::BackendDevice> devices; | ||
// TODO(whc) figure out how to query available devices from pytorch | ||
devices.emplace_back(GetBackendDevice(c10::Device(c10::kCPU, 0))); | ||
devices.emplace_back(GetBackendDevice(c10::Device(c10::kCUDA, 0))); | ||
return devices; | ||
} | ||
|
||
torch::lazy::BackendDevice DISCBackendImpl::GetBackendDevice( | ||
c10::Device device) const { | ||
// Note, we ignore the device type specified by the c10::Device since it is | ||
// expected to be a virtual device (lazy::), but we need to change this when | ||
// we support lazy as a mode | ||
return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index()); | ||
} | ||
|
||
void DISCBackendImpl::PrepareToExit() const {} | ||
|
||
c10::DeviceType DISCBackendImpl::EagerFallbackDeviceType() const { | ||
// For TS backend, hardware device _is_ eager device | ||
return (c10::DeviceType)GetDefaultDeviceType()->type; | ||
} | ||
|
||
torch::lazy::BackendImplInterface* GetDISCBackendImpl() { | ||
static compiler::DISCBackendImpl* disc_backend_impl = | ||
new compiler::DISCBackendImpl(); | ||
return disc_backend_impl; | ||
} | ||
|
||
void InitTorchScriptBackend() { | ||
static std::unique_ptr<torch::lazy::BackendRegistrar> s_registrar; | ||
s_registrar.reset( | ||
new torch::lazy::BackendRegistrar(compiler::GetDISCBackendImpl())); | ||
} | ||
|
||
} // namespace compiler | ||
} // namespace torch_disc |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should discuss this with PyTorch LTC. We could propose our roadmap and our requirements on LTC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, maybe we should write a detailed design document before communicating with LTC team.