Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add value() feature (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
kexinzhao authored Jun 19, 2018
1 parent 10a42ea commit a3ce862
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/tape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ class ScopeWrapper : public framework::Scope {
for (auto &v : in_vars) {
for (auto &vv : v.second) {
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
vars_[vv->Name()].reset(vv->MutableVar());
}
}
}
for (auto &v : out_vars) {
for (auto &vv : v.second) {
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
vars_[vv->Name()].reset(vv->MutableVar());
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/test_tape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ TEST(Tape, TestMLP) {
filler(input);

auto loss = mean(linear2(linear1(input)));
LOG(INFO) << loss->value();

get_global_tape().Backward(loss);

Expand Down
18 changes: 18 additions & 0 deletions src/variable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@

#include "src/variable.h"

#include "src/tape.h"

namespace paddle {
namespace tape {

std::ostream& operator<<(std::ostream& os, const Variable& var) {
LOG(INFO) << "Printing " << var.Name();
framework::proto::VarType::Type var_type = var.Desc().GetType();
if (var_type == framework::proto::VarType::LOD_TENSOR) {
os << var.Var().Get<framework::LoDTensor>();
} else {
PADDLE_THROW("Variable type is not LOD_TENSOR");
}
return os;
}

void Variable::InitializeVariable() {
LOG(INFO) << "Initialzing " << desc_.Name() << " as " << desc_.GetType();
framework::proto::VarType::Type var_type = desc_.GetType();
Expand All @@ -30,5 +43,10 @@ void Variable::InitializeVariable() {
}
}

const Variable& Variable::value() {
get_global_tape().Forward();
return *this;
}

} // namespace tape
} // namespace paddle
12 changes: 8 additions & 4 deletions src/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace tape {
class Variable;
using VariableHandle = std::shared_ptr<Variable>;

std::ostream& operator<<(std::ostream&, const Variable&);

/*
* Combination of
* framework::VarDesc desc_;
Expand Down Expand Up @@ -61,15 +63,17 @@ class Variable {
// void init(const std::string& initializer,
// const framework::AttributeMap& attrs);

// void value() {};
// Evaluate a variable by running Forward() on the global tape
const Variable& value();

const framework::VarDesc &Desc() const { return desc_; }
framework::VarDesc *MutableDesc() { return &desc_; }
const framework::VarDesc& Desc() const { return desc_; }
framework::VarDesc* MutableDesc() { return &desc_; }

// TODO(tonyyang-svail): No need to expose name
std::string Name() const { return desc_.Name(); }

framework::Variable *Var() { return &var_; }
const framework::Variable& Var() const { return var_; }
framework::Variable* MutableVar() { return &var_; }

private:
int count() {
Expand Down

0 comments on commit a3ce862

Please sign in to comment.