Integrated Trainer of Parameter Server (API add `fluid.contrib.layers.sparse_embedding` only) (#22957)

* Integrated Trainer of Parameter Server
fix_copy_if_different
tangwei12 5 years ago committed by GitHub
parent af74675b5a
commit caa90a6510
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,53 +42,18 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
}
}
// get RpcContext and remote send and recv op
// get CommContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
for (auto &node : graphs[0]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames =
BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("epmap"));
auto height_section = BOOST_GET_CONST(
std::vector<int64_t>, node->Op()->GetNullableAttr("sections"));
auto trainer_id =
BOOST_GET_CONST(int, node->Op()->GetNullableAttr("trainer_id"));
auto merge_add =
BOOST_GET_CONST(bool, node->Op()->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler = BOOST_GET_CONST(
bool, node->Op()->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
}
}
}
// init communicator here
if (send_varname_to_ctx.size() > 0) {
auto *instance = operators::distributed::Communicator::GetInstance();
auto initialized = instance ? true : false;
PADDLE_ENFORCE_EQ(initialized, true,
platform::errors::InvalidArgument(
"Communicator is not Initialized, you may use "
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"develop/markdown_doc/transpiler)"));
}
auto *instance = operators::distributed::Communicator::GetInstance();
auto initialized = instance ? true : false;
PADDLE_ENFORCE_EQ(initialized, true,
platform::errors::InvalidArgument(
"Communicator is not Initialized, you may use "
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"develop/markdown_doc/transpiler)"));
#endif
}

@ -122,7 +122,7 @@ class SelectedRows {
/*
* @brief Get the index of the key from id_to_index_ map.
*/
inline int64_t GetIndexFromId(int64_t key) {
inline int64_t GetIndexFromId(int64_t key) const {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;

@ -79,5 +79,6 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) {
PADDLE_THROW("unknown var type to copy");
}
}
} // namespace framework
} // namespace paddle

@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace framework {

@ -13,6 +13,7 @@ cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_rec
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder)
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_library(large_scale_kv SRCS large_scale_kv.cc DEPS enforce simple_threadpool)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
@ -26,7 +27,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${GRPC_SRCS}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor)
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor large_scale_kv)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
@ -50,12 +51,12 @@ else()
set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS})
cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_op)
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_read_op)
endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_op)
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)

@ -446,11 +446,12 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
}
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
const std::string& dirname,
const std::string& varname,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
req.set_varname(varname);
req.set_out_varname(dirname);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
}

@ -102,7 +102,8 @@ class BRPCClient : public RPCClient {
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,91 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
namespace distributed {
struct CommContext {
CommContext() = default;
CommContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false)
: var_name(name),
splited_varnames(names),
epmap(emap),
height_sections(sections),
origin_varnames(origin_names),
trainer_id(id),
merge_add(merge_add_),
is_sparse(is_sparse_),
is_distributed(is_distributed_) {}
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
splited_varnames = ctx.splited_varnames;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add;
is_sparse = ctx.is_sparse;
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
}
std::string print() const {
std::stringstream ss;
ss << "varname: " << var_name << " trainer_id: " << trainer_id << " ";
for (size_t i = 0; i < splited_varnames.size(); i++) {
ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i]
<< " section: " << height_sections[i] << " ";
}
ss << "origin varnames: ";
for (size_t i = 0; i < origin_varnames.size(); i++) {
ss << origin_varnames[i] << " ";
}
ss << " aggregation->add: " << merge_add << " ";
ss << " is_sparse: " << is_sparse << "\n";
ss << " is_distributed: " << is_distributed << "\n";
return ss.str();
}
std::string var_name;
std::vector<std::string> splited_varnames;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
std::vector<std::string> origin_varnames;
int trainer_id;
bool merge_add;
bool is_sparse;
bool is_distributed;
};
} // namespace distributed
} // namespace operators
} // namespace paddle

