Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_paraexe_bcast

port
yi.wu 7 years ago
commit 88cb47bd86

@ -114,7 +114,12 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c)
FILE(WRITE ${dummyfile} "const char *dummy_cblas = \"${dummyfile}\";")
ADD_LIBRARY(cblas STATIC ${dummyfile})
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
IF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
TARGET_LINK_LIBRARIES(cblas dynload_mklml)
ELSE()
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
ENDIF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
IF(NOT ${CBLAS_FOUND})
ADD_DEPENDENCIES(cblas extern_openblas)

@ -195,6 +195,15 @@ function(cc_library TARGET_NAME)
list(REMOVE_ITEM cc_library_DEPS warpctc)
add_dependencies(${TARGET_NAME} warpctc)
endif()
# Only deps libmklml.so, not link
if("${cc_library_DEPS};" MATCHES "mklml;")
list(REMOVE_ITEM cc_library_DEPS mklml)
if(NOT "${TARGET_NAME}" MATCHES "dynload_mklml")
list(APPEND cc_library_DEPS dynload_mklml)
endif()
add_dependencies(${TARGET_NAME} mklml)
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed")
endif()
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
endif()

@ -1,10 +1,10 @@
# Inference High-level APIs
This document describes the high-level inference APIs one can use to easily deploy a Paddle model for an application.
This document describes the high-level inference APIs, one can use them to deploy a Paddle model for an application quickly.
The APIs are described in `paddle_inference_api.h`, just one header file, and two libaries `libpaddle_fluid.so` and `libpaddle_fluid_api.so` are needed.
The APIs are described in `paddle_inference_api.h`, just one header file, and two libaries `libpaddle_fluid.so` and `libpaddle_fluid_api.so` are needed for a deployment.
## PaddleTensor
We provide the `PaddleTensor` data structure is to give a general tensor interface.
We provide the `PaddleTensor` data structure to give a general tensor interface.
The definition is
@ -17,18 +17,19 @@ struct PaddleTensor {
};
```
The data is stored in a continuous memory `PaddleBuf`, and tensor's data type is specified by a `PaddleDType`.
The `name` field is used to specify the name of input variable,
that is important when there are multiple inputs and need to distiuish which variable to set.
The data is stored in a continuous memory `PaddleBuf,` and a `PaddleDType` specifies tensor's data type.
The `name` field is used to specify the name of an input variable,
that is important when there are multiple inputs and need to distinguish which variable to set.
## engine
The inference APIs has two different underlying implementation, currently there are two valid engines:
The inference APIs has two different underlying engines
- the native engine, which is consists of the native operators and framework,
- the Anakin engine, which is a Anakin library embeded.
- the Anakin engine, which has an Anakin library embedded.
The native engine takes a native Paddle model as input, and supports any model that trained by Paddle,
but the Anakin engine can only take the Anakin model as input(user need to manully transform the format first) and currently not all Paddle models are supported.
the Anakin engine is faster for some model,
but it can only take the Anakin model as input(user need to transform the format first manually) and currently not all Paddle models are supported.
```c++
enum class PaddleEngineKind {
@ -38,10 +39,10 @@ enum class PaddleEngineKind {
```
## PaddlePredictor and how to create one
The main interface is `PaddlePredictor`, there are following methods
The main interface is `PaddlePredictor,` there are following methods
- `bool Run(const std::vector<PaddleTensor>& inputs, std::vector<PaddleTensor>* output_data)`
- take inputs and output `output_data`
- take inputs and output `output_data.`
- `Clone` to clone a predictor from an existing one, with model parameter shared.
There is a factory method to help create a predictor, and the user takes the ownership of this object.
@ -51,9 +52,9 @@ template <typename ConfigT, PaddleEngineKind engine = PaddleEngineKind::kNative>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config);
```
By specifying the engine kind and config, one can get an specific implementation.
By specifying the engine kind and config, one can get a specific implementation.
## Reference
- [paddle_inference_api.h](./paddle_inference_api.h)
- [demos](./demo)
- [some demos](./demo)

@ -207,53 +207,56 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices_.emplace(var_name, op_dev_id);
}
}
if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(op->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) {
try {
auto backward_vars =
boost::get<std::vector<std::string>>(op->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) {
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
InsertAllReduceOp(&result, g_name);
}
break;
default:
LOG(FATAL) << "Unknown reduce strategy ";
break;
} else {
// This op runs on all devices, and its output may have parameter's
// gradients.
CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(op->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) {
try {
auto backward_vars =
boost::get<std::vector<std::string>>(op->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) {
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
InsertAllReduceOp(&result, g_name);
}
break;
default:
LOG(FATAL) << "Unknown reduce strategy ";
break;
}
}
} catch (boost::bad_get e) {
}
} catch (boost::bad_get e) {
}
}
}

@ -19,8 +19,8 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_MKLML
#include <mkl_service.h>
#include <omp.h>
#endif
@ -164,7 +164,7 @@ TEST(inference, nlp) {
// only use 1 thread number per std::thread
omp_set_dynamic(0);
omp_set_num_threads(1);
mkl_set_num_threads(1);
paddle::operators::math::SetNumThreads(1);
#endif
double start_ms = 0, stop_ms = 0;

@ -195,7 +195,7 @@ if(WITH_DISTRIBUTE)
endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
foreach(dist_op "prefetch_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
foreach(dist_op "prefetch_op" "checkpoint_notify_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endforeach()
@ -216,7 +216,7 @@ if(WITH_DISTRIBUTE)
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif()
else()
set(DEPS_OPS ${DEPS_OPS} prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif()
op_library(cross_entropy_op DEPS cross_entropy)

@ -0,0 +1,88 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
class CheckpointNotifyOp : public framework::OperatorBase {
public:
CheckpointNotifyOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < epmap.size(); i++) {
auto lookup_table_save_dir =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir);
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i];
}
rpc_client->Wait();
}
};
class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>(
"dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name");
AddComment(R"DOC(
CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server.
)DOC");
}
};
class CheckpointNotifyOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp,
paddle::framework::EmptyGradOpMaker,
ops::CheckpointNotifyOpMaker,
ops::CheckpointNotifyOpShapeInference);

@ -55,26 +55,24 @@ class BRPCClient : public RPCClient {
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override;

@ -239,6 +239,23 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
req_count_++;
}
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });

@ -171,6 +171,20 @@ class FetchBarrierProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class CheckpointNotifyProcessor : public BaseProcessor {
public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~CheckpointNotifyProcessor() {}
virtual void Process() {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class GRPCClient : public RPCClient {
public:
GRPCClient() {}
@ -178,24 +192,27 @@ class GRPCClient : public RPCClient {
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_grpc_deadline) override;
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_grpc_deadline) override;
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_grpc_deadline) override;
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline) override;
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline) override;
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override;
@ -211,7 +228,7 @@ class GRPCClient : public RPCClient {
void Proceed();
void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline);
int64_t time_out = FLAGS_rpc_deadline);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);

@ -200,6 +200,45 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* local_scope_;
};
class RequestCheckpointNotify final : public RequestBase {
public:
explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx()));
int method_id =
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestCheckpointNotify() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
auto scope = request_->GetMutableLocalScope();
std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
checkpoint_dir);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<VariableResponse> request_;
sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
void AsyncGRPCServer::WaitServerReady() {
VLOG(4) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
@ -237,6 +276,7 @@ void AsyncGRPCServer::StartServer() {
reqs.reserve(kRequestBufSize);
for (int i = 0; i < kRequestBufSize; i++) {
VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
TryToRegisterNewOne(rpc_name, i);
}
@ -289,8 +329,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
return;
}
VLOG(4) << "register send rpc_name:" << rpc_name
<< ", handler:" << rpc_call_map_[kRequestSend];
VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
<< " REQ ID: " << req_id;
auto& reqs = rpc_reqs_[rpc_name];
auto& handler = rpc_call_map_[rpc_name];
@ -303,6 +343,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not supported rpc");
}
@ -321,7 +363,7 @@ void AsyncGRPCServer::HandleRequest(
while (true) {
VLOG(4) << "HandleRequest " << rpc_name << " wait next";
if (!cq->Next(&tag, &ok)) {
LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!";
VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!";
break;
}

@ -80,10 +80,11 @@ enum class GrpcMethod {
kSendVariable,
kGetVariable,
kPrefetchVariable,
kCheckpointNotify,
};
static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kPrefetchVariable) + 1;
static_cast<int>(GrpcMethod::kCheckpointNotify) + 1;
inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
@ -93,6 +94,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendRecvService/PrefetchVariable";
case GrpcMethod::kCheckpointNotify:
return "/sendrecv.SendRecvService/CheckpointNotify";
}
// Shouldn't be reached.

