-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: This is PoC for Pallas integration. Currently, it can run Pallas kernels that take arbitrary tensors as inputs and output a single tensor. The design doc is here: go/pytorch-xla-pallas. Test Plan: PJRT_DEVICE=TPU python test/test_operations.py -v -k test_tpu_custom_call
- Loading branch information
1 parent
de5d764
commit 56db8f2
Showing
10 changed files
with
150 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#include "torch_xla/csrc/ops/tpu_custom_call.h" | ||
|
||
#include "torch_xla/csrc/lowering_context.h" | ||
#include "torch_xla/csrc/ops/xla_ops.h" | ||
#include "torch_xla/csrc/xla_lower_util.h" | ||
|
||
namespace torch_xla { | ||
|
||
TpuCustomCall::TpuCustomCall(torch::lazy::OpList inputs, | ||
xla::Shape output_shape, | ||
const std::string& payload) | ||
: XlaNode(xla_tpu_custom_call, inputs, std::move(output_shape), | ||
/*num_outputs=*/1, torch::lazy::MHash(payload)), | ||
payload_(payload) {} | ||
|
||
torch::lazy::NodePtr TpuCustomCall::Clone(torch::lazy::OpList operands) const { | ||
return torch::lazy::MakeNode<TpuCustomCall>(operands, xla_shape(), payload_); | ||
} | ||
|
||
XlaOpVector TpuCustomCall::Lower(LoweringContext* loctx) const { | ||
std::vector<xla::XlaOp> inputs; | ||
inputs.reserve(operands().size()); | ||
for (auto& operand : operands()) { | ||
inputs.push_back(loctx->GetOutputOp(operand)); | ||
} | ||
xla::XlaOp output = BuildTpuCustomCall(inputs, xla_shape(), payload_); | ||
return ReturnOp(output, loctx); | ||
} | ||
|
||
std::string TpuCustomCall::ToString() const { | ||
std::stringstream ss; | ||
ss << XlaNode::ToString() << ", " << payload_; | ||
return ss.str(); | ||
} | ||
|
||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#ifndef XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_ | ||
#define XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_ | ||
|
||
#include "torch_xla/csrc/ir.h" | ||
|
||
namespace torch_xla { | ||
|
||
class TpuCustomCall : public XlaNode { | ||
public: | ||
// Make a TPU custom call with payload, e.g., Mosaic. | ||
TpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, | ||
const std::string& payload); | ||
|
||
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; | ||
|
||
XlaOpVector Lower(LoweringContext* loctx) const override; | ||
|
||
std::string ToString() const override; | ||
|
||
private: | ||
std::string payload_; | ||
}; | ||
|
||
} // namespace torch_xla | ||
|
||
#endif // XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters