Skip to content

Commit

Permalink
Release pod tensor (#8552)
Browse files Browse the repository at this point in the history
* ThreadLocalGuard

* split ReleaseTensor into ReleasePodTensor and ReleaseNonPodTensor.

* rename

Co-authored-by: luyang <flowingsun007@163.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 11, 2022
1 parent 44886c1 commit bfaa258
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 33 deletions.
96 changes: 65 additions & 31 deletions oneflow/core/eager/release_tensor_instruction_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,10 @@ namespace vm {
class ReleaseTensorInstructionType : public vm::InstructionType {
public:
ReleaseTensorInstructionType() = default;
~ReleaseTensorInstructionType() override = default;
virtual ~ReleaseTensorInstructionType() = default;

InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAtAnyPosition; }

std::string DebugName(const vm::Instruction& instruction) const override {
return "ReleaseTensor";
}
Maybe<void> Prepare(vm::Instruction* instruction) const override {
const auto& eager_blob_object = GetEagerBlobObject(*instruction);
DataType data_type = eager_blob_object->data_type();
if (IsPODDataType(data_type)) { Release(eager_blob_object); }
return Maybe<void>::Ok();
}
void Compute(vm::Instruction* instruction) const override {
const auto& eager_blob_object = GetEagerBlobObject(*instruction);
DataType data_type = eager_blob_object->data_type();
if (!IsPODDataType(data_type)) { Release(eager_blob_object); }
}
void InitInstructionStatus(Instruction* instruction) const override {
auto* status_buffer = instruction->mut_status_buffer();
auto* stream = instruction->mut_stream();
Expand All @@ -57,7 +43,7 @@ class ReleaseTensorInstructionType : public vm::InstructionType {
EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_ep_event(nullptr);
}

private:
protected:
const std::shared_ptr<vm::EagerBlobObject>& GetEagerBlobObject(
const vm::Instruction& instruction) const {
const auto& phy_instr_operand = instruction.phy_instr_operand();
Expand All @@ -72,35 +58,83 @@ class ReleaseTensorInstructionType : public vm::InstructionType {
}
};

class FastReleaseTensorInstructionType final : public ReleaseTensorInstructionType {
public:
FastReleaseTensorInstructionType() = default;
~FastReleaseTensorInstructionType() override = default;

std::string DebugName(const vm::Instruction& instruction) const override {
return "ReleasePodTensor";
}

Maybe<void> Prepare(vm::Instruction* instruction) const override {
const auto& eager_blob_object = GetEagerBlobObject(*instruction);
DataType data_type = eager_blob_object->data_type();
CHECK(IsPODDataType(data_type));
Release(eager_blob_object);
return Maybe<void>::Ok();
}

void Compute(vm::Instruction* instruction) const override {}
};

class SlowReleaseTensorInstructionType final : public ReleaseTensorInstructionType {
public:
SlowReleaseTensorInstructionType() = default;
~SlowReleaseTensorInstructionType() override = default;

std::string DebugName(const vm::Instruction& instruction) const override {
return "ReleaseNonPodTensor";
}

Maybe<void> Prepare(vm::Instruction* instruction) const override { return Maybe<void>::Ok(); }

void Compute(vm::Instruction* instruction) const override {
const auto& eager_blob_object = GetEagerBlobObject(*instruction);
DataType data_type = eager_blob_object->data_type();
CHECK(!IsPODDataType(data_type));
Release(eager_blob_object);
}
};

} // namespace vm

struct GetReleaseInstructionType : public StreamRoleVisitor<GetReleaseInstructionType> {
static Maybe<const vm::InstructionType*> VisitCompute(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
static Maybe<const vm::InstructionType*> VisitCompute(DataType data_type) {
return GetReleaseTensorInstructionType(data_type);
}
static Maybe<const vm::InstructionType*> VisitHost2Device(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
static Maybe<const vm::InstructionType*> VisitHost2Device(DataType data_type) {
return GetReleaseTensorInstructionType(data_type);
}
static Maybe<const vm::InstructionType*> VisitDevice2Host(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
static Maybe<const vm::InstructionType*> VisitDevice2Host(DataType data_type) {
return GetReleaseTensorInstructionType(data_type);
}
static Maybe<const vm::InstructionType*> VisitSyncedLaunchedCommNet(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
static Maybe<const vm::InstructionType*> VisitSyncedLaunchedCommNet(DataType data_type) {
return GetReleaseTensorInstructionType(data_type);
}
static Maybe<const vm::InstructionType*> VisitAsyncedLaunchedCommNet(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
static Maybe<const vm::InstructionType*> VisitAsyncedLaunchedCommNet(DataType data_type) {
return GetReleaseTensorInstructionType(data_type);
}
static Maybe<const vm::InstructionType*> VisitBarrier(DeviceType device_type) {
static Maybe<const vm::InstructionType*> VisitBarrier(DataType data_type) {
UNIMPLEMENTED_THEN_RETURN();
}
static Maybe<const vm::InstructionType*> VisitCriticalSection(DeviceType device_type) {
static Maybe<const vm::InstructionType*> VisitCriticalSection(DataType data_type) {
UNIMPLEMENTED_THEN_RETURN();
}
static Maybe<const vm::InstructionType*> VisitLazyJobLauncher(DeviceType device_type) {
static Maybe<const vm::InstructionType*> VisitLazyJobLauncher(DataType data_type) {
UNIMPLEMENTED_THEN_RETURN();
}
static Maybe<const vm::InstructionType*> VisitPinnedCompute(DeviceType device_type) {
return VisitCompute(device_type);
static Maybe<const vm::InstructionType*> VisitPinnedCompute(DataType data_type) {
return VisitCompute(data_type);
}

private:
static Maybe<const vm::InstructionType*> GetReleaseTensorInstructionType(DataType data_type) {
if (IsPODDataType(data_type)) {
return SingletonPtr<vm::FastReleaseTensorInstructionType>();
} else {
return SingletonPtr<vm::SlowReleaseTensorInstructionType>();
}
}
};

Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/framework/instructions_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,10 @@ Maybe<void> InstructionsBuilder::ReleaseTensor(
const auto& phy_instr_operand =
std::make_shared<vm::ReleaseTensorArgPhyInstrOperand>(eager_blob_object, vm_stream);
StreamRole stream_role = producer_stream->stream_role();
DeviceType device_type = producer_stream->device()->enum_type();
DataType data_type = eager_blob_object->data_type();
auto instruction = intrusive::make_shared<vm::Instruction>(
JUST(Singleton<VirtualMachine>::Get()->GetVmStream(producer_stream)),
JUST(GetReleaseInstructionType::Visit(stream_role, device_type)), phy_instr_operand);
JUST(GetReleaseInstructionType::Visit(stream_role, data_type)), phy_instr_operand);
instruction_list_->EmplaceBack(std::move(instruction));
return Maybe<void>::Ok();
}
Expand Down

0 comments on commit bfaa258

Please sign in to comment.