@ -36,12 +36,16 @@ namespace distributed {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
class RPCServer;
class RequestHandler {
@ -69,6 +73,11 @@ class RequestHandler {
prefetch_var_name_to_prepared_ctx_ = g;
}
void SetCheckpointNotifyPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
checkpoint_prepared_ctx_ = g;
}
// Used for async.
void SetGradToPreparedCtx(
std::unordered_map<
@ -115,6 +124,8 @@ class RequestHandler {
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_;
// used for checkpoint notify
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
// Used for async.
std::unordered_map<std::string,

@ -22,11 +22,16 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
namespace distributed {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
@ -119,6 +124,24 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true;
}
bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const std::string& out_var_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true;
}
} // namespace distributed
} // namespace operators
} // namespace paddle

@ -66,6 +66,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override;
};
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id)
: RequestHandler(sync_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
private:
int checkpoint_notify_id;
};
} // namespace distributed
} // namespace operators
} // namespace paddle

@ -16,7 +16,7 @@
#include "gflags/gflags.h"
// default to 3min to avoid temprary network failures.
DEFINE_int32(grpc_deadline, 180000, "deadline timeouts for grpc");
DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc");
namespace paddle {
namespace operators {

@ -21,7 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
DECLARE_int32(grpc_deadline);
DECLARE_int32(rpc_deadline);
namespace paddle {
namespace operators {
@ -35,26 +35,30 @@ class RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_grpc_deadline) = 0;
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_grpc_deadline) = 0;
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_grpc_deadline) = 0;
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue

@ -25,6 +25,8 @@ service SendRecvService {
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
}
// VariableMessage is serialized paddle variable message.

@ -99,7 +99,8 @@ static int64_t GetTimestamp() {
void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope,
const std::vector<int> &prefetch_block_id_list) const {
const std::vector<int> &prefetch_block_id_list,
const int checkpoint_point_block_id) const {
size_t num_blocks = program->Size();
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
@ -163,7 +164,8 @@ void ListenAndServOp::RunSyncLoop(
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program) const {
framework::ProgramDesc *program,
framework::Scope *recv_scope) const {
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
@ -190,6 +192,10 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
block_list.push_back(blkid);
}
auto optimize_prepared = executor->Prepare(*program, block_list);
// execute global block if needed
if (block_list[0] == 1 && id_to_grad.count(1) == 0) {
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
}
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx;
@ -203,7 +209,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
while (true) {
if (rpc_service_->IsExit()) {
LOG(INFO) << "get exit!rpc_processor break!";
VLOG(4) << "get exit!rpc_processor break!";
break;
}
@ -218,6 +224,7 @@ static void FillRequestCtx(
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx,
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx,
distributed::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
@ -225,6 +232,7 @@ static void FillRequestCtx(
h->SetProgram(program);
h->SetPrefetchPreparedCtx(prefetch_ctx);
h->SetRPCServer(rpc_server);
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
}
void ListenAndServOp::RunImpl(const framework::Scope &scope,
@ -240,9 +248,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint;
VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint
<< ", checkpoint_block_id: " << checkpoint_block_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
@ -250,6 +260,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get());
@ -257,6 +269,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get());
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
@ -265,6 +279,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
if (checkpoint_block_id != -1) {
auto ctx = executor.Prepare(*program, checkpoint_block_id);
// see: https://stackoverflow.com/a/14856553
ckpt_pre_context = std::move(ctx);
}
// prepare for prefetch
std::vector<int> prefetch_block_id_list;
std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
@ -295,13 +316,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
}
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program,
&prefetch_var_name_to_prepared_ctx, rpc_service_.get());
auto f =
std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx,
&executor, program, &prefetch_var_name_to_prepared_ctx,
ckpt_pre_context, rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
@ -315,9 +338,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use.
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
checkpoint_block_id);
} else {
RunAsyncLoop(&executor, program);
RunAsyncLoop(&executor, program, &recv_scope);
}
}
@ -347,6 +371,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
AddAttr<int>(kCheckpointBlockId,
"BolckID to run save checkpoint on pserer.")
.SetDefault(-1);
}
};

@ -32,6 +32,7 @@ namespace operators {
constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
constexpr char kCheckpointBlockId[] = "checkpint_block_id";
void RunServer(std::shared_ptr<distributed::RPCServer> service);
@ -47,10 +48,12 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
const std::vector<int>& prefetch_block_id_list) const;
const std::vector<int>& prefetch_block_id_list,
const int checkpoint_point_block_id) const;
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program) const;
framework::ProgramDesc* program,
framework::Scope* recv_scope) const;
void SavePort() const;
@ -67,6 +70,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_checkpoint_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
};

