Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_cudnn_lstm

add_cudnn_lstm
phlrain 6 years ago
commit 6ce4250172

@ -862,7 +862,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
}
p = places_[outvar_dev_id];
ir::Node *new_node = nullptr;

@ -37,7 +37,13 @@ if (WITH_GPU)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub)
endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS})
SET(OP_PREFETCH_DEPS "")
if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
# warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32)

@ -9,36 +9,37 @@ else()
endif()
configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
return()
endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory)
else()
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory)
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
cc_test(brpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL)
endif()

@ -171,11 +171,13 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
@ -186,11 +188,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
s, method, h, this] {
s, method, h, table_name_val, this] {
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
0, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";

@ -194,6 +194,7 @@ class GRPCClient : public RPCClient {
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendBatchBarrier(

@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, const std::string& out_name,
const int trainer_id) {
const int trainer_id,
const std::string& table_name) {
platform::RecordRPCEvent record_event("serial", &ctx);
VarMsg request;
TensorPayload* payload = nullptr;
@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
if (!table_name.empty()) {
request.set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
payload = new TensorPayload(GetTensorPayload(var, ctx, &request));

@ -40,7 +40,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string(),
const int trainer_id = 0);
const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,

@ -130,7 +130,8 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg;
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg,
"outvar", 0, "table_name");
EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize

@ -183,6 +183,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process...
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
@ -193,7 +194,7 @@ class RequestPrefetch final : public RequestBase {
framework::Variable* outvar = scope->Var(out_var_name);
request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name);
out_var_name, table_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);

@ -301,6 +301,20 @@ int GRPCVariableResponse::Parse(Source* source) {
meta_.set_trainer_id(trainer_id);
break;
}
case sendrecv::VariableMessage::kTableNameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_table_name(temp);
break;
}
default: {
// Unknown tag, return unknown error.
return -1;

File diff suppressed because it is too large Load Diff

@ -0,0 +1,34 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
namespace distributed {
void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<int>& height_sections,
const framework::ExecutionContext& context);
}; // namespace distributed
}; // namespace operators
}; // namespace paddle

@ -191,7 +191,8 @@ class RequestHandler {
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") = 0;
const std::string& out_var_name = "",
const std::string& table_name = "") = 0;
protected:
const bool sync_mode_;

@ -37,7 +37,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestSendHandler:" << varname;
// Sync
@ -77,8 +78,10 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetHandler:" << varname;
if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message";
@ -113,14 +116,22 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
if (table_name.empty()) {
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
} else {
(*outvar)->GetMutable<framework::LoDTensor>();
auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place;
lookup_table_op->Run(*scope, cpu_place);
}
return true;
}
@ -129,7 +140,8 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");

