Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-7195
commit
a7c2bfb4a1
@ -0,0 +1,9 @@
|
||||
# Advbox
|
||||
|
||||
Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle.
|
||||
|
||||
## How to use
|
||||
|
||||
1. train a model and save it's parameters. (like fluid_mnist.py)
|
||||
2. load the parameters which is trained in step1, then reconstruct the model.(like mnist_tutorial_fgsm.py)
|
||||
3. use advbox to generate the adversarial sample.
|
@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
A set of tools for generating adversarial example on paddle platform
|
||||
"""
|
@ -0,0 +1,39 @@
|
||||
"""
|
||||
The base model of the model.
|
||||
"""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class Attack(object):
|
||||
"""
|
||||
Abstract base class for adversarial attacks. `Attack` represent an adversarial attack
|
||||
which search an adversarial example. subclass should implement the _apply() method.
|
||||
|
||||
Args:
|
||||
model(Model): an instance of the class advbox.base.Model.
|
||||
|
||||
"""
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def __call__(self, image_label):
|
||||
"""
|
||||
Generate the adversarial sample.
|
||||
|
||||
Args:
|
||||
image_label(list): The image and label tuple list with one element.
|
||||
"""
|
||||
adv_img = self._apply(image_label)
|
||||
return adv_img
|
||||
|
||||
@abstractmethod
|
||||
def _apply(self, image_label):
|
||||
"""
|
||||
Search an adversarial example.
|
||||
|
||||
Args:
|
||||
image_batch(list): The image and label tuple list with one element.
|
||||
"""
|
||||
raise NotImplementedError
|
@ -0,0 +1,38 @@
|
||||
"""
|
||||
This module provide the attack method for FGSM's implement.
|
||||
"""
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
from collections import Iterable
|
||||
from .base import Attack
|
||||
|
||||
|
||||
class GradientSignAttack(Attack):
|
||||
"""
|
||||
This attack was originally implemented by Goodfellow et al. (2015) with the
|
||||
infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called
|
||||
the Fast Gradient Method.
|
||||
Paper link: https://arxiv.org/abs/1412.6572
|
||||
"""
|
||||
|
||||
def _apply(self, image_label, epsilons=1000):
|
||||
assert len(image_label) == 1
|
||||
pre_label = np.argmax(self.model.predict(image_label))
|
||||
|
||||
min_, max_ = self.model.bounds()
|
||||
gradient = self.model.gradient(image_label)
|
||||
gradient_sign = np.sign(gradient) * (max_ - min_)
|
||||
|
||||
if not isinstance(epsilons, Iterable):
|
||||
epsilons = np.linspace(0, 1, num=epsilons + 1)
|
||||
|
||||
for epsilon in epsilons:
|
||||
adv_img = image_label[0][0].reshape(
|
||||
gradient_sign.shape) + epsilon * gradient_sign
|
||||
adv_img = np.clip(adv_img, min_, max_)
|
||||
adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
|
||||
if pre_label != adv_label:
|
||||
return adv_img
|
||||
|
||||
|
||||
FGSM = GradientSignAttack
|
@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Paddle model for target of attack
|
||||
"""
|
@ -0,0 +1,90 @@
|
||||
"""
|
||||
The base model of the model.
|
||||
"""
|
||||
from abc import ABCMeta
|
||||
import abc
|
||||
|
||||
abstractmethod = abc.abstractmethod
|
||||
|
||||
|
||||
class Model(object):
|
||||
"""
|
||||
Base class of model to provide attack.
|
||||
|
||||
|
||||
Args:
|
||||
bounds(tuple): The lower and upper bound for the image pixel.
|
||||
channel_axis(int): The index of the axis that represents the color channel.
|
||||
preprocess(tuple): Two element tuple used to preprocess the input. First
|
||||
substract the first element, then divide the second element.
|
||||
"""
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, bounds, channel_axis, preprocess=None):
|
||||
assert len(bounds) == 2
|
||||
assert channel_axis in [0, 1, 2, 3]
|
||||
|
||||
if preprocess is None:
|
||||
preprocess = (0, 1)
|
||||
self._bounds = bounds
|
||||
self._channel_axis = channel_axis
|
||||
self._preprocess = preprocess
|
||||
|
||||
def bounds(self):
|
||||
"""
|
||||
Return the upper and lower bounds of the model.
|
||||
"""
|
||||
return self._bounds
|
||||
|
||||
def channel_axis(self):
|
||||
"""
|
||||
Return the channel axis of the model.
|
||||
"""
|
||||
return self._channel_axis
|
||||
|
||||
def _process_input(self, input_):
|
||||
res = input_
|
||||
sub, div = self._preprocess
|
||||
if sub != 0:
|
||||
res = input_ - sub
|
||||
assert div != 0
|
||||
if div != 1:
|
||||
res /= div
|
||||
return res
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, image_batch):
|
||||
"""
|
||||
Calculate the prediction of the image batch.
|
||||
|
||||
Args:
|
||||
image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels).
|
||||
|
||||
Return:
|
||||
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_classes(self):
|
||||
"""
|
||||
Determine the number of the classes
|
||||
|
||||
Return:
|
||||
int: the number of the classes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def gradient(self, image_batch):
|
||||
"""
|
||||
Calculate the gradient of the cross-entropy loss w.r.t the image.
|
||||
|
||||
Args:
|
||||
image_batch(list): The image and label tuple list.
|
||||
|
||||
Return:
|
||||
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with
|
||||
the shape (height, width, channel).
|
||||
"""
|
||||
raise NotImplementedError
|
@ -0,0 +1,101 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
from paddle.v2.fluid.framework import program_guard
|
||||
|
||||
from .base import Model
|
||||
|
||||
|
||||
class PaddleModel(Model):
|
||||
"""
|
||||
Create a PaddleModel instance.
|
||||
When you need to generate a adversarial sample, you should construct an instance of PaddleModel.
|
||||
|
||||
Args:
|
||||
program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample.
|
||||
input_name(string): The name of the input.
|
||||
logits_name(string): The name of the logits.
|
||||
predict_name(string): The name of the predict.
|
||||
cost_name(string): The name of the loss in the program.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
program,
|
||||
input_name,
|
||||
logits_name,
|
||||
predict_name,
|
||||
cost_name,
|
||||
bounds,
|
||||
channel_axis=3,
|
||||
preprocess=None):
|
||||
super(PaddleModel, self).__init__(
|
||||
bounds=bounds, channel_axis=channel_axis, preprocess=preprocess)
|
||||
|
||||
if preprocess is None:
|
||||
preprocess = (0, 1)
|
||||
|
||||
self._program = program
|
||||
self._place = fluid.CPUPlace()
|
||||
self._exe = fluid.Executor(self._place)
|
||||
|
||||
self._input_name = input_name
|
||||
self._logits_name = logits_name
|
||||
self._predict_name = predict_name
|
||||
self._cost_name = cost_name
|
||||
|
||||
# gradient
|
||||
loss = self._program.block(0).var(self._cost_name)
|
||||
param_grads = fluid.backward.append_backward(
|
||||
loss, parameter_list=[self._input_name])
|
||||
self._gradient = dict(param_grads)[self._input_name]
|
||||
|
||||
def predict(self, image_batch):
|
||||
"""
|
||||
Predict the label of the image_batch.
|
||||
|
||||
Args:
|
||||
image_batch(list): The image and label tuple list.
|
||||
Return:
|
||||
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
|
||||
"""
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[self._input_name, self._logits_name],
|
||||
place=self._place,
|
||||
program=self._program)
|
||||
predict_var = self._program.block(0).var(self._predict_name)
|
||||
predict = self._exe.run(self._program,
|
||||
feed=feeder.feed(image_batch),
|
||||
fetch_list=[predict_var])
|
||||
return predict
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
Calculate the number of classes of the output label.
|
||||
|
||||
Return:
|
||||
int: the number of classes
|
||||
"""
|
||||
predict_var = self._program.block(0).var(self._predict_name)
|
||||
assert len(predict_var.shape) == 2
|
||||
return predict_var.shape[1]
|
||||
|
||||
def gradient(self, image_batch):
|
||||
"""
|
||||
Calculate the gradient of the loss w.r.t the input.
|
||||
|
||||
Args:
|
||||
image_batch(list): The image and label tuple list.
|
||||
Return:
|
||||
list: The list of the gradient of the image.
|
||||
"""
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[self._input_name, self._logits_name],
|
||||
place=self._place,
|
||||
program=self._program)
|
||||
|
||||
grad, = self._exe.run(self._program,
|
||||
feed=feeder.feed(image_batch),
|
||||
fetch_list=[self._gradient])
|
||||
return grad
|
@ -0,0 +1,86 @@
|
||||
"""
|
||||
CNN on mnist data using fluid api of paddlepaddle
|
||||
"""
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
|
||||
|
||||
def mnist_cnn_model(img):
|
||||
"""
|
||||
Mnist cnn model
|
||||
|
||||
Args:
|
||||
img(Varaible): the input image to be recognized
|
||||
|
||||
Returns:
|
||||
Variable: the label prediction
|
||||
"""
|
||||
conv_pool_1 = fluid.nets.simple_img_conv_pool(
|
||||
input=img,
|
||||
num_filters=20,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
conv_pool_2 = fluid.nets.simple_img_conv_pool(
|
||||
input=conv_pool_1,
|
||||
num_filters=50,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Train the cnn model on mnist datasets
|
||||
"""
|
||||
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
logits = mnist_cnn_model(img)
|
||||
cost = fluid.layers.cross_entropy(input=logits, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
optimizer = fluid.optimizer.Adam(learning_rate=0.01)
|
||||
optimizer.minimize(avg_cost)
|
||||
|
||||
accuracy = fluid.evaluator.Accuracy(input=logits, label=label)
|
||||
|
||||
BATCH_SIZE = 50
|
||||
PASS_NUM = 3
|
||||
ACC_THRESHOLD = 0.98
|
||||
LOSS_THRESHOLD = 10.0
|
||||
train_reader = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.mnist.train(), buf_size=500),
|
||||
batch_size=BATCH_SIZE)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
for pass_id in range(PASS_NUM):
|
||||
accuracy.reset(exe)
|
||||
for data in train_reader():
|
||||
loss, acc = exe.run(fluid.default_main_program(),
|
||||
feed=feeder.feed(data),
|
||||
fetch_list=[avg_cost] + accuracy.metrics)
|
||||
pass_acc = accuracy.eval(exe)
|
||||
print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc="
|
||||
+ str(pass_acc))
|
||||
if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD:
|
||||
break
|
||||
|
||||
pass_acc = accuracy.eval(exe)
|
||||
print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc))
|
||||
fluid.io.save_params(
|
||||
exe, dirname='./mnist', main_program=fluid.default_main_program())
|
||||
print('train mnist done')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,87 @@
|
||||
"""
|
||||
FGSM demos on mnist using advbox tool.
|
||||
"""
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from advbox.models.paddle import PaddleModel
|
||||
from advbox.attacks.gradientsign import GradientSignAttack
|
||||
|
||||
|
||||
def cnn_model(img):
|
||||
"""
|
||||
Mnist cnn model
|
||||
Args:
|
||||
img(Varaible): the input image to be recognized
|
||||
Returns:
|
||||
Variable: the label prediction
|
||||
"""
|
||||
#conv1 = fluid.nets.conv2d()
|
||||
conv_pool_1 = fluid.nets.simple_img_conv_pool(
|
||||
input=img,
|
||||
num_filters=20,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
conv_pool_2 = fluid.nets.simple_img_conv_pool(
|
||||
input=conv_pool_1,
|
||||
num_filters=50,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Advbox demo which demonstrate how to use advbox.
|
||||
"""
|
||||
IMG_NAME = 'img'
|
||||
LABEL_NAME = 'label'
|
||||
|
||||
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
|
||||
# gradient should flow
|
||||
img.stop_gradient = False
|
||||
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
|
||||
logits = cnn_model(img)
|
||||
cost = fluid.layers.cross_entropy(input=logits, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
BATCH_SIZE = 1
|
||||
train_reader = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.mnist.train(), buf_size=500),
|
||||
batch_size=BATCH_SIZE)
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[IMG_NAME, LABEL_NAME],
|
||||
place=place,
|
||||
program=fluid.default_main_program())
|
||||
|
||||
fluid.io.load_params(
|
||||
exe, "./mnist/", main_program=fluid.default_main_program())
|
||||
|
||||
# advbox demo
|
||||
m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME,
|
||||
logits.name, avg_cost.name, (-1, 1))
|
||||
att = GradientSignAttack(m)
|
||||
for data in train_reader():
|
||||
# fgsm attack
|
||||
adv_img = att(data)
|
||||
plt.imshow(n[0][0], cmap='Greys_r')
|
||||
plt.show()
|
||||
#np.save('adv_img', adv_img)
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1 +1 @@
|
||||
grpc_library(sendrecvop_grpc SRCS recv_impl.cc send_impl.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||
|
@ -0,0 +1,147 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "grpc_client.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
bool RPCClient::AsyncSendVariable(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name,
|
||||
int64_t time_out) {
|
||||
sendrecv::VariableMessage req;
|
||||
auto* var = scope.FindVar(var_name);
|
||||
SerializeToMessage(var_name, var, ctx, &req);
|
||||
|
||||
// varhandle
|
||||
VarHandle var_h;
|
||||
var_h.ep = ep;
|
||||
var_h.scope = &scope;
|
||||
var_h.name = var_name;
|
||||
var_h.ctx = &ctx;
|
||||
|
||||
// stub context
|
||||
auto ch = GetChannel(ep);
|
||||
SendProcessor* s = new SendProcessor(ch);
|
||||
s->Prepare(var_h, time_out);
|
||||
s->response_call_back_ = NULL;
|
||||
|
||||
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
|
||||
rpc->Finish(&s->reply_, &s->status_, (void*)s);
|
||||
|
||||
req_count_++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ProcGetResponse(const VarHandle& var_h,
|
||||
const sendrecv::VariableMessage& ret_msg) {
|
||||
auto* outvar = var_h.scope->FindVar(var_h.name);
|
||||
|
||||
std::istringstream iss(ret_msg.serialized());
|
||||
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
|
||||
}
|
||||
|
||||
bool RPCClient::AsyncGetVariable(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name,
|
||||
int64_t time_out) {
|
||||
sendrecv::VariableMessage req;
|
||||
req.set_varname(var_name);
|
||||
|
||||
auto* var = scope.FindVar(var_name);
|
||||
SerializeToMessage(var_name, var, ctx, &req);
|
||||
|
||||
// varhandle
|
||||
VarHandle var_h;
|
||||
var_h.ep = ep;
|
||||
var_h.scope = &scope;
|
||||
var_h.name = var_name;
|
||||
var_h.ctx = &ctx;
|
||||
|
||||
// stub context
|
||||
auto ch = GetChannel(ep);
|
||||
GetProcessor* s = new GetProcessor(ch);
|
||||
s->Prepare(var_h, time_out);
|
||||
s->response_call_back_ = ProcGetResponse;
|
||||
|
||||
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
|
||||
rpc->Finish(&s->reply_, &s->status_, (void*)s);
|
||||
|
||||
req_count_++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RPCClient::wait() {
|
||||
bool ok = true;
|
||||
|
||||
while (true) {
|
||||
if (req_count_ <= 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (!Proceed()) {
|
||||
LOG(ERROR) << "Get meets CompletionQueue error";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
bool RPCClient::Proceed() {
|
||||
void* tag = NULL;
|
||||
bool ok = false;
|
||||
|
||||
// request counts.
|
||||
if (!cq_.Next(&tag, &ok)) {
|
||||
return false;
|
||||
}
|
||||
req_count_--;
|
||||
|
||||
GPR_ASSERT(ok);
|
||||
PADDLE_ENFORCE(tag);
|
||||
|
||||
// TODO(gongwb): add more retries.
|
||||
ClientBase* c = static_cast<ClientBase*>(tag);
|
||||
if (!c->status_.ok()) {
|
||||
delete c;
|
||||
return true;
|
||||
}
|
||||
|
||||
c->Process();
|
||||
delete c;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
|
||||
auto it = channels_.find(ep);
|
||||
if (it != channels_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto ch = std::shared_ptr<grpc::Channel>(
|
||||
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
|
||||
|
||||
channels_[ep] = ch;
|
||||
return ch;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,147 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
#include <grpc/support/log.h>
|
||||
#include <time.h>
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/framework/data_type.h"
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/selected_rows.h"
|
||||
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||
#include "paddle/operators/detail/simple_block_queue.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
struct VarHandle {
|
||||
std::string ep;
|
||||
const platform::DeviceContext* ctx;
|
||||
const framework::Scope* scope;
|
||||
std::string name;
|
||||
|
||||
std::string String() const {
|
||||
std::ostringstream s;
|
||||
s << "name:[" << name << "] ep:[" << ep << "]";
|
||||
return s.str();
|
||||
}
|
||||
};
|
||||
|
||||
void ProcGetResponse(const VarHandle& var_h,
|
||||
const sendrecv::VariableMessage& msg);
|
||||
|
||||
class ClientBase {
|
||||
public:
|
||||
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) {
|
||||
stub_ = sendrecv::SendRecvService::NewStub(ch);
|
||||
context_ = NULL;
|
||||
}
|
||||
|
||||
virtual ~ClientBase() {}
|
||||
|
||||
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
|
||||
context_.reset(new grpc::ClientContext());
|
||||
var_h_ = var_info;
|
||||
|
||||
std::chrono::system_clock::time_point deadline =
|
||||
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
|
||||
|
||||
context_->set_deadline(deadline);
|
||||
}
|
||||
|
||||
virtual void Process() = 0;
|
||||
|
||||
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
|
||||
std::unique_ptr<grpc::ClientContext> context_;
|
||||
grpc::Status status_;
|
||||
VarHandle var_h_;
|
||||
};
|
||||
|
||||
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
|
||||
RequestSendCallBack;
|
||||
|
||||
class SendProcessor : public ClientBase {
|
||||
public:
|
||||
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
|
||||
|
||||
virtual ~SendProcessor() {}
|
||||
|
||||
virtual void Process() {
|
||||
if (response_call_back_) {
|
||||
response_call_back_(var_h_, reply_);
|
||||
}
|
||||
}
|
||||
|
||||
sendrecv::VoidMessage reply_;
|
||||
RequestSendCallBack response_call_back_ = NULL;
|
||||
};
|
||||
|
||||
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
|
||||
RequestGetCallBack;
|
||||
|
||||
class GetProcessor : public ClientBase {
|
||||
public:
|
||||
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
|
||||
|
||||
virtual ~GetProcessor() {}
|
||||
|
||||
virtual void Process() {
|
||||
if (response_call_back_) {
|
||||
response_call_back_(var_h_, reply_);
|
||||
}
|
||||
}
|
||||
|
||||
sendrecv::VariableMessage reply_;
|
||||
RequestGetCallBack response_call_back_ = ProcGetResponse;
|
||||
};
|
||||
|
||||
class RPCClient {
|
||||
public:
|
||||
bool AsyncSendVariable(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name,
|
||||
int64_t time_out = 600 * 1000);
|
||||
|
||||
bool AsyncGetVariable(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name,
|
||||
int64_t time_out = 600 * 1000);
|
||||
bool wait();
|
||||
|
||||
private:
|
||||
bool Proceed();
|
||||
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
|
||||
|
||||
private:
|
||||
grpc::CompletionQueue cq_;
|
||||
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
|
||||
int64_t req_count_ = 0;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,237 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/detail/grpc_server.h"
|
||||
|
||||
using grpc::ServerAsyncResponseWriter;
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
enum CallStatus { PROCESS = 0, FINISH };
|
||||
|
||||
// reference:
|
||||
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
|
||||
class RequestBase {
|
||||
public:
|
||||
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
|
||||
grpc::ServerCompletionQueue* cq)
|
||||
: service_(service), cq_(cq), status_(PROCESS) {}
|
||||
virtual ~RequestBase() {}
|
||||
virtual void Process() { assert(false); }
|
||||
|
||||
CallStatus Status() { return status_; }
|
||||
void SetStatus(CallStatus status) { status_ = status; }
|
||||
|
||||
protected:
|
||||
grpc::ServerContext ctx_;
|
||||
sendrecv::SendRecvService::AsyncService* service_;
|
||||
grpc::ServerCompletionQueue* cq_;
|
||||
CallStatus status_;
|
||||
};
|
||||
|
||||
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||
|
||||
class RequestSend final : public RequestBase {
|
||||
public:
|
||||
explicit RequestSend(sendrecv::SendRecvService::AsyncService* service,
|
||||
grpc::ServerCompletionQueue* cq,
|
||||
SimpleBlockQueue<MessageWithName>* queue)
|
||||
: RequestBase(service, cq), queue_(queue), responder_(&ctx_) {
|
||||
service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_,
|
||||
this);
|
||||
}
|
||||
|
||||
virtual ~RequestSend() {}
|
||||
|
||||
virtual void Process() {
|
||||
MessageWithName msg_with_name =
|
||||
std::make_pair(request_.varname(), std::move(request_));
|
||||
queue_->Push(std::move(msg_with_name));
|
||||
// TODO(gongwb): check var's info.
|
||||
responder_.Finish(reply_, grpc::Status::OK, this);
|
||||
}
|
||||
|
||||
protected:
|
||||
sendrecv::VariableMessage request_;
|
||||
sendrecv::VoidMessage reply_;
|
||||
SimpleBlockQueue<MessageWithName>* queue_;
|
||||
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
|
||||
};
|
||||
|
||||
class RequestGet final : public RequestBase {
|
||||
public:
|
||||
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
|
||||
grpc::ServerCompletionQueue* cq, framework::Scope* scope)
|
||||
: RequestBase(service, cq), responder_(&ctx_), scope_(scope) {
|
||||
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
|
||||
}
|
||||
|
||||
virtual ~RequestGet() {}
|
||||
|
||||
virtual void Process() {
|
||||
// proc request.
|
||||
std::string var_name = request_.varname();
|
||||
auto* var = scope_->FindVar(var_name);
|
||||
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
|
||||
// TODO(gongwb): check var's info.
|
||||
responder_.Finish(reply_, grpc::Status::OK, this);
|
||||
}
|
||||
|
||||
protected:
|
||||
sendrecv::VariableMessage request_;
|
||||
sendrecv::VariableMessage reply_;
|
||||
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
|
||||
framework::Scope* scope_;
|
||||
};
|
||||
|
||||
void AsyncGRPCServer::RunSyncUpdate() {
|
||||
grpc::ServerBuilder builder;
|
||||
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service_);
|
||||
|
||||
cq_send_ = builder.AddCompletionQueue();
|
||||
cq_get_ = builder.AddCompletionQueue();
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(INFO) << "Server listening on " << address_ << std::endl;
|
||||
|
||||
std::function<void()> send_register =
|
||||
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
|
||||
std::function<void()> get_register =
|
||||
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
|
||||
|
||||
t_send_.reset(
|
||||
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
|
||||
cq_send_.get(), "cq_send", send_register)));
|
||||
|
||||
t_get_.reset(
|
||||
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
|
||||
cq_get_.get(), "cq_get", get_register)));
|
||||
|
||||
// wait server
|
||||
server_->Wait();
|
||||
t_send_->join();
|
||||
t_get_->join();
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::ShutdownQueue() {
|
||||
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||
cq_send_->Shutdown();
|
||||
cq_get_->Shutdown();
|
||||
is_shut_down_ = true;
|
||||
}
|
||||
|
||||
// This URL explains why shutdown is complicate:
|
||||
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
|
||||
void AsyncGRPCServer::ShutDown() {
|
||||
server_->Shutdown();
|
||||
ShutdownQueue();
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::TryToRegisterNewSendOne() {
|
||||
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||
if (is_shut_down_) {
|
||||
return;
|
||||
}
|
||||
RequestSend* send =
|
||||
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
|
||||
VLOG(4) << "create RequestSend status:" << send->Status();
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::TryToRegisterNewGetOne() {
|
||||
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||
if (is_shut_down_) {
|
||||
return;
|
||||
}
|
||||
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_);
|
||||
VLOG(4) << "create Requestget status:" << get->Status();
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
|
||||
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||
if (is_shut_down_) {
|
||||
delete last;
|
||||
last = NULL;
|
||||
return;
|
||||
}
|
||||
|
||||
last->SetStatus(FINISH);
|
||||
return;
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
|
||||
std::string cq_name,
|
||||
std::function<void()> TryToRegisterNewOne) {
|
||||
TryToRegisterNewOne();
|
||||
|
||||
void* tag = NULL;
|
||||
bool ok = false;
|
||||
while (true) {
|
||||
if (!cq->Next(&tag, &ok)) {
|
||||
LOG(INFO) << cq_name << " get CompletionQueue shutdown!";
|
||||
break;
|
||||
}
|
||||
|
||||
if (wait && !done_) {
|
||||
Wait();
|
||||
}
|
||||
|
||||
RequestBase* base = (RequestBase*)tag;
|
||||
if (!ok) {
|
||||
VLOG(4) << cq_name << " recv no regular event";
|
||||
TryToRegisterNewOne();
|
||||
delete base;
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (base->Status()) {
|
||||
case PROCESS: {
|
||||
VLOG(4) << cq_name << " status:" << base->Status();
|
||||
TryToRegisterNewOne();
|
||||
base->Process();
|
||||
SetFinishOrDelete(base);
|
||||
break;
|
||||
}
|
||||
case FINISH: {
|
||||
VLOG(4) << cq_name << " status:" << base->Status();
|
||||
delete base;
|
||||
break;
|
||||
}
|
||||
default: { assert(false); }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::Wait() {
|
||||
std::unique_lock<std::mutex> lock(this->mutex_);
|
||||
condition_.wait(lock, [=] { return this->done_ == true; });
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::Reset() {
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
done_ = false;
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::Done() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
done_ = true;
|
||||
}
|
||||
condition_.notify_all();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/selected_rows.h"
|
||||
#include "paddle/framework/var_type.h"
|
||||
#include "paddle/operators/detail/simple_block_queue.h"
|
||||
|
||||
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
||||
#include "paddle/operators/detail/send_recv.pb.h"
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
#include <grpc/support/log.h>
|
||||
#include <thread>
|
||||
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||
class RequestBase;
|
||||
|
||||
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
|
||||
public:
|
||||
explicit AsyncGRPCServer(std::string address) { address_ = address; }
|
||||
|
||||
void RunSyncUpdate();
|
||||
|
||||
void Reset();
|
||||
|
||||
void Done();
|
||||
|
||||
void SetScope(framework::Scope *scope) { scope_ = scope; }
|
||||
|
||||
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
|
||||
|
||||
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
|
||||
|
||||
void ShutDown();
|
||||
|
||||
protected:
|
||||
void Wait();
|
||||
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
|
||||
std::string cq_name,
|
||||
std::function<void()> TryToRegisterNewOne);
|
||||
void TryToRegisterNewSendOne();
|
||||
void TryToRegisterNewGetOne();
|
||||
void SetFinishOrDelete(RequestBase *&last);
|
||||
void ShutdownQueue();
|
||||
|
||||
private:
|
||||
std::mutex cq_mutex_;
|
||||
volatile bool is_shut_down_ = false;
|
||||
std::unique_ptr<grpc::ServerCompletionQueue> cq_send_;
|
||||
std::unique_ptr<grpc::ServerCompletionQueue> cq_get_;
|
||||
|
||||
sendrecv::SendRecvService::AsyncService service_;
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
|
||||
std::string address_;
|
||||
framework::Scope *scope_;
|
||||
// received variable from RPC, operators fetch variable from this queue.
|
||||
SimpleBlockQueue<MessageWithName> var_recv_queue_;
|
||||
|
||||
// condition of the sub program
|
||||
std::mutex mutex_;
|
||||
volatile mutable bool done_;
|
||||
std::condition_variable condition_;
|
||||
|
||||
std::unique_ptr<std::thread> t_send_;
|
||||
std::unique_ptr<std::thread> t_get_;
|
||||
};
|
||||
|
||||
}; // namespace detail
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
@ -1,65 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "send_recv_impl.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
Status SendRecvServerImpl::SendVariable(ServerContext *context,
|
||||
const VariableMessage *in_var,
|
||||
VoidMessage *out_var) {
|
||||
MessageWithName msg_with_name =
|
||||
std::make_pair(in_var->varname(), std::move(*in_var));
|
||||
var_recv_queue_.Push(std::move(msg_with_name));
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SendRecvServerImpl::GetVariable(ServerContext *context,
|
||||
const VariableMessage *in_var,
|
||||
VariableMessage *out_var) {
|
||||
std::string get_var_name = in_var->varname();
|
||||
auto *var = scope_->FindVar(get_var_name);
|
||||
|
||||
SerializeToMessage(get_var_name, var, platform::CPUDeviceContext(), out_var);
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SendRecvServerImpl::Wait(ServerContext *context,
|
||||
const VoidMessage *in_var,
|
||||
VoidMessage *out_var) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->mutex_);
|
||||
condition_.wait(lock, [=] { return this->done_ == true; });
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
void SendRecvServerImpl::Reset() {
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
done_ = false;
|
||||
}
|
||||
|
||||
void SendRecvServerImpl::Done() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
done_ = true;
|
||||
}
|
||||
condition_.notify_all();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,67 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "send_recv_impl.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
bool RPCClient::SendVariable(const framework::Scope& scope,
|
||||
const std::string& inname) {
|
||||
ClientContext context;
|
||||
VariableMessage msg;
|
||||
VoidMessage out_msg;
|
||||
// FIXME(typhoonzero): pass device context to here.
|
||||
auto ctx = platform::CPUDeviceContext();
|
||||
auto* var = scope.FindVar(inname);
|
||||
PADDLE_ENFORCE(var);
|
||||
SerializeToMessage(inname, var, ctx, &msg);
|
||||
|
||||
Status status = stub_->SendVariable(&context, msg, &out_msg);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "gRPC error: " << status.error_message();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RPCClient::GetVariable(const framework::Scope& scope,
|
||||
const std::string& outname) {
|
||||
ClientContext context;
|
||||
VariableMessage call_msg, ret_msg;
|
||||
call_msg.set_varname(outname);
|
||||
auto ctx = platform::CPUDeviceContext();
|
||||
Status status = stub_->GetVariable(&context, call_msg, &ret_msg);
|
||||
auto* outvar = scope.FindVar(outname);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "gRPC error: " << status.error_message();
|
||||
return false;
|
||||
}
|
||||
|
||||
std::istringstream iss(ret_msg.serialized());
|
||||
DeserializeFromMessage(ret_msg, ctx, outvar);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void RPCClient::Wait() {
|
||||
ClientContext context;
|
||||
VoidMessage call_msg, ret_msg;
|
||||
stub_->Wait(&context, call_msg, &ret_msg);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,141 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/selected_rows.h"
|
||||
#include "paddle/framework/var_type.h"
|
||||
#include "paddle/operators/detail/simple_block_queue.h"
|
||||
|
||||
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
||||
#include "paddle/operators/detail/send_recv.pb.h"
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
|
||||
using grpc::Channel;
|
||||
using grpc::Server;
|
||||
using grpc::ServerContext;
|
||||
using grpc::ServerReader;
|
||||
using grpc::ServerBuilder;
|
||||
|
||||
using grpc::ClientContext;
|
||||
using grpc::ClientReader;
|
||||
using grpc::ClientReaderWriter;
|
||||
using grpc::ClientWriter;
|
||||
using grpc::Status;
|
||||
using sendrecv::SendRecvService;
|
||||
using sendrecv::VariableMessage;
|
||||
using sendrecv::VoidMessage;
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||
|
||||
class SendRecvServerImpl final : public SendRecvService::Service {
|
||||
public:
|
||||
explicit SendRecvServerImpl() {}
|
||||
|
||||
Status SendVariable(ServerContext *context, const VariableMessage *in_var,
|
||||
VoidMessage *out_var) override;
|
||||
Status GetVariable(ServerContext *context, const VariableMessage *in_var,
|
||||
VariableMessage *out_var) override;
|
||||
Status Wait(ServerContext *context, const VoidMessage *in_var,
|
||||
VoidMessage *out_var) override;
|
||||
void Reset();
|
||||
void Done();
|
||||
void SetScope(framework::Scope *scope) { scope_ = scope; };
|
||||
|
||||
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
|
||||
|
||||
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
|
||||
|
||||
private:
|
||||
// received variable from RPC, operators fetch variable from this queue.
|
||||
SimpleBlockQueue<MessageWithName> var_recv_queue_;
|
||||
framework::Scope *scope_;
|
||||
// condition of the sub program
|
||||
std::mutex mutex_;
|
||||
bool done_;
|
||||
std::condition_variable condition_;
|
||||
};
|
||||
|
||||
// RPCClient is a class to send tensors to pserver sub-network
|
||||
// using different hashing methods.
|
||||
class RPCClient {
|
||||
public:
|
||||
RPCClient(std::shared_ptr<Channel> channel)
|
||||
: stub_(SendRecvService::NewStub(channel)) {}
|
||||
|
||||
bool SendVariable(const framework::Scope &scope, const std::string &inname);
|
||||
bool GetVariable(const framework::Scope &scope, const std::string &outname);
|
||||
void Wait();
|
||||
|
||||
private:
|
||||
std::unique_ptr<SendRecvService::Stub> stub_;
|
||||
};
|
||||
|
||||
inline void SerializeToMessage(const std::string &name,
|
||||
const framework::Variable *var,
|
||||
const platform::DeviceContext &ctx,
|
||||
VariableMessage *msg) {
|
||||
msg->set_varname(name);
|
||||
std::ostringstream oss;
|
||||
switch (framework::ToVarType(var->Type())) {
|
||||
case framework::proto::VarDesc_VarType_LOD_TENSOR:
|
||||
msg->set_type(sendrecv::VarType::LOD_TENSOR);
|
||||
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
|
||||
break;
|
||||
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
|
||||
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
|
||||
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
|
||||
ctx);
|
||||
break;
|
||||
default: {
|
||||
PADDLE_THROW("Serialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg->set_serialized(oss.str());
|
||||
}
|
||||
|
||||
inline void DeserializeFromMessage(const VariableMessage &msg,
|
||||
const platform::DeviceContext &ctx,
|
||||
framework::Variable *var) {
|
||||
using namespace paddle::framework::proto;
|
||||
std::istringstream iss(msg.serialized());
|
||||
switch (msg.type()) {
|
||||
case sendrecv::VarType::LOD_TENSOR:
|
||||
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
|
||||
break;
|
||||
case sendrecv::VarType::SELECTED_ROWS: {
|
||||
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
|
||||
ctx);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PADDLE_THROW("Deserialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,68 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
void SerializeToMessage(const std::string& name, const framework::Variable* var,
|
||||
const platform::DeviceContext& ctx,
|
||||
sendrecv::VariableMessage* msg) {
|
||||
msg->set_varname(name);
|
||||
std::ostringstream oss;
|
||||
switch (framework::ToVarType(var->Type())) {
|
||||
case framework::proto::VarDesc_VarType_LOD_TENSOR:
|
||||
msg->set_type(sendrecv::VarType::LOD_TENSOR);
|
||||
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
|
||||
break;
|
||||
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
|
||||
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
|
||||
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
|
||||
ctx);
|
||||
break;
|
||||
default: {
|
||||
PADDLE_THROW("Serialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg->set_serialized(oss.str());
|
||||
}
|
||||
|
||||
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
|
||||
const platform::DeviceContext& ctx,
|
||||
framework::Variable* var) {
|
||||
std::istringstream iss(msg.serialized());
|
||||
switch (msg.type()) {
|
||||
case sendrecv::VarType::LOD_TENSOR:
|
||||
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
|
||||
break;
|
||||
case sendrecv::VarType::SELECTED_ROWS: {
|
||||
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
|
||||
ctx);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PADDLE_THROW("Deserialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,42 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/framework/data_type.h"
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/selected_rows.h"
|
||||
#include "paddle/framework/var_type.h"
|
||||
|
||||
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
||||
#include "paddle/operators/detail/send_recv.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
void SerializeToMessage(const std::string& name, const framework::Variable* var,
|
||||
const platform::DeviceContext& ctx,
|
||||
sendrecv::VariableMessage* msg);
|
||||
|
||||
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
|
||||
const platform::DeviceContext& ctx,
|
||||
framework::Variable* var);
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue