【Paddle.Fleet】Fix tensor table (#30075)

* add tensor table
revert-31562-mean
Chengmo 4 years ago committed by GitHub
parent 1bdf924217
commit 528e03fc08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,7 +30,8 @@ struct CommContext {
const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false, int table_id_ = -1)
bool is_distributed_ = false, int table_id_ = -1,
bool is_tensor_table_ = false)
: var_name(name),
splited_varnames(names),
epmap(emap),
@ -40,7 +41,8 @@ struct CommContext {
merge_add(merge_add_),
is_sparse(is_sparse_),
is_distributed(is_distributed_),
table_id(table_id_) {}
table_id(table_id_),
is_tensor_table(is_tensor_table_) {}
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
@ -53,6 +55,7 @@ struct CommContext {
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
table_id = ctx.table_id;
is_tensor_table = ctx.is_tensor_table;
}
std::string print() const {
@ -75,6 +78,7 @@ struct CommContext {
ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n";
ss << " is_tensor_table: " << is_tensor_table << "\n";
return ss.str();
}
@ -89,6 +93,7 @@ struct CommContext {
bool is_sparse;
bool is_distributed;
int table_id;
bool is_tensor_table;
};
} // namespace distributed

@ -53,15 +53,16 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path,
pserver_ptr_->_server_ptr->table(table_id)->load(path, meta);
}
void FleetWrapper::InitServer(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list,
int index) {
void FleetWrapper::InitServer(
const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index,
const std::vector<framework::ProgramDesc>& server_sub_program) {
if (!is_initialized_) {
VLOG(3) << "Going to init server";
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
new paddle::distributed::PSCore());
pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(),
index);
index, server_sub_program);
is_initialized_ = true;
} else {
VLOG(3) << "Server can be initialized only once";

@ -154,8 +154,10 @@ class FleetWrapper {
// init server
// void InitServer(const std::string& dist_desc,
// const std::vector<uint64_t>& host_sign_list, int index);
void InitServer(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index);
void InitServer(
const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index,
const std::vector<framework::ProgramDesc>& server_sub_program = {});
// init trainer
void InitWorker(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, Scope* scope,

@ -126,12 +126,11 @@ message TableAccessorParameter {
}
message TensorAccessorParameter {
optional string tensor_class = 1;
optional uint32 fea_dim = 2;
optional uint32 emb_dim = 3;
optional string param = 4;
optional string grad = 5;
optional string common_block_map = 6;
optional string feed_var_name = 1;
optional string fetch_var_name = 2;
optional int64 startup_program_id = 3;
optional int64 main_program_id = 4;
optional string tensor_table_class = 6;
}
message CommonAccessorParameter {

@ -719,6 +719,34 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
return fut;
}
std::future<int32_t> BrpcPsClient::push_global_step(int table_id,
int64_t *total_send_data,
void *done) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_GLOBAL_STEP);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
int32_t num_per_shard = 1;
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(int64_t));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t), total_send_data,
num_per_shard * sizeof(int64_t));
PsService_Stub rpc_stub(get_dense_channel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys,

@ -140,7 +140,9 @@ class BrpcPsClient : public PSClient {
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> flush();
virtual std::future<int32_t> send_client2client_msg(

@ -100,6 +100,7 @@ int32_t PsService::initialize() {
_service_handler_map[PS_BARRIER] = &PsService::barrier;
_service_handler_map[PS_START_PROFILER] = &PsService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &PsService::push_global_step;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
@ -526,5 +527,26 @@ int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::push_global_step(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response);
auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) {
set_response_code(response, 0, "run_program data is empty");
return 0;
}
uint32_t num = *(const uint32_t *)(request.data().data());
const int64_t *values =
(const int64_t *)(request.data().data() + sizeof(uint32_t));
auto trainer_id = request.client_id();
if (table->push_dense(values, trainer_id) != 0) {
set_response_code(response, -1, "run_program failed");
}
return 0;
}
} // namespace distributed
} // namespace paddle