@ -409,7 +409,8 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
}
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
const std::string& dirname,
const std::string& varname,
int64_t time_out) {
const auto ch = GetChannel(ep);
@ -422,8 +423,8 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
req.set_varname(varname);
req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method);

@ -222,7 +222,8 @@ class GRPCClient : public RPCClient {
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify(

@ -103,11 +103,13 @@ class RequestSend final : public RequestBase {
void Process() override {
std::string varname = GetReqName();
VLOG(4) << "RequestSend var_name:" << varname;
auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestSend var_name:" << varname << " trainer: " << trainer_id;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_);
@ -332,8 +334,9 @@ class RequestPrefetch final : public RequestBase {
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
<< " out_var_name: " << out_var_name << " trainer: " << trainer_id;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);

@ -0,0 +1,26 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag LargeScaleKV::init_flag_;
std::shared_ptr<LargeScaleKV> LargeScaleKV::scale_kv_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -31,7 +31,6 @@ void prefetchs(const std::vector<std::string>& id_var_names,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope);
@ -39,7 +38,6 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope);

File diff suppressed because it is too large Load Diff

@ -18,7 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle {
namespace operators {
@ -26,7 +26,10 @@ namespace distributed {
template <typename T>
struct ParameterRecv {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope);
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool barrier);
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope);
};
}; // namespace distributed

File diff suppressed because it is too large Load Diff

@ -18,7 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle {
namespace operators {
@ -26,7 +26,7 @@ namespace distributed {
template <typename T>
struct ParameterSend {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope,
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool sync, int multi_parts);
};

@ -65,6 +65,7 @@ constexpr int64_t kPrefetchTimeout = 60000;
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"

File diff suppressed because it is too large Load Diff

@ -19,6 +19,7 @@
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
@ -98,6 +99,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& table_name = "") override;
private:
std::unique_ptr<paddle::framework::OperatorBase> PullLargeScaleOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
framework::OpDesc desc;
desc.SetType("lookup_sparse_table_read");
desc.SetInput("Ids", {id_name});
desc.SetOutput("Out", std::vector<std::string>({out_name}));
desc.SetAttr("tablename", {table_name});
desc.SetAttr("init", true);
desc.SetAttr("value_names", std::vector<std::string>({"Param"}));
auto op = paddle::framework::OpRegistry::CreateOp(desc);
return op;
}
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
@ -114,11 +130,9 @@ class RequestPrefetchHandler final : public RequestHandler {
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(int distributed_mode,
int checkpoint_notify_id)
: RequestHandler(distributed_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
explicit RequestCheckpointHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
@ -126,14 +140,30 @@ class RequestCheckpointHandler final : public RequestHandler {
const std::string& table_name = "") override;
private:
int checkpoint_notify_id;
std::unique_ptr<paddle::framework::OperatorBase> BuildCheckpointOp(
const std::string& varname, const std::string& file_path) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("save");
BuildVar("X", {varname.data()}, op_desc.add_inputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("file_path");
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(file_path);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
};
class RequestNotifyHandler final : public RequestHandler {
public:
explicit RequestNotifyHandler(int distributed_mode, int lr_decay_block_id)
explicit RequestNotifyHandler(int distributed_mode, int trainers)
: RequestHandler(distributed_mode) {
this->lr_decay_block_id = lr_decay_block_id;
this->trainers = trainers;
for (int i = 0; i < trainers; i++) {
decay_counters[i] = 0;
}
}
virtual ~RequestNotifyHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
@ -142,7 +172,8 @@ class RequestNotifyHandler final : public RequestHandler {
const std::string& table_name = "") override;
private:
int lr_decay_block_id;
int trainers;
std::unordered_map<int, int64_t> decay_counters;
};
} // namespace distributed

@ -77,8 +77,8 @@ class RPCClient {
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
const std::string& ep, const std::string& dirname,
const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx,

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save