Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Add var->value() feature #18

Merged
merged 1 commit into from
Jun 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/tape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ class ScopeWrapper : public framework::Scope {
for (auto &v : in_vars) {
for (auto &vv : v.second) {
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
vars_[vv->Name()].reset(vv->MutableVar());
}
}
}
for (auto &v : out_vars) {
for (auto &vv : v.second) {
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
vars_[vv->Name()].reset(vv->MutableVar());
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/test_tape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ TEST(Tape, TestMLP) {
filler(input);

auto loss = mean(linear2(linear1(input)));
LOG(INFO) << loss->value();

get_global_tape().Backward(loss);

Expand Down
18 changes: 18 additions & 0 deletions src/variable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@

#include "src/variable.h"

#include "src/tape.h"

namespace paddle {
namespace tape {

std::ostream& operator<<(std::ostream& os, const Variable& var) {
LOG(INFO) << "Printing " << var.Name();
framework::proto::VarType::Type var_type = var.Desc().GetType();
if (var_type == framework::proto::VarType::LOD_TENSOR) {
os << var.Var().Get<framework::LoDTensor>();
} else {
PADDLE_THROW("Variable type is not LOD_TENSOR");
}
return os;
}

void Variable::InitializeVariable() {
LOG(INFO) << "Initialzing " << desc_.Name() << " as " << desc_.GetType();
framework::proto::VarType::Type var_type = desc_.GetType();
Expand All @@ -30,5 +43,10 @@ void Variable::InitializeVariable() {
}
}

const Variable& Variable::value() {
get_global_tape().Forward();
return *this;
}

} // namespace tape
} // namespace paddle
12 changes: 8 additions & 4 deletions src/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace tape {
class Variable;
using VariableHandle = std::shared_ptr<Variable>;

std::ostream& operator<<(std::ostream&, const Variable&);

/*
* Combination of
* framework::VarDesc desc_;
Expand Down Expand Up @@ -61,15 +63,17 @@ class Variable {
// void init(const std::string& initializer,
// const framework::AttributeMap& attrs);

// void value() {};
// Evaluate a variable by running Forward() on the global tape
const Variable& value();

const framework::VarDesc &Desc() const { return desc_; }
framework::VarDesc *MutableDesc() { return &desc_; }
const framework::VarDesc& Desc() const { return desc_; }
framework::VarDesc* MutableDesc() { return &desc_; }

// TODO(tonyyang-svail): No need to expose name
std::string Name() const { return desc_.Name(); }

framework::Variable *Var() { return &var_; }
const framework::Variable& Var() const { return var_; }
framework::Variable* MutableVar() { return &var_; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch.


private:
int count() {
Expand Down