@ -110,6 +110,9 @@ class PsService : public PsBaseService {
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_global_step(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;

@ -34,6 +34,9 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
namespace paddle {
namespace distributed {
@ -377,6 +380,37 @@ void Communicator::RpcProfilerControl() {
}
}
void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
Scope *send_scope) {
if (batches == 0) {
return;
}
auto &table_id = ctx.table_id;
size_t request_call_num = _worker_ptr->get_server_nums();
auto &var_name = STEP_COUNTER;
auto *out_var = send_scope->Var(var_name);
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
data[0] = static_cast<int64_t>(batches);
VLOG(3) << "Communicator::SendGlobalStep send: " << batches;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_GLOBAL_STEP) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->push_global_step(table_id, data, closure);
status.wait();
return;
}
void AsyncCommunicator::RecvThread() {
if (!independent_recv_) return;
VLOG(3) << "Independent RecvThread Start and Wait";
@ -465,10 +499,16 @@ void AsyncCommunicator::SendByCommunicator() {
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
if (var_name == STEP_COUNTER) {
MergeVars<int64_t>(var_name, vars[i], send_scope_.get(), 1);
} else {
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
}
if (ctx.is_sparse) {
if (ctx.is_tensor_table) {
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
platform::errors::InvalidArgument(
@ -599,8 +639,18 @@ bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) {
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end())
if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end()) {
return false;
}
if (table_name == STEP_COUNTER) {
VLOG(3) << "send step_counter into queue";
auto tmp_var = std::make_shared<Variable>();
auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({1}));
auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
out_d[0] = 1;
send_varname_to_queue_[table_name]->Push(tmp_var);
}
return true;
}

@ -223,6 +223,9 @@ class Communicator {
// 6. recv sparse param
virtual void RpcRecvSparse(const std::string &varname, int table_id,
Scope *scope);
// 7. send gloabl step
virtual void SendGlobalStep(const CommContext &ctx, int batches,
Scope *send_scope);
virtual ~Communicator() {}
virtual void RpcProfilerControl();
@ -376,8 +379,6 @@ class AsyncCommunicator : public Communicator {
virtual void SendByCommunicator();
virtual void SendGlobalStep(int batches) {}
virtual void RecvByCommunicator();
virtual void RecvNoBarrier();
@ -527,8 +528,6 @@ class GeoCommunicator : public AsyncCommunicator {
void SendByCommunicator() { return; }
void SendGlobalStep(int batches) override { return; }
void RecvByCommunicator() override { return; }
inline std::string GradToParam(const std::string var_name) {

@ -131,6 +131,9 @@ class PSClient {
std::vector<uint64_t> *keys,
int pserver_idx) = 0;
virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data,
void *done) = 0;
virtual void finalize_worker() = 0;
// client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type,

@ -47,6 +47,7 @@ enum PsCmdID {
PS_PUSH_SPARSE_PARAM = 26;
PS_START_PROFILER = 27;
PS_STOP_PROFILER = 28;
PS_PUSH_GLOBAL_STEP = 29;
}
message PsRequestMessage {

@ -53,8 +53,10 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
return server;
}
int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
size_t server_rank) {
int32_t PSServer::configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program) {
scope_.reset(new framework::Scope());
_config = config.server_param();
_rank = server_rank;
_environment = &env;
@ -65,6 +67,7 @@ int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
const auto &downpour_param = _config.downpour_server_param();
uint32_t barrier_table = UINT32_MAX;
uint32_t global_step_table = UINT32_MAX;
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_CLASS(
@ -74,6 +77,12 @@ int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
"BarrierTable") {
barrier_table = downpour_param.downpour_table_param(i).table_id();
}
if (downpour_param.downpour_table_param(i).table_class() ==
"GlobalStepTable") {
global_step_table = downpour_param.downpour_table_param(i).table_id();
}
table->set_program_env(scope_.get(), place_, &server_sub_program);
table->set_shard(_rank, shard_num);
table->initialize(downpour_param.downpour_table_param(i),
config.fs_client_param());
@ -83,6 +92,9 @@ int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->set_table_map(&_table_map);
}
if (global_step_table != UINT32_MAX) {
_table_map[global_step_table]->set_table_map(&_table_map);
}
return initialize();
}

@ -27,6 +27,20 @@
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace distributed {
@ -40,8 +54,9 @@ class PSServer {
PSServer(PSServer &&) = delete;
PSServer(const PSServer &) = delete;
virtual int32_t configure(const PSParameter &config, PSEnvironment &env,
size_t server_rank) final;
virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) final;
// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
@ -86,6 +101,10 @@ class PSServer {
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
protected:
std::shared_ptr<framework::Scope> scope_;
platform::Place place_ = platform::CPUPlace();
};
REGISTER_REGISTERER(PSServer);

@ -66,9 +66,10 @@ void PSCore::init_gflag(const std::string& gflags) {
::google::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
int PSCore::init_server(const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num, int index) {
int PSCore::init_server(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index,
const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
@ -76,7 +77,7 @@ int PSCore::init_server(const std::string& dist_desc,
int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(_ps_param));
ret = _server_ptr->configure(_ps_param, _ps_env, index);
ret = _server_ptr->configure(_ps_param, _ps_env, index, server_sub_program);
CHECK(ret == 0) << "failed to configure server";
return ret;
}

@ -33,9 +33,10 @@ class PSCore {
explicit PSCore() {}
virtual ~PSCore() {}
virtual int init_server(const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num, int index);
virtual int init_server(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index,
const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int init_worker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>&

@ -11,8 +11,9 @@ cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse
set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(tensor_accessor SRCS tensor_accessor.cc DEPS ${TABLE_DEPS} eigen3 ps_framework_proto device_context)
cc_library(tensor_accessor SRCS tensor_accessor.cc DEPS ${TABLE_DEPS} eigen3 ps_framework_proto device_context)
cc_library(tensor_table SRCS tensor_table.cc DEPS eigen3 ps_framework_proto executor scope device_context tensor ${TABLE_DEPS})
set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(table SRCS table.cc DEPS common_table tensor_accessor ps_framework_proto string_helper device_context gflags glog boost)
cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)

@ -42,6 +42,7 @@ int32_t CommonDenseTable::initialize() {
sync = _config.common().sync();
VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;
_global_lr = new float(1.0);
initialize_value();
initialize_optimizer();
@ -81,8 +82,10 @@ int32_t CommonDenseTable::initialize_optimizer() {
if (name == "sgd") {
optimizer_ = std::make_shared<DSGD>(common, &values_);
optimizer_->set_global_lr(_global_lr);
} else if (name == "adam") {
optimizer_ = std::make_shared<DAdam>(common, &values_);
optimizer_->set_global_lr(_global_lr);
} else if (name == "sum") {
optimizer_ = std::make_shared<DSUM>(common, &values_);
} else {
@ -92,6 +95,12 @@ int32_t CommonDenseTable::initialize_optimizer() {
return 0;
}
int32_t CommonDenseTable::set_global_lr(float* lr) {
_global_lr = lr;
optimizer_->set_global_lr(_global_lr);
return 0;
}
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values);

@ -42,6 +42,7 @@ class CommonDenseTable : public DenseTable {
virtual int32_t push_dense_param(const float* values, size_t num) override;
virtual int32_t push_dense(const float* values, size_t num) override;
virtual int32_t pour() override;
virtual int32_t set_global_lr(float* lr) override;
int32_t load(const std::string& path, const std::string& param) override {
VLOG(0) << "Dense table may load by "

@ -175,6 +175,8 @@ int32_t CommonSparseTable::initialize() {
sync = _config.common().sync();
VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;
_global_lr = new float(1.0);
auto common = _config.common();
int size = static_cast<int>(common.params().size());
@ -249,9 +251,11 @@ int32_t CommonSparseTable::initialize_optimizer() {
if (name == "sgd") {
optimizer_ = std::make_shared<SSGD>(value_names_, value_dims_,
value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr);
} else if (name == "adam") {
optimizer_ = std::make_shared<SAdam>(value_names_, value_dims_,
value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr);
} else if (name == "sum") {
optimizer_ = std::make_shared<SSUM>(value_names_, value_dims_,
value_offsets_, value_idx_);
@ -263,6 +267,12 @@ int32_t CommonSparseTable::initialize_optimizer() {
return 0;
}
int32_t CommonSparseTable::set_global_lr(float* lr) {
_global_lr = lr;
optimizer_->set_global_lr(_global_lr);
return 0;
}
int32_t CommonSparseTable::load(const std::string& path,
const std::string& param) {
rwlock_->WRLock();

@ -69,6 +69,8 @@ class CommonSparseTable : public SparseTable {
virtual int32_t push_sparse_param(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t set_global_lr(float* lr) override;
virtual int32_t pour();
virtual int32_t flush();
virtual int32_t shrink();

@ -36,6 +36,10 @@ class DenseOptimizer {
std::vector<std::vector<float>>* values) {}
virtual void update(const float* update_values, size_t num, int begin,
int end) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; }
protected:
float* global_learning_rate_;
};
// sum calc for dense tensor
@ -84,8 +88,10 @@ class DSGD : public DenseOptimizer {
grads.resize(update_numel);
auto blas = GetBlas<float>();
float lr = *(global_learning_rate_) * (*learning_rate);
VLOG(4) << "DSGD LearningRate: " << lr;
blas.VCOPY(update_numel, update_values + begin, grads.data());
blas.SCAL(update_numel, *learning_rate, grads.data());
blas.SCAL(update_numel, lr, grads.data());
blas.VSUB(update_numel, param + begin, grads.data(), param + begin);
}
@ -150,7 +156,8 @@ class DAdam : public DenseOptimizer {
beta1_pow[0] = beta1_pow[0] * beta1;
beta2_pow[0] = beta2_pow[0] * beta2;
float lr_ = learning_rate[0];
float lr_ = *(global_learning_rate_)*learning_rate[0];
VLOG(4) << "DAdam LearningRate: " << lr_;
lr_ *= sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]);
float* tmp_ = tmp.data();

@ -44,12 +44,17 @@ class SparseOptimizer {
size_t num, const std::vector<uint64_t>& offsets,
ValueBlock* block) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; }
const std::vector<std::string>& value_names_;
const std::vector<int>& value_dims_;
const std::vector<int>& value_offsets_;
const std::unordered_map<std::string, int>& value_idx_;
int param_offset = 0;
int update_numel = 0;
protected:
float* global_learning_rate_;
};
// sum calc for sparse tensor
@ -102,13 +107,14 @@ class SSGD : public SparseOptimizer {
auto id = keys[x];
auto* value = block->Get(id);
float* learning_rate = value + lr_offset;
float learning_rate = *(global_learning_rate_) * (value + lr_offset)[0];
VLOG(4) << "SSGD LearningRate: " << learning_rate;
float* param = value + param_offset;
std::vector<float> grads;
grads.resize(update_numel);
blas.VCOPY(update_numel, update_values + x * update_numel, grads.data());
blas.SCAL(update_numel, learning_rate[0], grads.data());
blas.SCAL(update_numel, learning_rate, grads.data());
blas.VSUB(update_numel, param, grads.data(), param);
}
}
@ -156,7 +162,8 @@ class SAdam : public SparseOptimizer {
for (auto x : offsets) {
auto id = keys[x];
auto* values = block->Get(id);
float* learning_rate = values + lr_offset;
float lr_ = *(global_learning_rate_) * (values + lr_offset)[0];
VLOG(4) << "SAdam LearningRate: " << lr_;
float* param = values + param_offset;
float* moment1 = values + m1_offset;
float* moment2 = values + m2_offset;
@ -166,7 +173,6 @@ class SAdam : public SparseOptimizer {
beta1_pow[0] = beta1_pow[0] * beta1;
beta2_pow[0] = beta2_pow[0] * beta2;
float lr_ = learning_rate[0];
lr_ *= sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]);
std::vector<float> grad, grad2, tmp;

@ -22,6 +22,7 @@
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/table/tensor_accessor.h"
#include "paddle/fluid/distributed/table/tensor_table.h"
namespace paddle {
namespace distributed {
@ -30,7 +31,9 @@ REGISTER_CLASS(Table, CommonDenseTable);
REGISTER_CLASS(Table, CommonSparseTable);
REGISTER_CLASS(Table, SparseGeoTable);
REGISTER_CLASS(Table, BarrierTable);
REGISTER_CLASS(Table, TensorTable);
REGISTER_CLASS(Table, DenseTensorTable);
REGISTER_CLASS(Table, GlobalStepTable);
REGISTER_CLASS(ValueAccessor, CommMergeAccessor);
int32_t TableManager::initialize() {

@ -20,8 +20,11 @@
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
@ -35,6 +38,10 @@ class Table {
virtual int32_t pull_dense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0;
// for push global_step
virtual int32_t push_dense(const int64_t *values, const int32_t trainer_id) {
return 0;
}
virtual int32_t push_dense_param(const float *values, size_t num) {
return 0;
}
@ -67,6 +74,18 @@ class Table {
return 0;
}
// only for tensor table
virtual int32_t set_program_env(
framework::Scope *scope, platform::Place place,
const std::vector<framework::ProgramDesc> *sub_program) {
return 0;
}
virtual int32_t set_global_lr(float *lr) {
_global_lr = lr;
return 0;
}
virtual int32_t pour() { return 0; }
virtual void clear() = 0;
@ -105,6 +124,7 @@ class Table {
size_t _shard_idx; // table 分片编号
size_t _shard_num; // table 分片总数
TableParameter _config;
float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor;
};
REGISTER_REGISTERER(Table);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save