parent
020630b7a3
commit
da3087ada1
@ -1 +1 @@
|
|||||||
grpc_library(sendrecvop_grpc SRCS recv_impl.cc send_impl.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||||
|
@ -0,0 +1,147 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "grpc_client.h"
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
bool RPCClient::AsyncSendVariable(const std::string& ep,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
const framework::Scope& scope,
|
||||||
|
const std::string& var_name,
|
||||||
|
int64_t time_out) {
|
||||||
|
sendrecv::VariableMessage req;
|
||||||
|
auto* var = scope.FindVar(var_name);
|
||||||
|
SerializeToMessage(var_name, var, ctx, &req);
|
||||||
|
|
||||||
|
// varhandle
|
||||||
|
VarHandle var_h;
|
||||||
|
var_h.ep = ep;
|
||||||
|
var_h.scope = &scope;
|
||||||
|
var_h.name = var_name;
|
||||||
|
var_h.ctx = &ctx;
|
||||||
|
|
||||||
|
// stub context
|
||||||
|
auto ch = GetChannel(ep);
|
||||||
|
SendProcessor* s = new SendProcessor(ch);
|
||||||
|
s->Prepare(var_h, time_out);
|
||||||
|
s->response_call_back_ = NULL;
|
||||||
|
|
||||||
|
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
|
||||||
|
rpc->Finish(&s->reply_, &s->status_, (void*)s);
|
||||||
|
|
||||||
|
req_count_++;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProcGetResponse(const VarHandle& var_h,
|
||||||
|
const sendrecv::VariableMessage& ret_msg) {
|
||||||
|
auto* outvar = var_h.scope->FindVar(var_h.name);
|
||||||
|
|
||||||
|
std::istringstream iss(ret_msg.serialized());
|
||||||
|
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RPCClient::AsyncGetVariable(const std::string& ep,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
const framework::Scope& scope,
|
||||||
|
const std::string& var_name,
|
||||||
|
int64_t time_out) {
|
||||||
|
sendrecv::VariableMessage req;
|
||||||
|
req.set_varname(var_name);
|
||||||
|
|
||||||
|
auto* var = scope.FindVar(var_name);
|
||||||
|
SerializeToMessage(var_name, var, ctx, &req);
|
||||||
|
|
||||||
|
// varhandle
|
||||||
|
VarHandle var_h;
|
||||||
|
var_h.ep = ep;
|
||||||
|
var_h.scope = &scope;
|
||||||
|
var_h.name = var_name;
|
||||||
|
var_h.ctx = &ctx;
|
||||||
|
|
||||||
|
// stub context
|
||||||
|
auto ch = GetChannel(ep);
|
||||||
|
GetProcessor* s = new GetProcessor(ch);
|
||||||
|
s->Prepare(var_h, time_out);
|
||||||
|
s->response_call_back_ = ProcGetResponse;
|
||||||
|
|
||||||
|
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
|
||||||
|
rpc->Finish(&s->reply_, &s->status_, (void*)s);
|
||||||
|
|
||||||
|
req_count_++;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RPCClient::wait() {
|
||||||
|
bool ok = true;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (req_count_ <= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Proceed()) {
|
||||||
|
LOG(ERROR) << "Get meets CompletionQueue error";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RPCClient::Proceed() {
|
||||||
|
void* tag = NULL;
|
||||||
|
bool ok = false;
|
||||||
|
|
||||||
|
// request counts.
|
||||||
|
if (!cq_.Next(&tag, &ok)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
req_count_--;
|
||||||
|
|
||||||
|
GPR_ASSERT(ok);
|
||||||
|
PADDLE_ENFORCE(tag);
|
||||||
|
|
||||||
|
// TODO(gongwb): add more retries.
|
||||||
|
ClientBase* c = static_cast<ClientBase*>(tag);
|
||||||
|
if (!c->status_.ok()) {
|
||||||
|
delete c;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
c->Process();
|
||||||
|
delete c;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
|
||||||
|
auto it = channels_.find(ep);
|
||||||
|
if (it != channels_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ch = std::shared_ptr<grpc::Channel>(
|
||||||
|
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
|
||||||
|
|
||||||
|
channels_[ep] = ch;
|
||||||
|
return ch;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,147 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <grpc++/grpc++.h>
|
||||||
|
#include <grpc/support/log.h>
|
||||||
|
#include <time.h>
|
||||||
|
#include <chrono>
|
||||||
|
#include <ctime>
|
||||||
|
#include <functional>
|
||||||
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/framework/data_type.h"
|
||||||
|
#include "paddle/framework/lod_tensor.h"
|
||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
#include "paddle/framework/selected_rows.h"
|
||||||
|
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||||
|
#include "paddle/operators/detail/simple_block_queue.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
struct VarHandle {
|
||||||
|
std::string ep;
|
||||||
|
const platform::DeviceContext* ctx;
|
||||||
|
const framework::Scope* scope;
|
||||||
|
std::string name;
|
||||||
|
|
||||||
|
std::string String() const {
|
||||||
|
std::ostringstream s;
|
||||||
|
s << "name:[" << name << "] ep:[" << ep << "]";
|
||||||
|
return s.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void ProcGetResponse(const VarHandle& var_h,
|
||||||
|
const sendrecv::VariableMessage& msg);
|
||||||
|
|
||||||
|
class ClientBase {
|
||||||
|
public:
|
||||||
|
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) {
|
||||||
|
stub_ = sendrecv::SendRecvService::NewStub(ch);
|
||||||
|
context_ = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~ClientBase() {}
|
||||||
|
|
||||||
|
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
|
||||||
|
context_.reset(new grpc::ClientContext());
|
||||||
|
var_h_ = var_info;
|
||||||
|
|
||||||
|
std::chrono::system_clock::time_point deadline =
|
||||||
|
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
|
||||||
|
|
||||||
|
context_->set_deadline(deadline);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void Process() = 0;
|
||||||
|
|
||||||
|
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
|
||||||
|
std::unique_ptr<grpc::ClientContext> context_;
|
||||||
|
grpc::Status status_;
|
||||||
|
VarHandle var_h_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
|
||||||
|
RequestSendCallBack;
|
||||||
|
|
||||||
|
class SendProcessor : public ClientBase {
|
||||||
|
public:
|
||||||
|
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
|
||||||
|
|
||||||
|
virtual ~SendProcessor() {}
|
||||||
|
|
||||||
|
virtual void Process() {
|
||||||
|
if (response_call_back_) {
|
||||||
|
response_call_back_(var_h_, reply_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendrecv::VoidMessage reply_;
|
||||||
|
RequestSendCallBack response_call_back_ = NULL;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
|
||||||
|
RequestGetCallBack;
|
||||||
|
|
||||||
|
class GetProcessor : public ClientBase {
|
||||||
|
public:
|
||||||
|
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
|
||||||
|
|
||||||
|
virtual ~GetProcessor() {}
|
||||||
|
|
||||||
|
virtual void Process() {
|
||||||
|
if (response_call_back_) {
|
||||||
|
response_call_back_(var_h_, reply_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendrecv::VariableMessage reply_;
|
||||||
|
RequestGetCallBack response_call_back_ = ProcGetResponse;
|
||||||
|
};
|
||||||
|
|
||||||
|
class RPCClient {
|
||||||
|
public:
|
||||||
|
bool AsyncSendVariable(const std::string& ep,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
const framework::Scope& scope,
|
||||||
|
const std::string& var_name,
|
||||||
|
int64_t time_out = 600 * 1000);
|
||||||
|
|
||||||
|
bool AsyncGetVariable(const std::string& ep,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
const framework::Scope& scope,
|
||||||
|
const std::string& var_name,
|
||||||
|
int64_t time_out = 600 * 1000);
|
||||||
|
bool wait();
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool Proceed();
|
||||||
|
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
|
||||||
|
|
||||||
|
private:
|
||||||
|
grpc::CompletionQueue cq_;
|
||||||
|
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
|
||||||
|
int64_t req_count_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,237 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/detail/grpc_server.h"
|
||||||
|
|
||||||
|
using grpc::ServerAsyncResponseWriter;
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
enum CallStatus { PROCESS = 0, FINISH };
|
||||||
|
|
||||||
|
// reference:
|
||||||
|
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
|
||||||
|
class RequestBase {
|
||||||
|
public:
|
||||||
|
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
|
||||||
|
grpc::ServerCompletionQueue* cq)
|
||||||
|
: service_(service), cq_(cq), status_(PROCESS) {}
|
||||||
|
virtual ~RequestBase() {}
|
||||||
|
virtual void Process() { assert(false); }
|
||||||
|
|
||||||
|
CallStatus Status() { return status_; }
|
||||||
|
void SetStatus(CallStatus status) { status_ = status; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
grpc::ServerContext ctx_;
|
||||||
|
sendrecv::SendRecvService::AsyncService* service_;
|
||||||
|
grpc::ServerCompletionQueue* cq_;
|
||||||
|
CallStatus status_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||||
|
|
||||||
|
class RequestSend final : public RequestBase {
|
||||||
|
public:
|
||||||
|
explicit RequestSend(sendrecv::SendRecvService::AsyncService* service,
|
||||||
|
grpc::ServerCompletionQueue* cq,
|
||||||
|
SimpleBlockQueue<MessageWithName>* queue)
|
||||||
|
: RequestBase(service, cq), queue_(queue), responder_(&ctx_) {
|
||||||
|
service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_,
|
||||||
|
this);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~RequestSend() {}
|
||||||
|
|
||||||
|
virtual void Process() {
|
||||||
|
MessageWithName msg_with_name =
|
||||||
|
std::make_pair(request_.varname(), std::move(request_));
|
||||||
|
queue_->Push(std::move(msg_with_name));
|
||||||
|
// TODO(gongwb): check var's info.
|
||||||
|
responder_.Finish(reply_, grpc::Status::OK, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
sendrecv::VariableMessage request_;
|
||||||
|
sendrecv::VoidMessage reply_;
|
||||||
|
SimpleBlockQueue<MessageWithName>* queue_;
|
||||||
|
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class RequestGet final : public RequestBase {
|
||||||
|
public:
|
||||||
|
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
|
||||||
|
grpc::ServerCompletionQueue* cq, framework::Scope* scope)
|
||||||
|
: RequestBase(service, cq), responder_(&ctx_), scope_(scope) {
|
||||||
|
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~RequestGet() {}
|
||||||
|
|
||||||
|
virtual void Process() {
|
||||||
|
// proc request.
|
||||||
|
std::string var_name = request_.varname();
|
||||||
|
auto* var = scope_->FindVar(var_name);
|
||||||
|
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
|
||||||
|
// TODO(gongwb): check var's info.
|
||||||
|
responder_.Finish(reply_, grpc::Status::OK, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
sendrecv::VariableMessage request_;
|
||||||
|
sendrecv::VariableMessage reply_;
|
||||||
|
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
|
||||||
|
framework::Scope* scope_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void AsyncGRPCServer::RunSyncUpdate() {
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
|
||||||
|
builder.RegisterService(&service_);
|
||||||
|
|
||||||
|
cq_send_ = builder.AddCompletionQueue();
|
||||||
|
cq_get_ = builder.AddCompletionQueue();
|
||||||
|
server_ = builder.BuildAndStart();
|
||||||
|
LOG(INFO) << "Server listening on " << address_ << std::endl;
|
||||||
|
|
||||||
|
std::function<void()> send_register =
|
||||||
|
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
|
||||||
|
std::function<void()> get_register =
|
||||||
|
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
|
||||||
|
|
||||||
|
t_send_.reset(
|
||||||
|
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
|
||||||
|
cq_send_.get(), "cq_send", send_register)));
|
||||||
|
|
||||||
|
t_get_.reset(
|
||||||
|
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
|
||||||
|
cq_get_.get(), "cq_get", get_register)));
|
||||||
|
|
||||||
|
// wait server
|
||||||
|
server_->Wait();
|
||||||
|
t_send_->join();
|
||||||
|
t_get_->join();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::ShutdownQueue() {
|
||||||
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||||
|
cq_send_->Shutdown();
|
||||||
|
cq_get_->Shutdown();
|
||||||
|
is_shut_down_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This URL explains why shutdown is complicate:
|
||||||
|
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
|
||||||
|
void AsyncGRPCServer::ShutDown() {
|
||||||
|
server_->Shutdown();
|
||||||
|
ShutdownQueue();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::TryToRegisterNewSendOne() {
|
||||||
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||||
|
if (is_shut_down_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
RequestSend* send =
|
||||||
|
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
|
||||||
|
VLOG(4) << "create RequestSend status:" << send->Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::TryToRegisterNewGetOne() {
|
||||||
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||||
|
if (is_shut_down_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_);
|
||||||
|
VLOG(4) << "create Requestget status:" << get->Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
|
||||||
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||||
|
if (is_shut_down_) {
|
||||||
|
delete last;
|
||||||
|
last = NULL;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
last->SetStatus(FINISH);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
|
||||||
|
std::string cq_name,
|
||||||
|
std::function<void()> TryToRegisterNewOne) {
|
||||||
|
TryToRegisterNewOne();
|
||||||
|
|
||||||
|
void* tag = NULL;
|
||||||
|
bool ok = false;
|
||||||
|
while (true) {
|
||||||
|
if (!cq->Next(&tag, &ok)) {
|
||||||
|
LOG(INFO) << cq_name << " get CompletionQueue shutdown!";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wait && !done_) {
|
||||||
|
Wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
RequestBase* base = (RequestBase*)tag;
|
||||||
|
if (!ok) {
|
||||||
|
VLOG(4) << cq_name << " recv no regular event";
|
||||||
|
TryToRegisterNewOne();
|
||||||
|
delete base;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (base->Status()) {
|
||||||
|
case PROCESS: {
|
||||||
|
VLOG(4) << cq_name << " status:" << base->Status();
|
||||||
|
TryToRegisterNewOne();
|
||||||
|
base->Process();
|
||||||
|
SetFinishOrDelete(base);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case FINISH: {
|
||||||
|
VLOG(4) << cq_name << " status:" << base->Status();
|
||||||
|
delete base;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: { assert(false); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::Wait() {
|
||||||
|
std::unique_lock<std::mutex> lock(this->mutex_);
|
||||||
|
condition_.wait(lock, [=] { return this->done_ == true; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::Reset() {
|
||||||
|
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||||
|
done_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::Done() {
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||||
|
done_ = true;
|
||||||
|
}
|
||||||
|
condition_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,91 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/lod_tensor.h"
|
||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
#include "paddle/framework/selected_rows.h"
|
||||||
|
#include "paddle/framework/var_type.h"
|
||||||
|
#include "paddle/operators/detail/simple_block_queue.h"
|
||||||
|
|
||||||
|
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
||||||
|
#include "paddle/operators/detail/send_recv.pb.h"
|
||||||
|
|
||||||
|
#include <grpc++/grpc++.h>
|
||||||
|
#include <grpc/support/log.h>
|
||||||
|
#include <thread>
|
||||||
|
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||||
|
class RequestBase;
|
||||||
|
|
||||||
|
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
|
||||||
|
public:
|
||||||
|
explicit AsyncGRPCServer(std::string address) { address_ = address; }
|
||||||
|
|
||||||
|
void RunSyncUpdate();
|
||||||
|
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
void Done();
|
||||||
|
|
||||||
|
void SetScope(framework::Scope *scope) { scope_ = scope; }
|
||||||
|
|
||||||
|
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
|
||||||
|
|
||||||
|
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
|
||||||
|
|
||||||
|
void ShutDown();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void Wait();
|
||||||
|
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
|
||||||
|
std::string cq_name,
|
||||||
|
std::function<void()> TryToRegisterNewOne);
|
||||||
|
void TryToRegisterNewSendOne();
|
||||||
|
void TryToRegisterNewGetOne();
|
||||||
|
void SetFinishOrDelete(RequestBase *&last);
|
||||||
|
void ShutdownQueue();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex cq_mutex_;
|
||||||
|
volatile bool is_shut_down_ = false;
|
||||||
|
std::unique_ptr<grpc::ServerCompletionQueue> cq_send_;
|
||||||
|
std::unique_ptr<grpc::ServerCompletionQueue> cq_get_;
|
||||||
|
|
||||||
|
sendrecv::SendRecvService::AsyncService service_;
|
||||||
|
std::unique_ptr<grpc::Server> server_;
|
||||||
|
|
||||||
|
std::string address_;
|
||||||
|
framework::Scope *scope_;
|
||||||
|
// received variable from RPC, operators fetch variable from this queue.
|
||||||
|
SimpleBlockQueue<MessageWithName> var_recv_queue_;
|
||||||
|
|
||||||
|
// condition of the sub program
|
||||||
|
std::mutex mutex_;
|
||||||
|
volatile mutable bool done_;
|
||||||
|
std::condition_variable condition_;
|
||||||
|
|
||||||
|
std::unique_ptr<std::thread> t_send_;
|
||||||
|
std::unique_ptr<std::thread> t_get_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}; // namespace detail
|
||||||
|
}; // namespace operators
|
||||||
|
}; // namespace paddle
|
@ -1,65 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#include "send_recv_impl.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace operators {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
Status SendRecvServerImpl::SendVariable(ServerContext *context,
|
|
||||||
const VariableMessage *in_var,
|
|
||||||
VoidMessage *out_var) {
|
|
||||||
MessageWithName msg_with_name =
|
|
||||||
std::make_pair(in_var->varname(), std::move(*in_var));
|
|
||||||
var_recv_queue_.Push(std::move(msg_with_name));
|
|
||||||
return Status::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SendRecvServerImpl::GetVariable(ServerContext *context,
|
|
||||||
const VariableMessage *in_var,
|
|
||||||
VariableMessage *out_var) {
|
|
||||||
std::string get_var_name = in_var->varname();
|
|
||||||
auto *var = scope_->FindVar(get_var_name);
|
|
||||||
|
|
||||||
SerializeToMessage(get_var_name, var, platform::CPUDeviceContext(), out_var);
|
|
||||||
return Status::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SendRecvServerImpl::Wait(ServerContext *context,
|
|
||||||
const VoidMessage *in_var,
|
|
||||||
VoidMessage *out_var) {
|
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> lock(this->mutex_);
|
|
||||||
condition_.wait(lock, [=] { return this->done_ == true; });
|
|
||||||
}
|
|
||||||
return Status::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SendRecvServerImpl::Reset() {
|
|
||||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
|
||||||
done_ = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SendRecvServerImpl::Done() {
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
|
||||||
done_ = true;
|
|
||||||
}
|
|
||||||
condition_.notify_all();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace operators
|
|
||||||
} // namespace paddle
|
|
@ -1,67 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#include "send_recv_impl.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace operators {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
bool RPCClient::SendVariable(const framework::Scope& scope,
|
|
||||||
const std::string& inname) {
|
|
||||||
ClientContext context;
|
|
||||||
VariableMessage msg;
|
|
||||||
VoidMessage out_msg;
|
|
||||||
// FIXME(typhoonzero): pass device context to here.
|
|
||||||
auto ctx = platform::CPUDeviceContext();
|
|
||||||
auto* var = scope.FindVar(inname);
|
|
||||||
PADDLE_ENFORCE(var);
|
|
||||||
SerializeToMessage(inname, var, ctx, &msg);
|
|
||||||
|
|
||||||
Status status = stub_->SendVariable(&context, msg, &out_msg);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "gRPC error: " << status.error_message();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool RPCClient::GetVariable(const framework::Scope& scope,
|
|
||||||
const std::string& outname) {
|
|
||||||
ClientContext context;
|
|
||||||
VariableMessage call_msg, ret_msg;
|
|
||||||
call_msg.set_varname(outname);
|
|
||||||
auto ctx = platform::CPUDeviceContext();
|
|
||||||
Status status = stub_->GetVariable(&context, call_msg, &ret_msg);
|
|
||||||
auto* outvar = scope.FindVar(outname);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "gRPC error: " << status.error_message();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::istringstream iss(ret_msg.serialized());
|
|
||||||
DeserializeFromMessage(ret_msg, ctx, outvar);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RPCClient::Wait() {
|
|
||||||
ClientContext context;
|
|
||||||
VoidMessage call_msg, ret_msg;
|
|
||||||
stub_->Wait(&context, call_msg, &ret_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace operators
|
|
||||||
} // namespace paddle
|
|
@ -1,141 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "paddle/framework/lod_tensor.h"
|
|
||||||
#include "paddle/framework/scope.h"
|
|
||||||
#include "paddle/framework/selected_rows.h"
|
|
||||||
#include "paddle/framework/var_type.h"
|
|
||||||
#include "paddle/operators/detail/simple_block_queue.h"
|
|
||||||
|
|
||||||
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
|
||||||
#include "paddle/operators/detail/send_recv.pb.h"
|
|
||||||
|
|
||||||
#include <grpc++/grpc++.h>
|
|
||||||
|
|
||||||
using grpc::Channel;
|
|
||||||
using grpc::Server;
|
|
||||||
using grpc::ServerContext;
|
|
||||||
using grpc::ServerReader;
|
|
||||||
using grpc::ServerBuilder;
|
|
||||||
|
|
||||||
using grpc::ClientContext;
|
|
||||||
using grpc::ClientReader;
|
|
||||||
using grpc::ClientReaderWriter;
|
|
||||||
using grpc::ClientWriter;
|
|
||||||
using grpc::Status;
|
|
||||||
using sendrecv::SendRecvService;
|
|
||||||
using sendrecv::VariableMessage;
|
|
||||||
using sendrecv::VoidMessage;
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace operators {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
|
||||||
|
|
||||||
class SendRecvServerImpl final : public SendRecvService::Service {
|
|
||||||
public:
|
|
||||||
explicit SendRecvServerImpl() {}
|
|
||||||
|
|
||||||
Status SendVariable(ServerContext *context, const VariableMessage *in_var,
|
|
||||||
VoidMessage *out_var) override;
|
|
||||||
Status GetVariable(ServerContext *context, const VariableMessage *in_var,
|
|
||||||
VariableMessage *out_var) override;
|
|
||||||
Status Wait(ServerContext *context, const VoidMessage *in_var,
|
|
||||||
VoidMessage *out_var) override;
|
|
||||||
void Reset();
|
|
||||||
void Done();
|
|
||||||
void SetScope(framework::Scope *scope) { scope_ = scope; };
|
|
||||||
|
|
||||||
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
|
|
||||||
|
|
||||||
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
// received variable from RPC, operators fetch variable from this queue.
|
|
||||||
SimpleBlockQueue<MessageWithName> var_recv_queue_;
|
|
||||||
framework::Scope *scope_;
|
|
||||||
// condition of the sub program
|
|
||||||
std::mutex mutex_;
|
|
||||||
bool done_;
|
|
||||||
std::condition_variable condition_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// RPCClient is a class to send tensors to pserver sub-network
|
|
||||||
// using different hashing methods.
|
|
||||||
class RPCClient {
|
|
||||||
public:
|
|
||||||
RPCClient(std::shared_ptr<Channel> channel)
|
|
||||||
: stub_(SendRecvService::NewStub(channel)) {}
|
|
||||||
|
|
||||||
bool SendVariable(const framework::Scope &scope, const std::string &inname);
|
|
||||||
bool GetVariable(const framework::Scope &scope, const std::string &outname);
|
|
||||||
void Wait();
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unique_ptr<SendRecvService::Stub> stub_;
|
|
||||||
};
|
|
||||||
|
|
||||||
inline void SerializeToMessage(const std::string &name,
|
|
||||||
const framework::Variable *var,
|
|
||||||
const platform::DeviceContext &ctx,
|
|
||||||
VariableMessage *msg) {
|
|
||||||
msg->set_varname(name);
|
|
||||||
std::ostringstream oss;
|
|
||||||
switch (framework::ToVarType(var->Type())) {
|
|
||||||
case framework::proto::VarDesc_VarType_LOD_TENSOR:
|
|
||||||
msg->set_type(sendrecv::VarType::LOD_TENSOR);
|
|
||||||
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
|
|
||||||
break;
|
|
||||||
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
|
|
||||||
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
|
|
||||||
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
|
|
||||||
ctx);
|
|
||||||
break;
|
|
||||||
default: {
|
|
||||||
PADDLE_THROW("Serialize does not support type: %s",
|
|
||||||
typeid(var->Type()).name());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg->set_serialized(oss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void DeserializeFromMessage(const VariableMessage &msg,
|
|
||||||
const platform::DeviceContext &ctx,
|
|
||||||
framework::Variable *var) {
|
|
||||||
using namespace paddle::framework::proto;
|
|
||||||
std::istringstream iss(msg.serialized());
|
|
||||||
switch (msg.type()) {
|
|
||||||
case sendrecv::VarType::LOD_TENSOR:
|
|
||||||
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
|
|
||||||
break;
|
|
||||||
case sendrecv::VarType::SELECTED_ROWS: {
|
|
||||||
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
|
|
||||||
ctx);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
PADDLE_THROW("Deserialize does not support type: %s",
|
|
||||||
typeid(var->Type()).name());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace operators
|
|
||||||
} // namespace paddle
|
|
@ -0,0 +1,68 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
void SerializeToMessage(const std::string& name, const framework::Variable* var,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
sendrecv::VariableMessage* msg) {
|
||||||
|
msg->set_varname(name);
|
||||||
|
std::ostringstream oss;
|
||||||
|
switch (framework::ToVarType(var->Type())) {
|
||||||
|
case framework::proto::VarDesc_VarType_LOD_TENSOR:
|
||||||
|
msg->set_type(sendrecv::VarType::LOD_TENSOR);
|
||||||
|
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
|
||||||
|
break;
|
||||||
|
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
|
||||||
|
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
|
||||||
|
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
|
||||||
|
ctx);
|
||||||
|
break;
|
||||||
|
default: {
|
||||||
|
PADDLE_THROW("Serialize does not support type: %s",
|
||||||
|
typeid(var->Type()).name());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg->set_serialized(oss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
framework::Variable* var) {
|
||||||
|
std::istringstream iss(msg.serialized());
|
||||||
|
switch (msg.type()) {
|
||||||
|
case sendrecv::VarType::LOD_TENSOR:
|
||||||
|
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
|
||||||
|
break;
|
||||||
|
case sendrecv::VarType::SELECTED_ROWS: {
|
||||||
|
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
|
||||||
|
ctx);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
PADDLE_THROW("Deserialize does not support type: %s",
|
||||||
|
typeid(var->Type()).name());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/framework/data_type.h"
|
||||||
|
#include "paddle/framework/lod_tensor.h"
|
||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
#include "paddle/framework/selected_rows.h"
|
||||||
|
#include "paddle/framework/var_type.h"
|
||||||
|
|
||||||
|
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
||||||
|
#include "paddle/operators/detail/send_recv.pb.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
void SerializeToMessage(const std::string& name, const framework::Variable* var,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
sendrecv::VariableMessage* msg);
|
||||||
|
|
||||||
|
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
framework::Variable* var);
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
Loading…
Reference in new issue