@ -24,6 +24,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
@ -43,8 +44,8 @@ class RequestSendHandler final : public RequestHandler {
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
bool enable_dc_asgd_;
@ -59,21 +60,44 @@ class RequestGetHandler final : public RequestHandler {
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
bool enable_dc_asgd_;
};
static inline void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
class RequestPrefetchHandler final : public RequestHandler {
public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("lookup_table");
BuildVar("W", {table_name.data()}, op_desc.add_inputs());
BuildVar("Ids", {id_name.data()}, op_desc.add_inputs());
BuildVar("Out", {out_name.data()}, op_desc.add_outputs());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
};
class RequestCheckpointHandler final : public RequestHandler {
@ -85,8 +109,8 @@ class RequestCheckpointHandler final : public RequestHandler {
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
int checkpoint_notify_id;

@ -48,7 +48,7 @@ class RPCClient {
virtual VarHandlePtr AsyncPrefetchVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& in_var_name,
const std::string& out_var_name,
const std::string& out_var_name, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendBatchBarrier(

@ -80,6 +80,7 @@ message VariableMessage {
// when profile switches from 1 to 2.
int64 profile = 11;
int64 trainer_id = 12;
string table_name = 13;
}
message VoidMessage {}

@ -85,6 +85,7 @@ class VariableResponse {
inline framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); }
inline std::string TableName() const { return meta_.table_name(); }
// should call parse first.
framework::Variable* GetVar() {

@ -87,6 +87,25 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"If the grad op reuse the input's variable.")
.SetDefault(false);
// for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int>({}));
AddAttr<std::vector<std::string>>(
"epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"table_names",
"(string vector, the splited table names that will be fetched from "
"parameter server)"
"in the order of input variables for mapping")
.SetDefault({});
AddComment(R"DOC(
Lookup Table Operator.

@ -78,27 +78,47 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
if (padding_idx == -1)
LookupTable<
T, 128, 8, 8,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<
T, 128, 8, 8,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
auto id_name = context.Inputs("Ids").front();
auto out_name = context.Outputs("Out").front();
// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections = context.Attr<std::vector<int>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
} else {
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
if (padding_idx == -1)
LookupTable<T, 128, 8, 8, false><<<
grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<T, 128, 8, 8, true><<<
grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
}
}
};
@ -109,6 +129,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {

@ -23,6 +23,10 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace paddle {
namespace operators {
@ -41,44 +45,66 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
auto id_name = context.Inputs("Ids").front();
auto out_name = context.Outputs("Out").front();
// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections = context.Attr<std::vector<int>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
} else {
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
}
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
}
}

@ -327,6 +327,11 @@ def embedding(input,
"""
helper = LayerHelper('embedding', **locals())
remote_prefetch = False
if os.environ.get('PADDLE_ENABLE_REMOTE_PREFETCH'):
remote_prefetch = True
if remote_prefetch:
assert is_sparse is True and is_distributed is False
w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
tmp = helper.create_variable_for_type_inference(dtype)
@ -340,6 +345,7 @@ def embedding(input,
attrs={
'is_sparse': is_sparse,
'is_distributed': is_distributed,
'remote_prefetch': remote_prefetch,
'padding_idx': padding_idx
})
return tmp

@ -16,11 +16,13 @@ from __future__ import print_function
import paddle
import paddle.fluid as fluid
import os
import dist_ctr_reader
from test_dist_base import TestDistRunnerBase, runtime_main
IS_SPARSE = True
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
# Fix seed for test
fluid.default_startup_program().random_seed = 1

@ -782,5 +782,46 @@ class TestNCCL2Transpile(TranspilerTest):
pass
# test for remote prefetch
class TestRemoteLookupTable(TestDistLookupTableBase):
def net_conf(self):
import os
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
self.network_with_table(is_sparse=True, is_distributed=False)
def transpiler_test_impl(self):
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
self.assertEqual(len(pserver1.blocks), 4)
# 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
["sum", "scale", "adam", "scale", "scale"])
# 2 optimize for table adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["sum", "scale", "adam", "scale", "scale"])
# 3 optimize for table 2 adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["sum", "scale", "adam", "scale", "scale"])
trainer, _ = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1)
ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv',
'recv', 'fetch_barrier'
]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,203 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import signal
import time
import unittest
from multiprocessing import Process
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard
def run_pserver(pserver_id, use_cuda, sync_mode):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
# create table parameter in scope
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
# create and initialize Param Variable
param = scope.var('table').get_tensor()
param_array = np.ones((10, 8)).astype("float32")
for i in range(len(param_array)):
param_array[i] *= param_array[i] * i + pserver_id * 10
param.set(param_array, place)
optimize_block = program._create_block(program.global_block().idx)
program.global_block().append_op(
type="listen_and_serv",
inputs={'X': []},
outputs={},
attrs={
"optimize_blocks": [optimize_block],
"endpoint": '127.0.0.1:0',
"Fanin": 1,
"sync_mode": True,
"grad_to_block_id": []
})
exe = fluid.Executor(place)
exe.run(program)
class TestListenAndServOp(unittest.TestCase):
def setUp(self):
self.ps_timeout = 5
def _start_pserver(self, pserver_id, use_cuda, sync_mode, pserver_func):
p = Process(target=pserver_func, args=(pserver_id, use_cuda, sync_mode))
p.daemon = True
p.start()
return p
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def _get_pserver_port(self, pid):
with open("/tmp/paddle.%d.port" % pid, 'r') as f:
port = int(f.read().strip())
return port
def _run_lookup_table_op_one_pserver(self, place, port):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
# create and initialize Param Variable
param = scope.var('W').get_tensor()
param_array = np.full((10, 8), 1.0).astype("float32")
param.set(param_array, place)
ids = scope.var('Ids').get_tensor()
ids_array = np.array([[1], [2], [5]]).astype("int64")
ids.set(ids_array, place)
ids_lod = [[0, 1, 2, 3]]
ids.set_lod(ids_lod)
out = scope.var('Out').get_tensor()
emaps = ['127.0.0.1:' + str(port)]
table_names = ['table']
height_sections = [10]
# create and run sgd operator
lookup_table_op = Operator(
"lookup_table",
W='W',
Ids='Ids',
Out='Out',
remote_prefetch=True,
epmap=emaps,
table_names=table_names,
height_sections=height_sections)
lookup_table_op.run(scope, place)
# get and compare result
result_array = np.array(out)
self.assertEqual(out.lod(), ids_lod)
self.assertEqual(list(result_array.shape), [len(ids_array), 8])
for i in range(len(ids_array)):
id = ids_array[i][0]
self.assertTrue((result_array[i] == id).all())
def _run_lookup_table_op_two_pserver(self, place, port0, port1):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
# create and initialize Param Variable
param = scope.var('W').get_tensor()
param_array = np.full((10, 8), 1.0).astype("float32")
param.set(param_array, place)
ids = scope.var('Ids').get_tensor()
ids_array = np.array([[1], [2], [11], [13]]).astype("int64")
ids.set(ids_array, place)
ids_lod = [[0, 2, 3, 4]]
ids.set_lod(ids_lod)
out = scope.var('Out').get_tensor()
emaps = ['127.0.0.1:' + str(port0), '127.0.0.1:' + str(port1)]
table_names = ['table', 'table']
height_sections = [10, 20]
# create and run sgd operator
lookup_table_op = Operator(
"lookup_table",
W='W',
Ids='Ids',
Out='Out',
remote_prefetch=True,
epmap=emaps,
table_names=table_names,
height_sections=height_sections)
lookup_table_op.run(scope, place)
# get and compare result
result_array = np.array(out)
self.assertEqual(out.lod(), ids_lod)
self.assertEqual(list(result_array.shape), [len(ids_array), 8])
for i in range(len(ids_array)):
id = ids_array[i][0]
self.assertTrue((result_array[i] == id).all())
def test_lookup_remote_table(self):
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
# run pserver on CPU in sync mode
p0 = self._start_pserver(0, False, True, run_pserver)
self._wait_ps_ready(p0.pid)
port0 = self._get_pserver_port(p0.pid)
p1 = self._start_pserver(1, False, True, run_pserver)
self._wait_ps_ready(p1.pid)
port1 = self._get_pserver_port(p1.pid)
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self._run_lookup_table_op_one_pserver(place, port0)
self._run_lookup_table_op_two_pserver(place, port0, port1)
# raise SIGTERM to pserver
os.kill(p0.pid, signal.SIGINT)
p0.join()
os.kill(p1.pid, signal.SIGINT)
p1.join()
if __name__ == '__main__':
unittest.main()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save