fix testing

fix_gru_py
typhoonzero 7 years ago
parent 0598a4b366
commit 82c61dbde3

@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
// stub context // stub context
SendProcessor* s = new SendProcessor(ch); SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out); s->Prepare(var_h, time_out);
s->response_call_back_ = NULL; s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);

@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor { class BaseProcessor {
public: public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { context_ = NULL; } explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
context_ = nullptr;
}
virtual ~BaseProcessor() {} virtual ~BaseProcessor() {}
@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
::grpc::GenericStub stub_g_; ::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_; ::grpc::ByteBuffer reply_;
RequestSendCallBack response_call_back_ = NULL; RequestSendCallBack response_call_back_ = nullptr;
}; };
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>

@ -261,8 +261,8 @@ void AsyncGRPCServer::ShutdownQueue() {
// This URL explains why shutdown is complicate: // This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() { void AsyncGRPCServer::ShutDown() {
is_shut_down_ = true; is_shut_down_ = true;
ShutdownQueue();
server_->Shutdown(); server_->Shutdown();
ShutdownQueue();
} }
void AsyncGRPCServer::TryToRegisterNewSendOne() { void AsyncGRPCServer::TryToRegisterNewSendOne() {

@ -47,6 +47,8 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode) explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode) {} : address_(address), sync_mode_(sync_mode) {}
~AsyncGRPCServer() {}
void RunSyncUpdate(); void RunSyncUpdate();
// functions to sync server barrier status. // functions to sync server barrier status.

@ -53,14 +53,15 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 1); e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
} else if (var->IsType<ncclUniqueId>()) { } else if (var->IsType<ncclUniqueId>()) {
// NOTE: sendrecv only support RAW type for NCCL_ID // NOTE: sendrecv only support RAW type for NCCL_ID
VLOG(3) << "serilizing: setting var type nccl id";
e.WriteUint64(VarMsg::kTypeFieldNumber, 2); e.WriteUint64(VarMsg::kTypeFieldNumber, 2);
} }
if (!out_name.empty()) { if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name); e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
} }
switch (framework::ToVarType(var->Type())) { if (var->IsType<framework::LoDTensor>()) {
case framework::proto::VarType_Type_LOD_TENSOR: { // ===========================Tensor==================================
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber, e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(tensor.type())); framework::ToDataType(tensor.type()));
@ -86,8 +87,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place())); PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto& gpu_dev_ctx = auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type()); auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
payload = memory::Alloc(cpu, copy_size); payload = memory::Alloc(cpu, copy_size);
@ -107,8 +107,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} }
payload_size = tensor.numel() * framework::SizeOfType(tensor.type()); payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break; } else if (var->IsType<framework::SelectedRows>()) {
case framework::proto::VarType_Type_SELECTED_ROWS: { // ===========================SELECTED
// ROWS==================================
// TODO(typhoonzero): selectedrows implement should not use unique_ptr // TODO(typhoonzero): selectedrows implement should not use unique_ptr
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber, e.WriteUint64(VarMsg::kDataTypeFieldNumber,
@ -122,10 +123,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto& gpu_dev_ctx = auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
static_cast<const platform::CUDADeviceContext&>(ctx); auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
auto copy_size =
tensor->numel() * framework::SizeOfType(tensor->type());
payload = memory::Alloc(cpu, copy_size); payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload, memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor->place()), boost::get<platform::CUDAPlace>(tensor->place()),
@ -142,20 +141,18 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} }
payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break; } else if (var->IsType<ncclUniqueId>()) {
case framework::proto::VarType_Type_RAW: { // ===========================NCCL ID==================================
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES); NCCL_UNIQUE_ID_BYTES);
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>(); ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES)); e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
} break; } else {
default:
PADDLE_THROW("Serialize does not support type: %s", PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name()); typeid(var->Type()).name());
break;
} }
if (framework::ToVarType(var->Type()) == framework::proto::VarType_Type_RAW) { if (var->IsType<ncclUniqueId>()) {
// for serialize NCCL_ID // for serialize NCCL_ID
::grpc::Slice slices(e.size()); ::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size()); memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());

@ -371,19 +371,26 @@ int VariableResponse::Parse(Source* source) {
meta_.type() == sendrecv::NCCL_ID) && meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); "meta info should be got first!");
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
if (meta_.type() == sendrecv::NCCL_ID) { if (meta_.type() == sendrecv::NCCL_ID) {
VLOG(3) << "parse nccl id request";
auto* var = scope_->FindVar(meta_.varname()); auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) { if (var != nullptr) {
VLOG(3) << "parse nccl id: length " << length;
ncclUniqueId* id = var->GetMutable<ncclUniqueId>(); ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
memcpy(id->internal, meta_.serialized().c_str(), if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
meta_.serialized().size()); length)) {
return tag;
} }
// memcpy(id->internal, meta_.serialized().c_str(),
// meta_.serialized().size());
} }
break;
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
} }
framework::DDim dims = GetDims(meta_.dims()); framework::DDim dims = GetDims(meta_.dims());

@ -37,7 +37,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place); // put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace());
int trainer_id = Attr<int>("trainer_id"); int trainer_id = Attr<int>("trainer_id");
framework::Scope& local_scope = scope.NewScope(); framework::Scope& local_scope = scope.NewScope();
@ -60,9 +61,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient client; detail::RPCClient client;
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID"); client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID");
} }
client.Wait(); client.Wait();
VLOG(3) << "sending completed...";
} }
void GetIdByServer(framework::Scope* scope, void GetIdByServer(framework::Scope* scope,
@ -78,9 +81,14 @@ class GenNCCLIdOp : public framework::OperatorBase {
server_thread_.reset(new std::thread(std::bind( server_thread_.reset(new std::thread(std::bind(
&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_.get()))); &detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_.get())));
rpc_service_->SetCond(0);
VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service_->Get(); auto recv = rpc_service_->Get();
rpc_service_->ShutDown(); VLOG(3) << "got nccl id and stop server...";
// rpc_service_->SetCond(1);
// rpc_service_->ShutDown();
rpc_service->Push(LISTEN_TERMINATE_MESSAGE);
VLOG(3) << "rpc server stopped";
// TODO(wuyi): reinit nccl communicators // TODO(wuyi): reinit nccl communicators
} }

Loading…
Cancel
Save