Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-7195

add_depthwiseConv_op_gpu
yangyaming 7 years ago
commit a7c2bfb4a1

@ -0,0 +1,9 @@
# Advbox
Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle.
## How to use
1. train a model and save it's parameters. (like fluid_mnist.py)
2. load the parameters which is trained in step1, then reconstruct the model.(like mnist_tutorial_fgsm.py)
3. use advbox to generate the adversarial sample.

@ -0,0 +1,16 @@
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A set of tools for generating adversarial example on paddle platform
"""

@ -0,0 +1,39 @@
"""
The base model of the model.
"""
from abc import ABCMeta, abstractmethod
class Attack(object):
"""
Abstract base class for adversarial attacks. `Attack` represent an adversarial attack
which search an adversarial example. subclass should implement the _apply() method.
Args:
model(Model): an instance of the class advbox.base.Model.
"""
__metaclass__ = ABCMeta
def __init__(self, model):
self.model = model
def __call__(self, image_label):
"""
Generate the adversarial sample.
Args:
image_label(list): The image and label tuple list with one element.
"""
adv_img = self._apply(image_label)
return adv_img
@abstractmethod
def _apply(self, image_label):
"""
Search an adversarial example.
Args:
image_batch(list): The image and label tuple list with one element.
"""
raise NotImplementedError

@ -0,0 +1,38 @@
"""
This module provide the attack method for FGSM's implement.
"""
from __future__ import division
import numpy as np
from collections import Iterable
from .base import Attack
class GradientSignAttack(Attack):
"""
This attack was originally implemented by Goodfellow et al. (2015) with the
infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called
the Fast Gradient Method.
Paper link: https://arxiv.org/abs/1412.6572
"""
def _apply(self, image_label, epsilons=1000):
assert len(image_label) == 1
pre_label = np.argmax(self.model.predict(image_label))
min_, max_ = self.model.bounds()
gradient = self.model.gradient(image_label)
gradient_sign = np.sign(gradient) * (max_ - min_)
if not isinstance(epsilons, Iterable):
epsilons = np.linspace(0, 1, num=epsilons + 1)
for epsilon in epsilons:
adv_img = image_label[0][0].reshape(
gradient_sign.shape) + epsilon * gradient_sign
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
if pre_label != adv_label:
return adv_img
FGSM = GradientSignAttack

@ -0,0 +1,16 @@
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Paddle model for target of attack
"""

@ -0,0 +1,90 @@
"""
The base model of the model.
"""
from abc import ABCMeta
import abc
abstractmethod = abc.abstractmethod
class Model(object):
"""
Base class of model to provide attack.
Args:
bounds(tuple): The lower and upper bound for the image pixel.
channel_axis(int): The index of the axis that represents the color channel.
preprocess(tuple): Two element tuple used to preprocess the input. First
substract the first element, then divide the second element.
"""
__metaclass__ = ABCMeta
def __init__(self, bounds, channel_axis, preprocess=None):
assert len(bounds) == 2
assert channel_axis in [0, 1, 2, 3]
if preprocess is None:
preprocess = (0, 1)
self._bounds = bounds
self._channel_axis = channel_axis
self._preprocess = preprocess
def bounds(self):
"""
Return the upper and lower bounds of the model.
"""
return self._bounds
def channel_axis(self):
"""
Return the channel axis of the model.
"""
return self._channel_axis
def _process_input(self, input_):
res = input_
sub, div = self._preprocess
if sub != 0:
res = input_ - sub
assert div != 0
if div != 1:
res /= div
return res
@abstractmethod
def predict(self, image_batch):
"""
Calculate the prediction of the image batch.
Args:
image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels).
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
"""
raise NotImplementedError
@abstractmethod
def num_classes(self):
"""
Determine the number of the classes
Return:
int: the number of the classes
"""
raise NotImplementedError
@abstractmethod
def gradient(self, image_batch):
"""
Calculate the gradient of the cross-entropy loss w.r.t the image.
Args:
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with
the shape (height, width, channel).
"""
raise NotImplementedError