@ -34,6 +34,8 @@ class LoadOp : public framework::OperatorBase {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
@ -44,9 +46,25 @@ class LoadOp : public framework::OperatorBase {
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_name);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var);
} else if (out_var->IsType<framework::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var);
} else {
PADDLE_ENFORCE(
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
out_var_name);
}
}
DeserializeFromStream(fin, tensor, *dev_ctx);
void LoadLodTensor(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = framework::ToDataType(tensor->type());
@ -63,18 +81,27 @@ class LoadOp : public framework::OperatorBase {
&fp16_tensor);
// reset output tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
var->Clear();
tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
void LoadSelectedRows(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
}
};
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "The tensor need to be loaded");
AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded");
AddAttr<bool>(
"load_as_fp16",
"If true, the tensor will be first loaded and then "
@ -85,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
R"(Variable will be loaded from "file_path")")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
AddComment("Load operator will load a tensor variable from disk file.");
AddComment(
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"file.");
}
};
} // namespace operators

@ -18,10 +18,7 @@
#include "paddle/fluid/framework/tensor.h"
#ifdef PADDLE_WITH_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_service.h>
#include <mkl_vml_functions.h>
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef PADDLE_USE_OPENBLAS
@ -55,7 +52,7 @@ static void SetNumThreads(int num_threads) {
openblas_set_num_threads(real_num_threads);
#elif defined(PADDLE_WITH_MKLML)
int real_num_threads = num_threads > 1 ? num_threads : 1;
mkl_set_num_threads(real_num_threads);
platform::dynload::MKL_Set_Num_Threads(real_num_threads);
#else
PADDLE_ENFORCE(false, "To be implemented.");
#endif

@ -22,61 +22,109 @@ namespace math {
template <typename T>
struct CBlas;
#ifdef PADDLE_WITH_MKLML
template <>
struct CBlas<float> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
cblas_sgemm(args...);
platform::dynload::cblas_sgemm(args...);
}
template <typename... ARGS>
static void AXPY(ARGS... args) {
cblas_saxpy(args...);
platform::dynload::cblas_saxpy(args...);
}
template <typename... ARGS>
static void VCOPY(ARGS... args) {
platform::dynload::cblas_scopy(args...);
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
platform::dynload::cblas_sgemv(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...);
}
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void VADD(ARGS... args) {
vsAdd(args...);
platform::dynload::vsAdd(args...);
}
};
template <>
struct CBlas<double> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
platform::dynload::cblas_dgemm(args...);
}
template <typename... ARGS>
static void AXPY(ARGS... args) {
platform::dynload::cblas_daxpy(args...);
}
#endif
template <typename... ARGS>
static void VCOPY(ARGS... args) {
cblas_scopy(args...);
platform::dynload::cblas_dcopy(args...);
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
cblas_sgemv(args...);
platform::dynload::cblas_dgemv(args...);
}
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
cblas_sgemm_batch(args...);
platform::dynload::cblas_dgemm_batch(args...);
}
template <typename... ARGS>
static void VADD(ARGS... args) {
platform::dynload::vdAdd(args...);
}
#endif
};
#else
template <>
struct CBlas<double> {
struct CBlas<float> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
cblas_dgemm(args...);
cblas_sgemm(args...);
}
template <typename... ARGS>
static void AXPY(ARGS... args) {
cblas_daxpy(args...);
cblas_saxpy(args...);
}
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void VADD(ARGS... args) {
vdAdd(args...);
static void VCOPY(ARGS... args) {
cblas_scopy(args...);
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
cblas_sgemv(args...);
}
};
template <>
struct CBlas<double> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
cblas_dgemm(args...);
}
template <typename... ARGS>
static void AXPY(ARGS... args) {
cblas_daxpy(args...);
}
#endif
template <typename... ARGS>
static void VCOPY(ARGS... args) {
@ -87,15 +135,8 @@ struct CBlas<double> {
static void GEMV(ARGS... args) {
cblas_dgemv(args...);
}
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
cblas_dgemm_batch(args...);
}
#endif
};
#endif
template <>
struct CBlas<platform::float16> {
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }

@ -14,9 +14,7 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef PADDLE_USE_OPENBLAS

@ -139,6 +139,7 @@ TEST(LoadFP16Op, CPU) {
save_op->Run(scope, place);
auto load_var = scope.Var("out_var");
load_var->GetMutable<paddle::framework::LoDTensor>();
auto load_op = paddle::framework::OpRegistry::CreateOp(
"load", {}, {{"Out", {"out_var"}}}, attrs);
load_op->Run(scope, place);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save