Skip to content

Commit

Permalink
add for gpu cache (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
goodcoder-cnn authored Oct 30, 2020
1 parent ab2db95 commit 73858ce
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 4 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/framework/fleet/box_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "paddle/fluid/platform/gpu_info.h"

DECLARE_bool(use_gpu_replica_cache);

DECLARE_int32(gpu_replica_cache_dim);
namespace paddle {
namespace framework {

Expand Down Expand Up @@ -372,6 +372,7 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place,
EMBEDX_CASE(16, PULLSPARSE_CASE(0););
EMBEDX_CASE(256, PULLSPARSE_CASE(0););
EMBEDX_CASE(128, PULLSPARSE_CASE(0););
EMBEDX_CASE(280, PULLSPARSE_CASE(0););
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupport this embedding size [%d]", hidden_size - 3));
Expand Down Expand Up @@ -413,6 +414,7 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
EMBEDX_CASE(16, PUSHSPARSE_CASE(0););
EMBEDX_CASE(256, PUSHSPARSE_CASE(0););
EMBEDX_CASE(128, PUSHSPARSE_CASE(0););
EMBEDX_CASE(280, PUSHSPARSE_CASE(0););
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupport this embedding size [%d]", hidden_size - 3));
Expand Down Expand Up @@ -467,7 +469,7 @@ void BoxWrapper::FeedPass(int date,
void BoxWrapper::BeginFeedPass(int date, boxps::PSAgentBase** agent) {
int ret = boxps_ptr_->BeginFeedPass(date, *agent);
if(FLAGS_use_gpu_replica_cache){
int dim = BoxWrapper::embedx_dim_;
int dim = FLAGS_gpu_replica_cache_dim;
VLOG(3) << "gpu cache dim:" << dim;
gpu_replica_cache.emplace_back(dim);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/fleet/box_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
EMBEDX_CASE(16, EXPAND_EMBED_PULL_CASE(0););
EMBEDX_CASE(256, EXPAND_EMBED_PULL_CASE(0););
EMBEDX_CASE(128, EXPAND_EMBED_PULL_CASE(0););
EMBEDX_CASE(280, EXPAND_EMBED_PULL_CASE(0););
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupport this embedding size [%d]", hidden_size - 3));
Expand Down Expand Up @@ -305,6 +306,7 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
EMBEDX_CASE(16, EXPAND_EMBED_PUSH_CASE(0););
EMBEDX_CASE(256, EXPAND_EMBED_PUSH_CASE(0););
EMBEDX_CASE(128, EXPAND_EMBED_PUSH_CASE(0););
EMBEDX_CASE(280, EXPAND_EMBED_PUSH_CASE(0););
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupport this embedding size [%d]", hidden_size - 3));
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,6 @@ class BoxWrapper {
expand_embed_dim_ = expand_embed_dim;
is_quant_ = is_quant;
pull_embedx_scale_ = pull_embedx_scale;

if (boxps::MPICluster::Ins().size() > 1) {
data_shuffle_.reset(boxps::PaddleShuffler::New());
data_shuffle_->init(10);
Expand Down Expand Up @@ -844,7 +843,6 @@ class BoxWrapper {
static int expand_embed_dim_;
static bool is_quant_;
static float pull_embedx_scale_;

// Metric Related
int phase_ = 1;
int phase_num_ = 2;
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,5 @@ DEFINE_int32(padbox_slotpool_thread_num, 1,
"PadBoxSlotDataset slot pool thread num");
DEFINE_bool(use_gpu_replica_cache, false,
"if true ,will open use_gpu_replica_cache");
DEFINE_int32(gpu_replica_cache_dim, 8,
"use_gpu_replica_cache,the dim");
1 change: 1 addition & 0 deletions python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __bootstrap__():
'padbox_dataset_shuffle_thread_num',
'padbox_dataset_merge_thread_num',
'use_gpu_replica_cache',
'gpu_replica_cache_dim',
]
core.init_gflags(["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0])
Expand Down

0 comments on commit 73858ce

Please sign in to comment.