@ -0,0 +1,101 @@
from __future__ import absolute_import
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from paddle.v2.fluid.framework import program_guard
from .base import Model
class PaddleModel(Model):
"""
Create a PaddleModel instance.
When you need to generate a adversarial sample, you should construct an instance of PaddleModel.
Args:
program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample.
input_name(string): The name of the input.
logits_name(string): The name of the logits.
predict_name(string): The name of the predict.
cost_name(string): The name of the loss in the program.
"""
def __init__(self,
program,
input_name,
logits_name,
predict_name,
cost_name,
bounds,
channel_axis=3,
preprocess=None):
super(PaddleModel, self).__init__(
bounds=bounds, channel_axis=channel_axis, preprocess=preprocess)
if preprocess is None:
preprocess = (0, 1)
self._program = program
self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place)
self._input_name = input_name
self._logits_name = logits_name
self._predict_name = predict_name
self._cost_name = cost_name
# gradient
loss = self._program.block(0).var(self._cost_name)
param_grads = fluid.backward.append_backward(
loss, parameter_list=[self._input_name])
self._gradient = dict(param_grads)[self._input_name]
def predict(self, image_batch):
"""
Predict the label of the image_batch.
Args:
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
"""
feeder = fluid.DataFeeder(
feed_list=[self._input_name, self._logits_name],
place=self._place,
program=self._program)
predict_var = self._program.block(0).var(self._predict_name)
predict = self._exe.run(self._program,
feed=feeder.feed(image_batch),
fetch_list=[predict_var])
return predict
def num_classes(self):
"""
Calculate the number of classes of the output label.
Return:
int: the number of classes
"""
predict_var = self._program.block(0).var(self._predict_name)
assert len(predict_var.shape) == 2
return predict_var.shape[1]
def gradient(self, image_batch):
"""
Calculate the gradient of the loss w.r.t the input.
Args:
image_batch(list): The image and label tuple list.
Return:
list: The list of the gradient of the image.
"""
feeder = fluid.DataFeeder(
feed_list=[self._input_name, self._logits_name],
place=self._place,
program=self._program)
grad, = self._exe.run(self._program,
feed=feeder.feed(image_batch),
fetch_list=[self._gradient])
return grad

@ -0,0 +1,86 @@
"""
CNN on mnist data using fluid api of paddlepaddle
"""
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
def mnist_cnn_model(img):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return logits
def main():
"""
Train the cnn model on mnist datasets
"""
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Adam(learning_rate=0.01)
optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=logits, label=label)
BATCH_SIZE = 50
PASS_NUM = 3
ACC_THRESHOLD = 0.98
LOSS_THRESHOLD = 10.0
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
accuracy.reset(exe)
for data in train_reader():
loss, acc = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics)
pass_acc = accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc="
+ str(pass_acc))
if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD:
break
pass_acc = accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc))
fluid.io.save_params(
exe, dirname='./mnist', main_program=fluid.default_main_program())
print('train mnist done')
if __name__ == '__main__':
main()

@ -0,0 +1,87 @@
"""
FGSM demos on mnist using advbox tool.
"""
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import matplotlib.pyplot as plt
import numpy as np
from advbox.models.paddle import PaddleModel
from advbox.attacks.gradientsign import GradientSignAttack
def cnn_model(img):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
#conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return logits
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(
feed_list=[IMG_NAME, LABEL_NAME],
place=place,
program=fluid.default_main_program())
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME,
logits.name, avg_cost.name, (-1, 1))
att = GradientSignAttack(m)
for data in train_reader():
# fgsm attack
adv_img = att(data)
plt.imshow(n[0][0], cmap='Greys_r')
plt.show()
#np.save('adv_img', adv_img)
break
if __name__ == '__main__':
main()

@ -5,28 +5,28 @@
In a lecture from Andrew Ng, he attributes the recent sucess of AI due to a combination of these:
- availability of Big Data
- supercomputing power to process this Big Data over very large neural networks
- modern algorithms
- Availability of Big Data
- Supercomputing power to process this Big Data over very large neural networks
- Modern algorithms
Following graph shows the details:
![](images/deep_learning.png)
Larger model usually brings better performance. However, GPU memory is certain limited. For example, the memory size of a GTX TITAN X is only 12GB. To train complex and large model, we have to take care of memory using. Besides, memory optimization is also necessary in both online/mobile inference.
Larger model usually bring better performance. However, GPU memory is limited. For example, the memory size of a GTX TITAN X is only 12GB. To train complex and large models, we have to take care of memory usage. Besides, memory optimization is also necessary in both online/mobile inference.
## Solution
### Basic Strategy
There are some basic strategies to make memory optimization, including in-place operation and memory sharing.
There are some basic strategies to improve memory usage, including in-place operations and memory sharing.
#### In-place Operation
In a relu activation operator
$y = \max(x, 0)$
If the variable x is not used in any other operator, we can make an in-place operation. In other words, the memory block of variable y and variable x are the same. In-place operation will save 50% memory occupancy immediately.
If the variable x is not used in any other operator, we can make an in-place operation. In other words, the memory block of variable y and variable x will be the same. In-place operations will save 50% memory occupancy immediately.
#### Memory Sharing
@ -40,18 +40,18 @@ d = op2(a)
e = op3(d, f)
```
In this case, variable a is no longer used, and op2 does not support in-place operation. After op2 finished, we can put the memory of variable a to a memory pool. Then, variable e can share the memory of variable a from the pool.
In this case, variable a is no longer used, and op2 does not support in-place operation. After op2 finishes, we can put the memory of variable a to a memory pool. Then, variable e can share the memory of variable a from the pool.
### Live Variable Analysis
It's not enough to only have some basic strategies. The prerequisite of memory optimization is to know if a variable is still "live" after an operation.
It's not enough to only have some basic strategies. The pre-requisite of memory optimization is to know if a variable is still "live" after an operation.
In our design, the neural network topology is defined as a program. Luckily, [live variable analysis](https://en.wikipedia.org/wiki/Live_variable_analysis) is a classic problem in compilers which can be used in many stages, such as register allocation.
In compilers, the front end of the compilers translates programs into an intermediate language with an unbounded number of temporaries. This program must run on a machine with a bounded number of registers. Two temporaries a and b can fit into the same register, if a and b are never "in use" at the same time. Thus, many temporaries can fit in few registers; if they don't all fit, the excess temporaries can be kept in memory.
In compilers, the front end of the compiler translates programs into an intermediate language with an unbounded number of temporary variables. This program must run on a machine with a bounded number of registers. Two temporary variables a and b can fit into the same register, if a and b are never "in use" at the same time. Thus, many temporary variables can fit in few registers; if they don't all fit, the excess tempory variables can be kept in memory.
Therefore, the compiler needs to analyze the intermediate-representation program to determine which temporaries are in use at the same time. We say a variable is "live" if it holds a value that may be needed in the future, so this analysis is called liveness analysis.
Therefore, the compiler needs to analyze the intermediate-representation program to determine which temporary variables are in use at the same time. We say a variable is "live" if it holds a value that may be needed in the future, so this analysis is called liveness analysis.
We can leran these techniques from compilers. There are mainly two stages to make live variable analysis:
@ -60,7 +60,7 @@ We can leran these techniques from compilers. There are mainly two stages to mak
#### Control Flow Graph
To preform analyses on a program, it is often useful to make a control flow graph. A [control flow graph](https://en.wikipedia.org/wiki/Control_flow_graph) (CFG) in computer science is a representation, using graph notation, of all paths that might be traversed through a program during its execution. Each statement in the program is a node in the flow graph; if statemment x can be followed by statement y, there is an egde from x to y.
To perform analysis on a program, it is often useful to make a control flow graph. A [control flow graph](https://en.wikipedia.org/wiki/Control_flow_graph) (CFG) in computer science is a representation, using graph notation, of all paths that might be traversed through a program during its execution. Each statement in the program is a node in the flow graph; if statemment x can be followed by statement y, there is an egde from x to y.
Following is the flow graph for a simple loop.
@ -68,18 +68,18 @@ Following is the flow graph for a simple loop.
#### Dataflow Analysis
liveness of variable "flows" around the edges of the control flow graph; determining the live range of each variable is an example of a dataflow problem. [Dataflow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) is a technique for gathering information about the possible set of values calculated at various points in a computer program.
Liveness of variable "flows" around the edges of the control flow graph; determining the live range of each variable is an example of a dataflow problem. [Dataflow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) is a technique for gathering information about the possible set of values calculated at various points in a computer program.
A simple way to perform data-flow analysis of programs is to set up dataflow equations for each node of the control flow graph and solve them by repeatedly calculating the output from the input locally at each node until the whole system stabilizes.
- Flow Graph Terminology
A flow graph node has out-edges that lead to sucessor nodes, and in-edges that come from presucessor nodes. The set *pred[n]* is all the predecessors of node n, and *succ[n]* is the set of sucessors.
A flow graph node has out-edges that lead to sucessor nodes, and in-edges that come from predecessor nodes. The set *pred[n]* is all the predecessors of node n, and *succ[n]* is the set of sucessors.
In former control flow graph, the out-edges of node 5 are 5 --> 6 and 5 --> 2, and *succ[5]* = {2, 6}. The in-edges of 2 are 5 --> 2 and 1 --> 2, and *pred[2]* = {1, 5}.
- Uses and Defs
An assignmemt to a variable or temporary defines that variable. An occurence of a variable on the right-hand side of an assginment(or in other expressions) uses the variable. We can speak the *def* of a variable as the set of graph nodes that define it; or the *def* of a graph node as the set of variables that it defines; and the similarly for the *use* of a variable or graph node. In former control flow graph, *def(3)* = {c}, *use(3)* = {b, c}.
An assignmemt to a variable or temporary defines that variable. An occurence of a variable on the right-hand side of an assginment(or in other expressions) uses the variable. We can define the *def* of a variable as the set of graph nodes that define it; or the *def* of a graph node as the set of variables that it defines; and the similarly for the *use* of a variable or graph node. In former control flow graph, *def(3)* = {c}, *use(3)* = {b, c}.
- Liveness
@ -168,9 +168,9 @@ class ControlFlowGraph(object):
return self._program
```
#### make dataflow analysis
#### Make dataflow analysis
We follow guide from compilers and try to solve the dataflow equation to get liveness of every variable. If the live-in of an operator node is different from the live-out, then we can make memory sharing.
We follow the guide from compilers and try to solve the dataflow equation to get liveness of every variable. If the live-in of an operator node is different from the live-out, then we can make memory sharing.
For example:

