diff --git a/cmake/generic.cmake b/cmake/generic.cmake
index c749c97f13..3fe750f47e 100644
--- a/cmake/generic.cmake
+++ b/cmake/generic.cmake
@@ -587,6 +587,9 @@ function(grpc_library TARGET_NAME)
get_filename_component(PROTO_WE ${grpc_library_PROTO} NAME_WE)
get_filename_component(PROTO_PATH ${ABS_PROTO} PATH)
+ #FIXME(putcn): the follwoing line is supposed to generate *.pb.h and cc, but
+ # somehow it didn't. line 602 to 604 is to patching this. Leaving this here
+ # for now to enable dist CI.
protobuf_generate_cpp(grpc_proto_srcs grpc_proto_hdrs "${ABS_PROTO}")
set(grpc_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.cc")
set(grpc_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.h")
@@ -597,6 +600,9 @@ function(grpc_library TARGET_NAME)
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" -I "${PROTO_PATH}"
--plugin=protoc-gen-grpc="${GRPC_CPP_PLUGIN}" "${ABS_PROTO}"
+ COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
+ ARGS --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" -I "${PROTO_PATH}"
+ "${ABS_PROTO}"
DEPENDS "${ABS_PROTO}" ${PROTOBUF_PROTOC_EXECUTABLE} extern_grpc)
# FIXME(typhoonzero): grpc generated code do not generate virtual-dtor, mark it
diff --git a/doc/fluid/design/concepts/cpp_data_feeding.md b/doc/fluid/design/concepts/cpp_data_feeding.md
index 8607b40ccb..aabc1ba75a 100644
--- a/doc/fluid/design/concepts/cpp_data_feeding.md
+++ b/doc/fluid/design/concepts/cpp_data_feeding.md
@@ -113,7 +113,7 @@ To solve this problem, we introduce `ReaderHolder` as a wrapper. It acts as an e
To create and invoke readers, some new ops are introduced:
-### CreateReaderOp
+### Operators That Create Readers
Each reader has its creation op. File readers' creation ops have no input and yield the created file reader as its output. Decorated readers' creation ops take the underlying readers as inputs and then yield new decorated readers.
@@ -153,19 +153,52 @@ double_buffer_reader = create_double_buffer_op(batch_reader)
The forwarding ops of the corresponding `main_program` would be like this:
```
-while_op {
+not_completed = true
+pass_count = 0
+while_op(not_completed) {
has_next = has_next_op(double_buffer_reader)
if_else_op(has_next) {
batch_data = read_op(double_buffer_reader)
... (subsequent training ops)
} else {
reset_op(double_buffer_reader)
+ increase_op(pass_count)
+ not_completed = less_than_op(pass_count, reqiured_pass_num)
}
}
```
-Two important considerations for these programs are as follows:
+A few important considerations for these programs are as follows:
-1. The multiple\_reader is the batch\_reader's underlying reader, and the batch\_reader is the double\_buffer\_reader's underlying reader. `read_op`, `has_next_op` and other reader related ops will only invoke the top-most reader. In this case, it's the double\_buffer\_reader.
+1. `not_completed`, `pass_count` and other variables shown above are all Fluid Variables.
-2. All readers exist in both `startup_program` and `main_program`. And they are persistable.
+2. The multiple\_reader is the batch\_reader's underlying reader, and the batch\_reader is the double\_buffer\_reader's underlying reader. `read_op`, `has_next_op` and other reader related ops will only invoke the top-most reader. In this case, it's the double\_buffer\_reader.
+
+3. All readers exist in both `startup_program` and `main_program`. And they are persistable.
+
+### Simplify Configuration by MultiPassReader
+
+The Program configuration mentioned above is complicated. Users need to be very familiar to concepts of Program and Block to prevent making mistakes in their code. To make the usage of C++ readers more friendly to new users, we introduce `MultiPassReader`.
+
+`MultiPassReader` is a decorated reader. A multi-pass reader is used to continuously yield data for several training passes. It takes the number of passes to run as one of its attributes('pass_num') and maintains a counter to record how many passes it has completed. Each time its underlying reader reaches the EOF, the multi-pass reader checks whether it has completed the training of given number of pass. If not, the underlying reader will be re-initialized and starts a new pass automatically. Before completing the whole training, the return of MultiPassReader's `HasNext()` will always be `true`.
+
+With `MultiPassReader`, the startup program would be like this:
+
+```
+multiple_reader = open_files_op(...)
+batch_reader = create_batch_reader_op(multiple_reader)
+multi_pass_reader = create_multi_pass_reader_op(batch_reader)
+double_buffer_reader = create_double_buffer_op(multi_pass_reader)
+... (other initializers)
+```
+
+The forwarding part of the corresponding `main_program` would be like this:
+
+```
+not_completed = true
+while_op(not_completed) {
+ batch_data = read_op(double_buffer_reader)
+ ... (subsequent training ops)
+ not_completed = has_next_op(double_buffer_reader)
+}
+```
diff --git a/doc/fluid/design/concurrent/channel.md b/doc/fluid/design/concurrent/channel.md
new file mode 100644
index 0000000000..a00a3325e7
--- /dev/null
+++ b/doc/fluid/design/concurrent/channel.md
@@ -0,0 +1,139 @@
+# Channel Design
+
+## Introduction
+
+A Channel is a data structure that allows for synchronous interprocess
+communication via message passing. It is a fundemental component of CSP
+(communicating sequential processes), and allows for users to pass data
+between threads without having to worry about synchronization.
+
+## How to use it
+
+Paddle offers python APIs to open and close channels, along with sending
+and receiving data to/from a channel.
+
+### Create a channel
+
+Creates a new channel that takes in variables of a specific dtype.
+
+- **fluid.make_channel(dtype, capacity=0)**
+ - **dtype**: The data type of variables being sent/received through channel
+ - **capacity**: The capacity of the channel. A capacity of 0 represents
+ an unbuffered channel. Capacity > 0 represents a buffered channel
+
+```
+ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR, 10)
+```
+
+### Close a channel
+
+Closes a channel. Any pending senders and receivers will be awoken during
+this time. Receivers can still receive from a closed channel, but senders
+are not allowed to send any additional data to the channel (Paddle will
+raise an exception if users try to send to a closed channel.)
+
+- **fluid.channel_close(channel)**
+
+```
+fluid.channel_close(ch)
+```
+
+### Send data to a channel
+
+Sends a variable to a channel. Currently, variables of dtype `LoDTensor`,
+`LoDRankTable`, `LoDTensorArray`, `SelectedRows`, `ReaderHolder`, and
+`ChannelHolder` are supported.
+
+By default, the data of the Variable is moved from the sender to the receiver,
+however the user can optionally copy the data before performing the send.
+
+- **channel_send(channel, variable, is_copy=False)**
+ - **channel**: The channel to send the variable to
+ - **variable**: The variable to send to the channel
+ - **is_copy**: If set to True, channel_send will perform a variable assign
+ to copy the source variable to a new variable to be sent.
+
+```
+ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
+var = fill_constant(shape=[1],dtype=core.VarDesc.VarType.INT32, value=100)
+fluid.channel_send(ch, var, True)
+```
+
+### Receive data from a channel
+
+Receives a variable from a channel. The data of the variable is moved to the
+receiving variable.
+
+- **channel_recv(channel, return_variable)**
+ - **channel**: The channel to receive the variable from
+ - **return_variable**: The destination variable used to store the data of the
+ variable received from the channel
+
+```
+ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
+var = fill_constant(shape=[1],dtype=core.VarDesc.VarType.INT32, value=-1)
+fluid.channel_recv(ch, var)
+```
+
+## How it Works
+
+Channels provides a simple interface for different threads to share data.
+To support the synchronization requirements, channels utilizes a series of
+internal queues, locks, and conditional variables.
+
+### QueueMessage
+
+QueueMessage encapsulates the state of the channel send/receive operation to be
+put in the **sendq/recvq**. It contains a condition variable used to lock the
+thread (when there are no available sends/receives). In addition, it contains
+a callback function to notify a thread when the QueueMessage is being
+processed by the channel.
+
+### Queues
+
+- **buff_**: This queue holds the data buffer in a buffered channel. The
+capacity is set to the capacity of the channel. This data buffer is not
+used in an unbuffered channel.
+
+- **sendq**: This queue holds the QueueMessage of any pending senders of a
+channel. When a thread performs a channel_send operation on the channel, the
+channel_send operation will put a new QueueMessage on the sendq and block the
+current thread under two conditions:
+ 1. The channel is buffered and is full
+ 2. The channel is unbuffered and does not have a receiver
+
+- **recvq**: This queue holds the QueueMessage of any pending receivers of a
+channel. When a thread performs a channel_recv operation on the channel, the
+channel_recv operation will put a new QueueMessage on the recvq and block the
+current thread under two conditions:
+ 1. The channel is buffered and there is no data on the buff_
+ 2. The channel is unbuffered and does not have a sender
+
+### State diagram
+
+#### Channel Send
+
+
+
+
+
+#### Channel Receive
+
+
+
+
+
+## Limitations and Considerations
+
+### Variable Copy
+
+In golang, variables in channels are copied from the sender to the receiver.
+In Paddle, the data from our variables are **moved** from sender to receiver.
+As a result, these variables should not be used after they are sent. We
+provide a flag in channel_send method to allow users to copy the variable to
+be sent before it is sent.
+
+Please note that this is acheived by adding an **assign** operator and creating
+a temporary variable that is sent in place of the original variable. Please
+note that **assign** operator has limited support for only certain variables
+datatypes.
diff --git a/doc/fluid/design/concurrent/images/channel_recv.png b/doc/fluid/design/concurrent/images/channel_recv.png
new file mode 100644
index 0000000000..c06cd15ae7
Binary files /dev/null and b/doc/fluid/design/concurrent/images/channel_recv.png differ
diff --git a/doc/fluid/design/concurrent/images/channel_send.png b/doc/fluid/design/concurrent/images/channel_send.png
new file mode 100644
index 0000000000..006ebb4a5a
Binary files /dev/null and b/doc/fluid/design/concurrent/images/channel_send.png differ
diff --git a/doc/v2/faq/cluster/index_en.rst b/doc/v2/faq/cluster/index_en.rst
index 855b7e8e53..fa942a0962 100644
--- a/doc/v2/faq/cluster/index_en.rst
+++ b/doc/v2/faq/cluster/index_en.rst
@@ -2,4 +2,15 @@
Cluster Training and Prediction
###############################
-TBD
+.. contents::
+
+1. Network connection errors in the log during multi-node cluster training
+------------------------------------------------
+There are maybe some errors in the log belonging to network connection problem during multi-node cluster training, for example, :code:`Connection reset by peer`.
+This kind of error is usually caused by the abnormal exit of a training process in some node, and the other nodes cannot connect with this node any longer. Steps to troubleshoot the problem are as follows:
+
+* Find the first error in the :code:`train.log`, :code:`server.log`, check whether other fault casued the problem, such as FPE, lacking of memory or disk.
+
+* If the first error in server.log says "Address already used", this may be caused by the port conflict of the non-exclusive execution. Connect the sys-admin to check if the current MPI cluster supports jobs submitted with parameter :code:`resource=full`. If the current MPI cluster does not support this parameter, change the server port and try agian.
+
+* If the current MPI cluster does not support exclusive pattern which allows a process to occupy the whole node, ask the administrator to replace or update the this cluster.
diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index a4ea74a6d2..8c8def6bf4 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -100,7 +100,7 @@ cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
-cc_test(channel_test SRCS channel_test.cc)
+# cc_test(channel_test SRCS channel_test.cc)
cc_test(tuple_test SRCS tuple_test.cc )
cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc
index 3693bc25d8..fbe08349c3 100644
--- a/paddle/fluid/framework/block_desc.cc
+++ b/paddle/fluid/framework/block_desc.cc
@@ -147,15 +147,52 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return;
}
+ auto get_vars = [](std::deque>::iterator &op,
+ std::vector &v) {
+ auto in_names = (*op)->InputArgumentNames();
+ v.insert(v.end(), in_names.begin(), in_names.end());
+ auto out_names = (*op)->OutputArgumentNames();
+ v.insert(v.end(), out_names.begin(), out_names.end());
+ std::sort(v.begin(), v.end());
+ auto last = std::unique(v.begin(), v.end());
+ v.erase(last, v.end());
+ };
need_update_ = true;
- for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
- auto names = (*it)->InputArgumentNames();
- for (auto n : names) {
- // TODO(typhoonzero): delete vars if no other op use it.
- VLOG(3) << "deleting var " << n;
+
+ for (size_t i = s; i < e; i++) {
+ // since remove op one by one, every time remove the first op.
+ auto op = ops_.begin() + s;
+
+ // collect input and output variables from current delete op
+ std::vector cur_vars;
+ get_vars(op, cur_vars);
+
+ // remove current op
+ ops_.erase(ops_.begin() + s);
+
+ // collect input and output variables from other ops
+ std::vector other_vars;
+ for (auto it = ops_.begin(); it != ops_.end(); it++) {
+ get_vars(it, other_vars);
+ }
+
+ // variables should be deleted
+ std::vector delete_vars;
+ // delete_vars = cur_vars - cur_vars ^ other_input_vars
+ std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(),
+ other_vars.end(),
+ std::inserter(delete_vars, delete_vars.end()));
+ // remove variables
+ for (size_t i = 0; i < delete_vars.size(); i++) {
+ auto name = delete_vars[i];
+ auto it = vars_.find(name);
+ PADDLE_ENFORCE(it != vars_.end(),
+ "%s is not in variable list, it should not be deleted",
+ name);
+ vars_.erase(it);
+ VLOG(3) << "deleting variable " << name;
}
}
- ops_.erase(ops_.begin() + s, ops_.begin() + e);
}
std::vector BlockDesc::AllOps() const {
diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h
index 185f018ac1..468423e0e8 100644
--- a/paddle/fluid/framework/block_desc.h
+++ b/paddle/fluid/framework/block_desc.h
@@ -89,6 +89,11 @@ class BlockDesc {
OpDesc *InsertOp(size_t index);
+ /*
+ * Remove Op and its input/output variables.
+ * Note that for either input or ouput variable, if it is also an input or
+ * output variable of other ops, we should remain it.
+ */
void RemoveOp(size_t s, size_t e);
std::vector AllOps() const;
diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc
index 979115eee0..a6d9ce0f04 100644
--- a/paddle/fluid/operators/activation_op.cc
+++ b/paddle/fluid/operators/activation_op.cc
@@ -260,6 +260,36 @@ $out = floor(x)$
}
};
+class CosOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ CosOpMaker(OpProto *proto, OpAttrChecker *op_checker)
+ : framework::OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "Input of Cosine operator");
+ AddOutput("Out", "Output of Cosine operator");
+ AddComment(R"DOC(
+Cosine Activation Operator.
+
+$out = cos(x)$
+
+)DOC");
+ }
+};
+
+class SinOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ SinOpMaker(OpProto *proto, OpAttrChecker *op_checker)
+ : framework::OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "Input of Sine operator");
+ AddOutput("Out", "Output of Sine operator");
+ AddComment(R"DOC(
+Sine Activation Operator.
+
+$out = sin(x)$
+
+)DOC");
+ }
+};
+
class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker)
@@ -561,6 +591,12 @@ REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad,
REGISTER_OP(floor, ops::ActivationOp, ops::FloorOpMaker, floor_grad,
ops::ActivationOpGrad);
+REGISTER_OP(cos, ops::ActivationOp, ops::CosOpMaker, cos_grad,
+ ops::ActivationOpGrad);
+
+REGISTER_OP(sin, ops::ActivationOp, ops::SinOpMaker, sin_grad,
+ ops::ActivationOpGrad);
+
REGISTER_OP(round, ops::ActivationOp, ops::RoundOpMaker, round_grad,
ops::ActivationOpGrad);
diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h
index 4c575b4a7b..7fbe4efc04 100644
--- a/paddle/fluid/operators/activation_op.h
+++ b/paddle/fluid/operators/activation_op.h
@@ -331,6 +331,54 @@ struct FloorFunctor : public BaseActivationFunctor {
}
};
+template
+struct Sine {
+ HOSTDEVICE T operator()(const T& val) const { return sin(val); }
+};
+
+template
+struct Cosine {
+ HOSTDEVICE T operator()(const T& val) const { return cos(val); }
+};
+
+// cosine'(x) = -sin(x)
+template
+struct CosGradFunctor : public BaseActivationFunctor {
+ template
+ void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
+ dx.device(d) = -dout * x.unaryExpr(Sine());
+ }
+};
+
+// cosine(x) = cos(x)
+template
+struct CosFunctor : public BaseActivationFunctor {
+ template
+ void operator()(Device d, X x, Out out) const {
+ out.device(d) = x.unaryExpr(Cosine());
+ }
+};
+
+// sine'(x) = cos(x)
+template
+struct SinGradFunctor : public BaseActivationFunctor {
+ template
+ void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
+ dx.device(d) = dout * x.unaryExpr(Cosine());
+ }
+};
+
+// sine(x) = sin(x)
+template
+struct SinFunctor : public BaseActivationFunctor {
+ template
+ void operator()(Device d, X x, Out out) const {
+ out.device(d) = x.unaryExpr(Sine());
+ }
+};
+
// round(x) = [x]
template
struct RoundFunctor : public BaseActivationFunctor {
@@ -782,6 +830,8 @@ struct SwishGradFunctor : public BaseActivationFunctor {
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, FloorFunctor, ZeroGradFunctor); \
+ __macro(cos, CosFunctor, CosGradFunctor); \
+ __macro(sin, SinFunctor, SinGradFunctor); \
__macro(round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc
index e73bbe7537..03b789f326 100644
--- a/paddle/fluid/operators/detail/grpc_client.cc
+++ b/paddle/fluid/operators/detail/grpc_client.cc
@@ -204,7 +204,6 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) {
}
grpc::ChannelArguments args;
- args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits::max());
args.SetMaxReceiveMessageSize(std::numeric_limits::max());
diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto
index 598aaa4c51..2d33f026e4 100644
--- a/paddle/fluid/operators/detail/send_recv.proto
+++ b/paddle/fluid/operators/detail/send_recv.proto
@@ -59,12 +59,12 @@ message VariableMessage {
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
+ // selected_rows height, aka. original dim0
+ int64 slr_height = 7;
// tensor data
- bytes serialized = 7;
+ bytes serialized = 8;
// selected_rows data
- bytes rows = 8;
+ bytes rows = 9;
}
message VoidMessage {}
-
-message TestMessage { int64 test_1 = 1; }
diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc
index d7bbf79c50..7e3f015dab 100644
--- a/paddle/fluid/operators/detail/sendrecvop_utils.cc
+++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc
@@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
+ e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
@@ -154,7 +155,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
ProtoEncodeHelper e2((char*)buf, 128);
// NOTE: rows is of type int64_t
size_t rows_memory_size =
- slr->rows().capacity() * framework::SizeOfType(typeid(int64_t));
+ slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast(slices[2].begin()), e2.data(), e2.size());
diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h
index 3b87562703..b3b2b8469c 100644
--- a/paddle/fluid/operators/detail/sendrecvop_utils.h
+++ b/paddle/fluid/operators/detail/sendrecvop_utils.h
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include
#include
#include
#include
@@ -35,6 +36,12 @@ namespace detail {
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
+static int64_t GetTimestamp() {
+ struct timeval tp;
+ gettimeofday(&tp, NULL);
+ return tp.tv_sec * 1000 + tp.tv_usec / 1000;
+}
+
typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
diff --git a/paddle/fluid/operators/detail/test_serde.cc b/paddle/fluid/operators/detail/test_serde.cc
index e646c894d1..ea1670e56f 100644
--- a/paddle/fluid/operators/detail/test_serde.cc
+++ b/paddle/fluid/operators/detail/test_serde.cc
@@ -40,14 +40,14 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable();
+ slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
- tensor->Resize(framework::make_ddim({2, 10}));
+ tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data(place);
- int tensor_numel = 2 * 10;
+ int tensor_numel = 564 * 128;
math::set_constant(ctx, tensor, 32.7);
- rows->push_back(3);
- rows->push_back(10);
+ for (int i = 0; i < 564; ++i) rows->push_back(i);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
@@ -64,6 +64,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
+ // deserialize bytebuffer
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1);
@@ -74,8 +75,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
}
- EXPECT_EQ(rows_data[0], 3);
- EXPECT_EQ(rows_data[1], 10);
+ for (int i = 0; i < 564; ++i) {
+ EXPECT_EQ(rows_data[i], i);
+ }
+
// deserialize zero-copy
// framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
@@ -104,8 +107,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
- EXPECT_EQ(rows_data2[0], 3);
- EXPECT_EQ(rows_data2[1], 10);
+ for (int i = 0; i < rows2->size(); ++i) {
+ EXPECT_EQ(rows_data2[i], i);
+ }
+ EXPECT_EQ(slr2->height(), 1000);
}
void RunTestLodTensor(platform::Place place, int from_type = 0) {
diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc
index 12e8eb0b4d..f59c9b50bb 100644
--- a/paddle/fluid/operators/detail/variable_response.cc
+++ b/paddle/fluid/operators/detail/variable_response.cc
@@ -48,6 +48,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
void* dest, int size) {
const void* data = NULL;
int size_to_write = 0;
+ int length = size;
+ int total_written = 0;
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
@@ -56,16 +58,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
platform::CPUPlace cpu;
char* p = reinterpret_cast(dest);
- while (size > 0) {
+ while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
-
+ // NOTE: if raw buffer is large and have two neighbor fields of raw
+ // buffers GetDirectBufferPointer can get all of them, use length to
+ // truncate it.
+ if (total_written + size_to_write > length) {
+ size_to_write = length - total_written;
+ }
memory::Copy(boost::get(place),
reinterpret_cast(p), cpu, data, size_to_write,
gpu_dev_ctx.stream());
p += size_to_write;
- size -= size_to_write;
+ total_written += size_to_write;
input->Skip(size_to_write);
}
@@ -77,16 +84,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
}
char* p = reinterpret_cast(dest);
- while (size > 0) {
+ while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
+ // NOTE: if raw buffer is large and have two neighbor fields of raw buffers
+ // GetDirectBufferPointer can get all of them, use length to truncate it.
+ if (total_written + size_to_write > length) {
+ size_to_write = length - total_written;
+ }
// TODO(gongwb): can we avoid copy?
platform::CPUPlace cpu;
memory::Copy(cpu, reinterpret_cast(p), cpu, data, size_to_write);
p += size_to_write;
- size -= size_to_write;
+ total_written += size_to_write;
input->Skip(size_to_write);
}
@@ -135,8 +147,13 @@ bool VariableResponse::CopySelectRowsTensorData(
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable();
+ slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
+ PADDLE_ENFORCE_EQ(
+ tensor->numel(),
+ length / framework::SizeOfType(
+ paddle::operators::detail::ToTypeIndex(meta_.data_type())));
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type()));
@@ -153,6 +170,8 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable();
+ slr->mutable_rows()->resize(length /
+ framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
@@ -233,7 +252,6 @@ int VariableResponse::Parse(Source* source) {
if (tag != 0) {
return -1;
}
-
return 0;
}
@@ -336,6 +354,14 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
+ case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
+ uint64_t v = 0;
+ if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
+ return tag;
+ }
+ meta_.set_slr_height(static_cast(v));
+ break;
+ }
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc
index 08b83375dd..9796fabdb6 100644
--- a/paddle/fluid/operators/listen_and_serv_op.cc
+++ b/paddle/fluid/operators/listen_and_serv_op.cc
@@ -141,6 +141,7 @@ class ListenAndServOp : public framework::OperatorBase {
// and this will still work.
std::vector> fs;
+ double ts = detail::GetTimestamp();
// block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(
@@ -162,6 +163,7 @@ class ListenAndServOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what();
}
}
+ VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc
index f7a6f2bdf4..5ae42ab973 100644
--- a/paddle/fluid/operators/math/sequence_pooling.cc
+++ b/paddle/fluid/operators/math/sequence_pooling.cc
@@ -19,8 +19,17 @@ namespace paddle {
namespace operators {
namespace math {
+using Tensor = framework::Tensor;
+using LoDTensor = framework::LoDTensor;
+template
+using EigenVector = framework::EigenVector;
+template
+using EigenMatrix = framework::EigenMatrix;
+
template
-class MaxSeqPoolFunctor {
+class MaxSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
@@ -60,7 +69,7 @@ class MaxSeqPoolFunctor {
};
template
-class MaxSeqPoolGradFunctor {
+class MaxSeqPoolGradFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& out_grad,
@@ -93,10 +102,101 @@ class MaxSeqPoolGradFunctor {
}
};
-template class MaxSeqPoolFunctor;
-template class MaxSeqPoolFunctor;
-template class MaxSeqPoolGradFunctor;
-template class MaxSeqPoolGradFunctor;
+template
+class SequencePoolFunctor {
+ public:
+ /* max pool has index output */
+ void operator()(const platform::CPUDeviceContext& context,
+ const std::string pooltype, const framework::LoDTensor& input,
+ framework::Tensor* output,
+ framework::Tensor* index = nullptr) {
+ if (pooltype == "MAX") {
+ math::MaxSeqPoolFunctor max_pool;
+ max_pool(context, input, output, index);
+ return;
+ }
+ auto lod = input.lod()[0];
+ auto& place = *context.eigen_device();
+ for (int i = 0; i < static_cast(lod.size()) - 1; ++i) {
+ Tensor in_t =
+ input.Slice(static_cast(lod[i]), static_cast(lod[i + 1]));
+ Tensor out_t = output->Slice(i, i + 1);
+ int64_t h = static_cast(lod[i + 1] - lod[i]);
+ int64_t w = input.numel() / input.dims()[0];
+ auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w}));
+ auto out_e = EigenVector::Flatten(out_t);
+ if (pooltype == "AVERAGE") {
+ out_e.device(place) = in_e.mean(Eigen::array({{0}}));
+ } else if (pooltype == "SUM") {
+ out_e.device(place) = in_e.sum(Eigen::array({{0}}));
+ } else if (pooltype == "SQRT") {
+ out_e.device(place) = in_e.sum(Eigen::array({{0}})) /
+ std::sqrt(static_cast(h));
+ } else if (pooltype == "LAST") {
+ out_e.device(place) = in_e.chip(h - 1, 0);
+ } else if (pooltype == "FIRST") {
+ out_e.device(place) = in_e.chip(0, 0);
+ } else {
+ PADDLE_THROW("unsupported pooling pooltype");
+ }
+ }
+ }
+};
+
+template
+class SequencePoolGradFunctor {
+ public:
+ void operator()(const platform::CPUDeviceContext& context,
+ const std::string pooltype, const framework::Tensor& out_grad,
+ framework::LoDTensor* in_grad,
+ /* max pool has index */
+ const framework::Tensor* index = nullptr) {
+ if (pooltype == "MAX") {
+ math::MaxSeqPoolGradFunctor max_pool_grad;
+ max_pool_grad(context, out_grad, *index, in_grad);
+ return;
+ }
+
+ if (pooltype == "LAST" || pooltype == "FIRST") {
+ // set X@Grad be zero at first when pooltype is LAST/FIRST
+ math::SetConstant functor;
+ functor(context, in_grad, 0);
+ }
+ auto lod = in_grad->lod()[0];
+ auto& place = *context.eigen_device();
+ for (int i = 0; i < static_cast(lod.size()) - 1; ++i) {
+ auto in_g_t = in_grad->Slice(static_cast(lod[i]),
+ static_cast(lod[i + 1]));
+ auto out_g_t = out_grad.Slice(i, i + 1);
+ int64_t h = static_cast(lod[i + 1] - lod[i]);
+ int64_t w = in_grad->numel() / in_grad->dims()[0];
+ auto in_g_e = EigenMatrix::From(in_g_t, {h, w});
+ auto out_g_e = EigenMatrix::From(out_g_t, {1, w});
+ auto out_g_e_v = EigenVector::Flatten(out_g_t);
+ Eigen::DSizes bcast(h, 1);
+
+ if (pooltype == "AVERAGE") {
+ in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast);
+ } else if (pooltype == "SUM") {
+ in_g_e.device(place) = (out_g_e).broadcast(bcast);
+ } else if (pooltype == "SQRT") {
+ in_g_e.device(place) =
+ (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast);
+ } else if (pooltype == "LAST") {
+ in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
+ } else if (pooltype == "FIRST") {
+ in_g_e.chip(0, 0).device(place) = out_g_e_v;
+ } else {
+ PADDLE_THROW("unsupported pooling pooltype");
+ }
+ }
+ }
+};
+
+template class SequencePoolFunctor;
+template class SequencePoolFunctor;
+template class SequencePoolGradFunctor;
+template class SequencePoolGradFunctor;
} // namespace math
} // namespace operators
diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu
index d61407c020..1935364da3 100644
--- a/paddle/fluid/operators/math/sequence_pooling.cu
+++ b/paddle/fluid/operators/math/sequence_pooling.cu
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h"
+#include "paddle/fluid/platform/cuda_helper.h"
namespace paddle {
namespace operators {
@@ -22,113 +23,331 @@ namespace math {
#define FLT_MAX __FLT_MAX__
template
-__global__ void KeMaxSequencePool(const T* input, const size_t* starts,
- T* output, int* index, int64_t num_seq,
- int64_t dim) {
- int dim_idx = threadIdx.x;
- int seq_id = blockIdx.x;
- if (seq_id >= num_seq) return;
- size_t start = starts[seq_id];
- size_t end = starts[seq_id + 1];
-
- for (int64_t i = dim_idx; i < dim; i += blockDim.x) {
- T max_val = static_cast(-FLT_MAX);
- int max_id = -1;
- for (size_t step_id = start; step_id < end; step_id++) {
- if (max_val < input[step_id * dim + i]) {
- max_val = input[step_id * dim + i];
- max_id = step_id;
+struct MaxPoolFunctor {
+ HOSTDEVICE void operator()(const T* input, const size_t start,
+ const size_t end, const size_t item_dim, T* output,
+ int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ T max_val = static_cast(-FLT_MAX);
+ int max_index = -1;
+ for (int i = start; i < end; ++i) {
+ if (max_val < input[item_dim * i + tid]) {
+ max_val = input[item_dim * i + tid];
+ max_index = i;
+ }
}
+ output[tid] = max_val;
+ index[tid] = max_index;
}
- output[seq_id * dim + i] = max_val;
- index[seq_id * dim + i] = max_id;
}
-}
+};
template
-class MaxSeqPoolFunctor {
- public:
- void operator()(const platform::CUDADeviceContext& context,
- const framework::LoDTensor& input, framework::Tensor* output,
- framework::Tensor* index) {
- auto in_dims = input.dims();
- auto out_dims = output->dims();
- auto idx_dims = index->dims();
- PADDLE_ENFORCE_GT(in_dims.size(), static_cast(1));
- PADDLE_ENFORCE_GT(out_dims.size(), 1);
- for (int64_t i = 1; i < in_dims.size(); ++i) {
- PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
+struct AvgPoolFunctor {
+ HOSTDEVICE void operator()(const T* input, const size_t start,
+ const size_t end, const size_t item_dim, T* output,
+ int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ T val = static_cast(0);
+ for (int i = start; i < end; ++i) {
+ val += input[item_dim * i + tid];
+ }
+ // end, start is lod, so end - start != 0
+ output[tid] = val / static_cast(end - start);
}
- PADDLE_ENFORCE_EQ(idx_dims, out_dims);
+ }
+};
- auto starts = input.lod()[0];
- const T* in_data = input.data();
- T* out_data = output->data();
- int* max_index = index->data();
+template
+struct SumPoolFunctor {
+ HOSTDEVICE void operator()(const T* input, const size_t start,
+ const size_t end, const size_t item_dim, T* output,
+ int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ T val = static_cast(0);
+ for (int i = start; i < end; ++i) {
+ val += input[item_dim * i + tid];
+ }
+ output[tid] = val;
+ }
+ }
+};
- int64_t num_seq = out_dims[0];
- int64_t dim = output->numel() / num_seq;
+template
+struct SqrtPoolFunctor {
+ HOSTDEVICE void operator()(const T* input, const size_t start,
+ const size_t end, const size_t item_dim, T* output,
+ int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ T val = static_cast(0);
+ for (int i = start; i < end; ++i) {
+ val += input[item_dim * i + tid];
+ }
+ // end, start is lod, so end - start != 0
+ output[tid] = val / sqrt(end - start);
+ }
+ }
+};
- dim3 threads(256, 1);
- dim3 grid(num_seq, 1);
- auto stream = context.stream();
- KeMaxSequencePool<<>>(
- in_data, starts.CUDAData(context.GetPlace()), out_data, max_index,
- num_seq, dim);
+template
+struct LastPoolFunctor {
+ HOSTDEVICE void operator()(const T* input, const size_t start,
+ const size_t end, const size_t item_dim, T* output,
+ int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ output[tid] = input[item_dim * (end - 1) + tid];
+ }
}
};
template
-__global__ void KeMaxSequencePoolGrad(const T* out_grad, const int* max_index,
- T* in_grad, int64_t num_seq,
- int64_t dim) {
- int idx = threadIdx.x + blockIdx.x * blockDim.x;
- int col_idx = idx % dim;
- if (idx < num_seq * dim) {
- int step_id = max_index[idx];
- in_grad[step_id * dim + col_idx] = out_grad[idx];
+struct FirstPoolFunctor {
+ HOSTDEVICE void operator()(const T* input, const size_t start,
+ const size_t end, const size_t item_dim, T* output,
+ int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ output[tid] = input[item_dim * start + tid];
+ }
}
+};
+
+template
+__global__ void sequence_pool_kernel(Range_OP op, const T* input,
+ const size_t* lod, const size_t lod_size,
+ const size_t item_dim, T* output,
+ int* index) {
+ int bid = blockIdx.x;
+ if (bid >= lod_size - 1) return;
+ size_t start = lod[bid];
+ size_t end = lod[bid + 1];
+ int* index_offset = nullptr;
+ if (index != nullptr) {
+ index_offset = &index[bid * item_dim];
+ }
+ op(input, start, end, item_dim, &output[bid * item_dim], index_offset);
}
template
-class MaxSeqPoolGradFunctor {
+class SequencePoolFunctor {
public:
void operator()(const platform::CUDADeviceContext& context,
- const framework::Tensor& out_grad,
- const framework::Tensor& index,
- framework::LoDTensor* in_grad) {
- auto og_dims = out_grad.dims();
- auto idx_dims = index.dims();
- auto ig_dims = in_grad->dims();
- PADDLE_ENFORCE_GT(og_dims.size(), static_cast(1));
- PADDLE_ENFORCE_GT(ig_dims.size(), static_cast(1));
- for (int64_t i = 1; i < og_dims.size(); ++i) {
- PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]);
+ const std::string pooltype, const framework::LoDTensor& input,
+ framework::Tensor* output,
+ framework::Tensor* index = nullptr) {
+ auto lod = input.lod()[0];
+ const size_t item_dim = output->numel() / output->dims()[0];
+ dim3 threads(1024, 1);
+ dim3 grid(lod.size(), 1);
+ if (pooltype == "MAX") {
+ sequence_pool_kernel<
+ T, MaxPoolFunctor><<>>(
+ MaxPoolFunctor(), input.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ output->mutable_data(context.GetPlace()), index->data());
+ } else if (pooltype == "AVERAGE") {
+ sequence_pool_kernel<
+ T, AvgPoolFunctor><<>>(
+ AvgPoolFunctor(), input.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ output->mutable_data(context.GetPlace()), nullptr);
+ } else if (pooltype == "SUM") {
+ sequence_pool_kernel<
+ T, SumPoolFunctor><<>>(
+ SumPoolFunctor(), input.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ output->mutable_data(context.GetPlace()), nullptr);
+ } else if (pooltype == "SQRT") {
+ sequence_pool_kernel<
+ T, SqrtPoolFunctor><<>>(
+ SqrtPoolFunctor(), input.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ output->mutable_data(context.GetPlace()), nullptr);
+ } else if (pooltype == "LAST") {
+ sequence_pool_kernel<
+ T, LastPoolFunctor><<>>(
+ LastPoolFunctor(), input.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ output->mutable_data(context.GetPlace()), nullptr);
+ } else if (pooltype == "FIRST") {
+ sequence_pool_kernel<
+ T, FirstPoolFunctor><<>>(
+ FirstPoolFunctor(), input.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ output->mutable_data(context.GetPlace()), nullptr);
+ } else {
+ PADDLE_THROW("unsupported pooling pooltype");
}
- PADDLE_ENFORCE_EQ(idx_dims, og_dims);
+ }
+};
- const T* og_data = out_grad.data();
- const int* max_index = index.data();
- T* ig_data = in_grad->data();
+template
+struct MaxPoolGradFunctor {
+ HOSTDEVICE void operator()(const T* out_grad, const size_t start,
+ const size_t end, const size_t item_dim,
+ T* in_grad, const int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ for (int i = start; i < end; ++i) {
+ if (i == index[tid]) {
+ in_grad[item_dim * i + tid] = out_grad[tid];
+ } else {
+ in_grad[item_dim * i + tid] = static_cast(0);
+ }
+ }
+ }
+ }
+};
- SetConstant set_zero;
- set_zero(context, in_grad, static_cast(0.0));
- int64_t num_seq = og_dims[0];
- int64_t dim = out_grad.numel() / num_seq;
+template
+struct AvgPoolGradFunctor {
+ HOSTDEVICE void operator()(const T* out_grad, const size_t start,
+ const size_t end, const size_t item_dim,
+ T* in_grad, const int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ for (int i = start; i < end; ++i) {
+ in_grad[item_dim * i + tid] = out_grad[tid] / (end - start);
+ }
+ }
+ }
+};
- unsigned int blocks = (num_seq * dim + 128 - 1) / 128;
- dim3 threads(128, 1);
- dim3 grid(blocks, 1);
- auto stream = context.stream();
- KeMaxSequencePoolGrad<<>>(
- og_data, max_index, ig_data, num_seq, dim);
+template
+struct SumPoolGradFunctor {
+ HOSTDEVICE void operator()(const T* out_grad, const size_t start,
+ const size_t end, const size_t item_dim,
+ T* in_grad, const int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ for (int i = start; i < end; ++i) {
+ in_grad[item_dim * i + tid] = out_grad[tid];
+ }
+ }
+ }
+};
+
+template
+struct SqrtPoolGradFunctor {
+ HOSTDEVICE void operator()(const T* out_grad, const size_t start,
+ const size_t end, const size_t item_dim,
+ T* in_grad, const int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ for (int i = start; i < end; ++i) {
+ in_grad[item_dim * i + tid] =
+ out_grad[tid] / (sqrt(static_cast(end - start)));
+ }
+ }
+ }
+};
+
+template
+struct LastPoolGradFunctor {
+ HOSTDEVICE void operator()(const T* out_grad, const size_t start,
+ const size_t end, const size_t item_dim,
+ T* in_grad, const int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ for (int i = start; i < end; ++i) {
+ if (i == end - 1) {
+ in_grad[item_dim * i + tid] = out_grad[tid];
+ } else {
+ in_grad[item_dim * i + tid] = static_cast(0);
+ }
+ }
+ }
+ }
+};
+
+template
+struct FirstPoolGradFunctor {
+ HOSTDEVICE void operator()(const T* out_grad, const size_t start,
+ const size_t end, const size_t item_dim,
+ T* in_grad, const int* index) {
+ for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
+ for (int i = start; i < end; ++i) {
+ if (i == start) {
+ in_grad[item_dim * i + tid] = out_grad[tid];
+ } else {
+ in_grad[item_dim * i + tid] = static_cast(0);
+ }
+ }
+ }
+ }
+};
+
+template
+__global__ void sequence_pool_grad_kernel(Range_OP op, const T* out_grad,
+ const size_t* lod,
+ const size_t lod_size,
+ const size_t item_dim, T* in_grad,
+ const int* index) {
+ int bid = blockIdx.x;
+ if (bid >= lod_size - 1) return;
+ size_t start = lod[bid];
+ size_t end = lod[bid + 1];
+ const int* index_offset = nullptr;
+ if (index != nullptr) {
+ index_offset = &index[bid * item_dim];
+ }
+ op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset);
+}
+
+template
+class SequencePoolGradFunctor {
+ public:
+ void operator()(const platform::CUDADeviceContext& context,
+ const std::string pooltype, const framework::Tensor& out_grad,
+ framework::LoDTensor* in_grad,
+ /* max pool has index */
+ const framework::Tensor* index = nullptr) {
+ auto lod = in_grad->lod()[0];
+ const size_t item_dim = in_grad->numel() / in_grad->dims()[0];
+ dim3 threads(1024, 1);
+ dim3 grid(lod.size(), 1);
+ if (pooltype == "MAX") {
+ sequence_pool_grad_kernel<
+ T, MaxPoolGradFunctor><<>>(
+ MaxPoolGradFunctor(), out_grad.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ in_grad->mutable_data(context.GetPlace()), index->data());
+ } else if (pooltype == "AVERAGE") {
+ sequence_pool_grad_kernel<
+ T, AvgPoolGradFunctor><<>>(
+ AvgPoolGradFunctor(), out_grad.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ in_grad->mutable_data(context.GetPlace()), nullptr);
+ } else if (pooltype == "SUM") {
+ sequence_pool_grad_kernel<
+ T, SumPoolGradFunctor><<>>(
+ SumPoolGradFunctor(), out_grad.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ in_grad->mutable_data(context.GetPlace()), nullptr);
+ } else if (pooltype == "SQRT") {
+ sequence_pool_grad_kernel<
+ T, SqrtPoolGradFunctor><<>>(
+ SqrtPoolGradFunctor(), out_grad.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ in_grad->mutable_data(context.GetPlace()), nullptr);
+ } else if (pooltype == "LAST") {
+ sequence_pool_grad_kernel<
+ T, LastPoolGradFunctor><<>>(
+ LastPoolGradFunctor(), out_grad.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ in_grad->mutable_data(context.GetPlace()), nullptr);
+ } else if (pooltype == "FIRST") {
+ sequence_pool_grad_kernel<
+ T, FirstPoolGradFunctor><<>>(
+ FirstPoolGradFunctor(), out_grad.data(),
+ lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
+ in_grad->mutable_data(context.GetPlace()), nullptr);
+
+ } else {
+ PADDLE_THROW("unsupported pooling pooltype");
+ }
}
};
-template class MaxSeqPoolFunctor;
-template class MaxSeqPoolFunctor;
-template class MaxSeqPoolGradFunctor;
-template class MaxSeqPoolGradFunctor;
+// sequence pooling
+template class SequencePoolFunctor;
+template class SequencePoolFunctor;
+template class SequencePoolGradFunctor;
+template class SequencePoolGradFunctor;
} // namespace math
} // namespace operators
diff --git a/paddle/fluid/operators/math/sequence_pooling.h b/paddle/fluid/operators/math/sequence_pooling.h
index ecb76884f6..38e7802229 100644
--- a/paddle/fluid/operators/math/sequence_pooling.h
+++ b/paddle/fluid/operators/math/sequence_pooling.h
@@ -21,23 +21,23 @@ namespace paddle {
namespace operators {
namespace math {
-#define FLT_MAX __FLT_MAX__
-
template
-class MaxSeqPoolFunctor {
+class SequencePoolFunctor {
public:
- void operator()(const DeviceContext& context,
+ /* max pool has index output */
+ void operator()(const DeviceContext& context, const std::string pooltype,
const framework::LoDTensor& input, framework::Tensor* output,
- framework::Tensor* index);
+ framework::Tensor* index = nullptr);
};
-template
-class MaxSeqPoolGradFunctor {
+template
+class SequencePoolGradFunctor {
public:
- void operator()(const DeviceContext& context,
+ void operator()(const DeviceContext& context, const std::string pooltype,
const framework::Tensor& out_grad,
- const framework::Tensor& index,
- framework::LoDTensor* in_grad);
+ framework::LoDTensor* in_grad,
+ /* max pool has index */
+ const framework::Tensor* index = nullptr);
};
} // namespace math
diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc
index 4001b9a130..b28c16b13f 100644
--- a/paddle/fluid/operators/parallel_do_op.cc
+++ b/paddle/fluid/operators/parallel_do_op.cc
@@ -144,7 +144,12 @@ class ParallelDoOp : public framework::OperatorBase {
PADDLE_ENFORCE(scope.FindVar(param)->IsType(),
"Only support parameter type as LoDTensor");
auto &src = scope.FindVar(param)->Get();
- for (size_t i = 0; i < sub_scopes.size(); ++i) {
+
+ auto *sub_scope0 = sub_scopes[0];
+ auto *dst0 = sub_scope0->Var(param)->GetMutable();
+ dst0->ShareDataWith(src);
+
+ for (size_t i = 1; i < sub_scopes.size(); ++i) {
auto &place = places[i];
auto *sub_scope = sub_scopes[i];
auto *dst = sub_scope->Var(param)->GetMutable();
diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
index 4d4e9fb909..47d9989bc8 100644
--- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
@@ -81,10 +81,10 @@ class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
This operator creates a multi-pass reader. A multi-pass reader
is used to yield data for several pass training continuously.
- It takes the the number of pass to run as one of its attributes
+ It takes the number of passes to run as one of its attributes
('pass_num'), and maintains a pass counter to record how many
- passes it has completed. When the underlying reader reach the EOF,
- the multi-pass reader checks whether it has completed training
+ passes it has completed. When the underlying reader reaches the
+ EOF, the multi-pass reader checks whether it has completed training
of the given number of pass. If not, the underlying reader will
be re-initialized and starts a new pass automatically.
)DOC");
diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc
index fdf3c06ef0..0752bd1bbd 100644
--- a/paddle/fluid/operators/send_op.cc
+++ b/paddle/fluid/operators/send_op.cc
@@ -72,7 +72,7 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
- VLOG(2) << "sending " << ins[i] << " to " << epmap[i];
+ VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
@@ -81,7 +81,7 @@ class SendOp : public framework::OperatorBase {
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) {
- VLOG(2) << "batch barrier, ep: " << ep;
+ VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
diff --git a/paddle/fluid/operators/sequence_pool_op.h b/paddle/fluid/operators/sequence_pool_op.h
index 8706ff14aa..c58d677c92 100644
--- a/paddle/fluid/operators/sequence_pool_op.h
+++ b/paddle/fluid/operators/sequence_pool_op.h
@@ -23,12 +23,6 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
-template
-using EigenVector = framework::EigenVector;
-template
-using EigenMatrix = framework::EigenMatrix;
template
class SequencePoolKernel : public framework::OpKernel {
@@ -37,11 +31,13 @@ class SequencePoolKernel : public framework::OpKernel {
auto* in = context.Input("X");
auto* out = context.Output("Out");
std::string pooltype = context.Attr("pooltype");
+ Tensor* index = nullptr;
+ if (pooltype == "MAX") {
+ index = context.Output("MaxIndex");
+ }
auto dims = in->dims();
auto lod = in->lod();
- int64_t w = in->numel() / dims[0];
-
// InferShape by lod
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_GE(
@@ -50,45 +46,14 @@ class SequencePoolKernel : public framework::OpKernel {
"The first dimension of Input(X) must be large than batch size.");
dims[0] = lod[0].size() - 1;
out->Resize({dims});
-
- auto lod_level_0 = lod[0];
-
out->mutable_data(context.GetPlace());
- auto& dev_ctx = context.template device_context();
if (pooltype == "MAX") {
- math::MaxSeqPoolFunctor max_pool;
- auto* index = context.Output("MaxIndex");
index->Resize({dims});
index->mutable_data(context.GetPlace());
- max_pool(dev_ctx, *in, out, index);
- return;
- }
-
- auto& place =
- *context.template device_context().eigen_device();
- for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) {
- Tensor in_t = in->Slice(static_cast(lod_level_0[i]),
- static_cast(lod_level_0[i + 1]));
- Tensor out_t = out->Slice(i, i + 1);
- int64_t h = static_cast(lod_level_0[i + 1] - lod_level_0[i]);
- auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w}));
- auto out_e = EigenVector::Flatten(out_t);
-
- if (pooltype == "AVERAGE") {
- out_e.device(place) = in_e.mean(Eigen::array({{0}}));
- } else if (pooltype == "SUM") {
- out_e.device(place) = in_e.sum(Eigen::array({{0}}));
- } else if (pooltype == "SQRT") {
- out_e.device(place) = in_e.sum(Eigen::array({{0}})) /
- std::sqrt(static_cast(h));
- } else if (pooltype == "LAST") {
- out_e.device(place) = in_e.chip(h - 1, 0);
- } else if (pooltype == "FIRST") {
- out_e.device(place) = in_e.chip(0, 0);
- } else {
- PADDLE_THROW("unsupported pooling pooltype");
- }
}
+ math::SequencePoolFunctor pool;
+ pool(context.template device_context(), pooltype, *in, out,
+ index);
}
};
@@ -96,58 +61,17 @@ template
class SequencePoolGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto* in = context.Input("X");
auto* out_g = context.Input(framework::GradVarName("Out"));
auto* in_g = context.Output(framework::GradVarName("X"));
std::string pooltype = context.Attr("pooltype");
-
- auto dims = in->dims();
- auto lod = in->lod()[0];
- int64_t w = in->numel() / dims[0];
-
- in_g->mutable_data(context.GetPlace());
- auto& dev_ctx = context.template device_context();
-
+ const Tensor* index = nullptr;
if (pooltype == "MAX") {
- math::MaxSeqPoolGradFunctor