@ -279,6 +279,26 @@ class LayerHelper(object):
return tmp
```
### Return value of layer functions
The layer will return a Variable, which is also the output of an operator. However, outputs of a layer function have more attributes than an operator. There are parameter variables, and their gradient variables need to return. To return them is useful. For example,
1. Users can debug the network by printing parameter gradients.
2. Users can append attributes to a parameter, such as, `param.stop_gradient=True` will make a parameter stop generate the gradient. We can fix the parameter value during training by using this attribute.
However, it is good to return a Variable for layers, since all layers and operators use Variables as their parameters. We can just append a `param` field and a `grad` field for layer function since the Python is dynamic typing.
The sample usage is
```python
data = fluid.layers.data(...)
hidden = fluid.layers.fc(data, ...)
...
executor.run(fetch_list=[hidden.param, hidden.param.grad], ...)
```
## Optimizer
[Optimizer Design Doc](./optimizer.md)

@ -26,16 +26,16 @@ sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
```
- Variables: `x`, `y`, `y_predict`, `cost` and `avg_cost`. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/framework.py#L93)
- Layers: `fluid.layers.data`, `fluid.layers.fc` and `fluid.layers.mean` are layers. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/layers.py)
- Variables: `x`, `y`, `y_predict`, `cost` and `avg_cost`. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/framework.py#)
- Layers: `fluid.layers.data`, `fluid.layers.fc` and `fluid.layers.mean` are layers. [Python](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/v2/fluid/layers)
- Every Layer has one or more operators and variables/parameters
- All the operators are defined at [`paddle/operators/`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators). Other worth-looking files:
- Base class: [`paddle/framework/operator.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h)
- Operator Registration: [`paddle/framework/op_registry.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_registry.h)
- Operator Lookup: [`paddle/framework/op_info.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_info.h)
- Optimizer: `fluid.optimizer.SGD`. It does the following
- Add backward operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/backward.py), [C++](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/backward.cc)]
- Add optimizer operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/optimizer.py), [C++](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/optimizer)]
- Add backward operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/backward.py)]
- Add optimizer operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/optimizer.py)]
# Run Time

@ -132,6 +132,8 @@ void MKLDNNLayer::reshapeInput(int& batchsize,
if (w != 0) {
width = w;
}
height = height != 0 ? height : 1;
width = width != 0 ? width : 1;
}
void MKLDNNLayer::reshapeOutput(size_t height, size_t width) {

@ -98,6 +98,8 @@ protected:
public:
explicit MKLDNNLayer(const LayerConfig& config)
: Layer(config),
ih_(0),
iw_(0),
condition_(0),
needResetBwd_(true),
outputOnlyMKLDNN_(false),

@ -1 +1 @@
grpc_library(sendrecvop_grpc SRCS recv_impl.cc send_impl.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)

@ -0,0 +1,147 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "grpc_client.h"
namespace paddle {
namespace operators {
namespace detail {
bool RPCClient::AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
sendrecv::VariableMessage req;
auto* var = scope.FindVar(var_name);
SerializeToMessage(var_name, var, ctx, &req);
// varhandle
VarHandle var_h;
var_h.ep = ep;
var_h.scope = &scope;
var_h.name = var_name;
var_h.ctx = &ctx;
// stub context
auto ch = GetChannel(ep);
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = NULL;
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
return true;
}
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& ret_msg) {
auto* outvar = var_h.scope->FindVar(var_h.name);
std::istringstream iss(ret_msg.serialized());
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
}
bool RPCClient::AsyncGetVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(var_name);
auto* var = scope.FindVar(var_name);
SerializeToMessage(var_name, var, ctx, &req);
// varhandle
VarHandle var_h;
var_h.ep = ep;
var_h.scope = &scope;
var_h.name = var_name;
var_h.ctx = &ctx;
// stub context
auto ch = GetChannel(ep);
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse;
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
return true;
}
bool RPCClient::wait() {
bool ok = true;
while (true) {
if (req_count_ <= 0) {
break;
}
if (!Proceed()) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
}
return ok;
}
bool RPCClient::Proceed() {
void* tag = NULL;
bool ok = false;
// request counts.
if (!cq_.Next(&tag, &ok)) {
return false;
}
req_count_--;
GPR_ASSERT(ok);
PADDLE_ENFORCE(tag);
// TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag);
if (!c->status_.ok()) {
delete c;
return true;
}
c->Process();
delete c;
return true;
}
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto it = channels_.find(ep);
if (it != channels_.end()) {
return it->second;
}
auto ch = std::shared_ptr<grpc::Channel>(
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
channels_[ep] = ch;
return ch;
}
} // namespace detail
} // namespace operators
} // namespace paddle

@ -0,0 +1,147 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <grpc/support/log.h>
#include <time.h>
#include <chrono>
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "paddle/framework/data_type.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
namespace paddle {
namespace operators {
namespace detail {
struct VarHandle {
std::string ep;
const platform::DeviceContext* ctx;
const framework::Scope* scope;
std::string name;
std::string String() const {
std::ostringstream s;
s << "name:[" << name << "] ep:[" << ep << "]";
return s.str();
}
};
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& msg);
class ClientBase {
public:
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
context_ = NULL;
}
virtual ~ClientBase() {}
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
context_.reset(new grpc::ClientContext());
var_h_ = var_info;
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
context_->set_deadline(deadline);
}
virtual void Process() = 0;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
std::unique_ptr<grpc::ClientContext> context_;
grpc::Status status_;
VarHandle var_h_;
};
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
RequestSendCallBack;
class SendProcessor : public ClientBase {
public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
virtual ~SendProcessor() {}
virtual void Process() {
if (response_call_back_) {
response_call_back_(var_h_, reply_);
}
}
sendrecv::VoidMessage reply_;
RequestSendCallBack response_call_back_ = NULL;
};
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
RequestGetCallBack;
class GetProcessor : public ClientBase {
public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
virtual ~GetProcessor() {}
virtual void Process() {
if (response_call_back_) {
response_call_back_(var_h_, reply_);
}
}
sendrecv::VariableMessage reply_;
RequestGetCallBack response_call_back_ = ProcGetResponse;
};
class RPCClient {
public:
bool AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool AsyncGetVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool wait();
private:
bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
private:
grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
int64_t req_count_ = 0;
};
} // namespace detail
} // namespace operators
} // namespace paddle

@ -0,0 +1,237 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/detail/grpc_server.h"
using grpc::ServerAsyncResponseWriter;
namespace paddle {
namespace operators {
namespace detail {
enum CallStatus { PROCESS = 0, FINISH };
// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
public:
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq)
: service_(service), cq_(cq), status_(PROCESS) {}
virtual ~RequestBase() {}
virtual void Process() { assert(false); }
CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
protected:
grpc::ServerContext ctx_;
sendrecv::SendRecvService::AsyncService* service_;
grpc::ServerCompletionQueue* cq_;
CallStatus status_;
};
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestSend final : public RequestBase {
public:
explicit RequestSend(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq,
SimpleBlockQueue<MessageWithName>* queue)
: RequestBase(service, cq), queue_(queue), responder_(&ctx_) {
service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_,
this);
}
virtual ~RequestSend() {}
virtual void Process() {
MessageWithName msg_with_name =
std::make_pair(request_.varname(), std::move(request_));
queue_->Push(std::move(msg_with_name));
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
}
protected:
sendrecv::VariableMessage request_;
sendrecv::VoidMessage reply_;
SimpleBlockQueue<MessageWithName>* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
class RequestGet final : public RequestBase {
public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope)
: RequestBase(service, cq), responder_(&ctx_), scope_(scope) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
}
virtual ~RequestGet() {}
virtual void Process() {
// proc request.
std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name);
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
}
protected:
sendrecv::VariableMessage request_;
sendrecv::VariableMessage reply_;
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_;
};
void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
builder.RegisterService(&service_);
cq_send_ = builder.AddCompletionQueue();
cq_get_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ << std::endl;
std::function<void()> send_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
std::function<void()> get_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
cq_send_.get(), "cq_send", send_register)));
t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
cq_get_.get(), "cq_get", get_register)));
// wait server
server_->Wait();
t_send_->join();
t_get_->join();
}
void AsyncGRPCServer::ShutdownQueue() {
std::unique_lock<std::mutex> lock(cq_mutex_);
cq_send_->Shutdown();
cq_get_->Shutdown();
is_shut_down_ = true;
}
// This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void AsyncGRPCServer::ShutDown() {
server_->Shutdown();
ShutdownQueue();
}
void AsyncGRPCServer::TryToRegisterNewSendOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
return;
}
RequestSend* send =
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
VLOG(4) << "create RequestSend status:" << send->Status();
}
void AsyncGRPCServer::TryToRegisterNewGetOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_);
VLOG(4) << "create Requestget status:" << get->Status();
}
void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
delete last;
last = NULL;
return;
}
last->SetStatus(FINISH);
return;
}
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne();
void* tag = NULL;
bool ok = false;
while (true) {
if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " get CompletionQueue shutdown!";
break;
}
if (wait && !done_) {
Wait();
}
RequestBase* base = (RequestBase*)tag;
if (!ok) {
VLOG(4) << cq_name << " recv no regular event";
TryToRegisterNewOne();
delete base;
continue;
}
switch (base->Status()) {
case PROCESS: {
VLOG(4) << cq_name << " status:" << base->Status();
TryToRegisterNewOne();
base->Process();
SetFinishOrDelete(base);
break;
}
case FINISH: {
VLOG(4) << cq_name << " status:" << base->Status();
delete base;
break;
}
default: { assert(false); }
}
}
}
void AsyncGRPCServer::Wait() {
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
}
void AsyncGRPCServer::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
}
void AsyncGRPCServer::Done() {
{
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true;
}
condition_.notify_all();
}
} // namespace detail
} // namespace operators
} // namespace paddle

@ -0,0 +1,91 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/var_type.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
#include <grpc++/grpc++.h>
#include <grpc/support/log.h>
#include <thread>
#include "paddle/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase;
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
public:
explicit AsyncGRPCServer(std::string address) { address_ = address; }
void RunSyncUpdate();
void Reset();
void Done();
void SetScope(framework::Scope *scope) { scope_ = scope; }
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
void ShutDown();
protected:
void Wait();
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
void SetFinishOrDelete(RequestBase *&last);
void ShutdownQueue();
private:
std::mutex cq_mutex_;
volatile bool is_shut_down_ = false;
std::unique_ptr<grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_get_;
sendrecv::SendRecvService::AsyncService service_;
std::unique_ptr<grpc::Server> server_;
std::string address_;
framework::Scope *scope_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
// condition of the sub program
std::mutex mutex_;
volatile mutable bool done_;
std::condition_variable condition_;
std::unique_ptr<std::thread> t_send_;
std::unique_ptr<std::thread> t_get_;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle

@ -1,65 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "send_recv_impl.h"
namespace paddle {
namespace operators {
namespace detail {
Status SendRecvServerImpl::SendVariable(ServerContext *context,
const VariableMessage *in_var,
VoidMessage *out_var) {
MessageWithName msg_with_name =
std::make_pair(in_var->varname(), std::move(*in_var));
var_recv_queue_.Push(std::move(msg_with_name));
return Status::OK;
}
Status SendRecvServerImpl::GetVariable(ServerContext *context,
const VariableMessage *in_var,
VariableMessage *out_var) {
std::string get_var_name = in_var->varname();
auto *var = scope_->FindVar(get_var_name);
SerializeToMessage(get_var_name, var, platform::CPUDeviceContext(), out_var);
return Status::OK;
}
Status SendRecvServerImpl::Wait(ServerContext *context,
const VoidMessage *in_var,
VoidMessage *out_var) {
{
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
}
return Status::OK;
}
void SendRecvServerImpl::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
}
void SendRecvServerImpl::Done() {
{
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true;
}
condition_.notify_all();
}
} // namespace detail
} // namespace operators
} // namespace paddle

@ -1,67 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "send_recv_impl.h"
namespace paddle {
namespace operators {
namespace detail {
bool RPCClient::SendVariable(const framework::Scope& scope,
const std::string& inname) {
ClientContext context;
VariableMessage msg;
VoidMessage out_msg;
// FIXME(typhoonzero): pass device context to here.
auto ctx = platform::CPUDeviceContext();
auto* var = scope.FindVar(inname);
PADDLE_ENFORCE(var);
SerializeToMessage(inname, var, ctx, &msg);
Status status = stub_->SendVariable(&context, msg, &out_msg);
if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}
return true;
}
bool RPCClient::GetVariable(const framework::Scope& scope,
const std::string& outname) {
ClientContext context;
VariableMessage call_msg, ret_msg;
call_msg.set_varname(outname);
auto ctx = platform::CPUDeviceContext();
Status status = stub_->GetVariable(&context, call_msg, &ret_msg);
auto* outvar = scope.FindVar(outname);
if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}
std::istringstream iss(ret_msg.serialized());
DeserializeFromMessage(ret_msg, ctx, outvar);
return true;
}
void RPCClient::Wait() {
ClientContext context;
VoidMessage call_msg, ret_msg;
stub_->Wait(&context, call_msg, &ret_msg);
}
} // namespace detail
} // namespace operators
} // namespace paddle

@ -21,8 +21,6 @@ service SendRecvService {
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// wait for one execution of the program
rpc Wait(VoidMessage) returns (VoidMessage) {}
}
// VariableMessage is serialized paddle variable message.

@ -1,141 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/var_type.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
#include <grpc++/grpc++.h>
using grpc::Channel;
using grpc::Server;
using grpc::ServerContext;
using grpc::ServerReader;
using grpc::ServerBuilder;
using grpc::ClientContext;
using grpc::ClientReader;
using grpc::ClientReaderWriter;
using grpc::ClientWriter;
using grpc::Status;
using sendrecv::SendRecvService;
using sendrecv::VariableMessage;
using sendrecv::VoidMessage;
namespace paddle {
namespace operators {
namespace detail {
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class SendRecvServerImpl final : public SendRecvService::Service {
public:
explicit SendRecvServerImpl() {}
Status SendVariable(ServerContext *context, const VariableMessage *in_var,
VoidMessage *out_var) override;
Status GetVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override;
Status Wait(ServerContext *context, const VoidMessage *in_var,
VoidMessage *out_var) override;
void Reset();
void Done();
void SetScope(framework::Scope *scope) { scope_ = scope; };
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
private:
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
framework::Scope *scope_;
// condition of the sub program
std::mutex mutex_;
bool done_;
std::condition_variable condition_;
};
// RPCClient is a class to send tensors to pserver sub-network
// using different hashing methods.
class RPCClient {
public:
RPCClient(std::shared_ptr<Channel> channel)
: stub_(SendRecvService::NewStub(channel)) {}
bool SendVariable(const framework::Scope &scope, const std::string &inname);
bool GetVariable(const framework::Scope &scope, const std::string &outname);
void Wait();
private:
std::unique_ptr<SendRecvService::Stub> stub_;
};
inline void SerializeToMessage(const std::string &name,
const framework::Variable *var,
const platform::DeviceContext &ctx,
VariableMessage *msg) {
msg->set_varname(name);
std::ostringstream oss;
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarDesc_VarType_LOD_TENSOR:
msg->set_type(sendrecv::VarType::LOD_TENSOR);
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
break;
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
ctx);
break;
default: {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
break;
}
}
msg->set_serialized(oss.str());
}
inline void DeserializeFromMessage(const VariableMessage &msg,
const platform::DeviceContext &ctx,
framework::Variable *var) {
using namespace paddle::framework::proto;
std::istringstream iss(msg.serialized());
switch (msg.type()) {
case sendrecv::VarType::LOD_TENSOR:
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
break;
case sendrecv::VarType::SELECTED_ROWS: {
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
ctx);
break;
}
default: {
PADDLE_THROW("Deserialize does not support type: %s",
typeid(var->Type()).name());
break;
}
}
}
} // namespace detail
} // namespace operators
} // namespace paddle

@ -0,0 +1,68 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg) {
msg->set_varname(name);
std::ostringstream oss;
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarDesc_VarType_LOD_TENSOR:
msg->set_type(sendrecv::VarType::LOD_TENSOR);
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
break;
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
ctx);
break;
default: {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
break;
}
}
msg->set_serialized(oss.str());
}
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
const platform::DeviceContext& ctx,
framework::Variable* var) {
std::istringstream iss(msg.serialized());
switch (msg.type()) {
case sendrecv::VarType::LOD_TENSOR:
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
break;
case sendrecv::VarType::SELECTED_ROWS: {
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
ctx);
break;
}
default: {
PADDLE_THROW("Deserialize does not support type: %s",
typeid(var->Type()).name());
break;
}
}
}
} // namespace detail
} // namespace operators
} // namespace paddle

@ -0,0 +1,42 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/framework/data_type.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/var_type.h"
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace detail {
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
const platform::DeviceContext& ctx,
framework::Variable* var);
} // namespace detail
} // namespace operators
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save