(lower_bound));
+ return *this;
+ }
+
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 8aa6728a95..c762811dfc 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -2,20 +2,20 @@
## Motivation
-In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
-
+In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the gradient operators/expressions together with the chain rule. Every forward network needs a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
+
## Backward Operator Registry
-A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs and output gradients and then calculate its input gradients.
+A backward network is built up with several backward operators. Backward operators take forward operators' inputs outputs, and output gradients and then calculate its input gradients.
| | forward operator | backward operator
| ---------------------- | ---------------- |------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
- In most cases, there is a one-to-one correspondence between forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
+ In most cases, there is a one-to-one correspondence between the forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
-For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro:
+For example, we have got a `mul_op`, and we can register its information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
@@ -27,17 +27,17 @@ REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
## Backward Opeartor Creating
-Given a certain forward operator, we can get its corresponding backward opeartor by calling:
+Given a certain forward operator, we can get its corresponding backward operator by calling:
```cpp
OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
-```
+```
The function `BuildGradOp` will sequentially execute following processes:
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these are not necessary for gradient computing.
+2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
@@ -49,31 +49,31 @@ A backward network is a series of backward operators. The main idea of building
In our design, the network itself is also a kind of operator. So the operators contained by a big network may be some small network.
-given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
+given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`, `InputGradients`.
1. Op
- when the input forward network is a Op, return its gradient Operator Immediately.
+ when the input forward network is an Op, return its gradient Operator Immediately.
2. NetOp
- when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to forward NetOp.
+ when the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
- **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwirte their shared input variable.
+ **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwrite their shared input variable.
- ![](./images/duplicate_op.png)
+ ![](./images/duplicate_op.png)
- 1. shared variable in two operators.
+ 1. Shared variable in operators.
- Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator replace the overwirte links.
+ Share variable between operators or same input variable used in multiple operators leads to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively and add a generic add operator replace the overwrite links.
- ![](images/duplicate_op2.png)
+ ![](images/duplicate_op2.png)
- 2. replace shared variable gradient with `Add` Operator
+ 2. Replace shared variable's gradient with `Add` operator.
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index 85b7de7974..fc3d508553 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -283,5 +283,14 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list init_list) {
*this = make_ddim(init_list);
}
+
+DDim flatten_to_2d(const DDim& src, int num_col_dims) {
+ int rank = src.size();
+ return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
+ product(slice_ddim(src, num_col_dims, rank))});
+}
+
+DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index db30c52394..ca29e7e8c7 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -115,6 +115,12 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&);
+// Reshape a tensor to a matrix. The matrix's first dimension(column length)
+// will be the product of tensor's first `num_col_dims` dimensions.
+DDim flatten_to_2d(const DDim& src, int num_col_dims);
+
+DDim flatten_to_1d(const DDim& src);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h
index 2d8d9ae10c..54bbeafcab 100644
--- a/paddle/framework/eigen.h
+++ b/paddle/framework/eigen.h
@@ -63,20 +63,35 @@ struct EigenTensor {
template
-struct EigenMatrix : public EigenTensor {};
+struct EigenMatrix : public EigenTensor {
+ static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+
+ static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
+ int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+};
template
struct EigenVector : public EigenTensor {
// Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
static typename EigenVector::ConstType Flatten(const Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
};
diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc
index dc1957691b..bc4a2db32c 100644
--- a/paddle/framework/eigen_test.cc
+++ b/paddle/framework/eigen_test.cc
@@ -108,5 +108,24 @@ TEST(Eigen, Matrix) {
}
}
+TEST(Eigen, MatrixReshape) {
+ Tensor t;
+ float* p = t.mutable_data({2, 3, 6, 4}, platform::CPUPlace());
+ for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
+ p[i] = static_cast(i);
+ }
+
+ EigenMatrix::Type em = EigenMatrix::Reshape(t, 2);
+
+ ASSERT_EQ(2 * 3, em.dimension(0));
+ ASSERT_EQ(6 * 4, em.dimension(1));
+
+ for (int i = 0; i < 2 * 3; i++) {
+ for (int j = 0; j < 6 * 4; j++) {
+ ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f);
+ }
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/images/duplicate_op2.graffle b/paddle/framework/images/duplicate_op2.graffle
index 2b658085d6..ede3bca30a 100644
Binary files a/paddle/framework/images/duplicate_op2.graffle and b/paddle/framework/images/duplicate_op2.graffle differ
diff --git a/paddle/framework/images/duplicate_op2.png b/paddle/framework/images/duplicate_op2.png
index c5588015d1..4e872dc2ca 100644
Binary files a/paddle/framework/images/duplicate_op2.png and b/paddle/framework/images/duplicate_op2.png differ
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 790cfc4746..e1e122091f 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type,
CheckAllInputOutputSet();
}
+std::vector OperatorBase::InputVars() const {
+ std::vector ret_val;
+ for (auto& o : outputs_) {
+ ret_val.reserve(ret_val.size() + o.second.size());
+ ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
+ }
+ return ret_val;
+}
+
std::vector OperatorBase::OutputVars(bool has_intermediate) const {
std::vector ret_val;
if (has_intermediate) {
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 9a98d4d3be..4600b06009 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -94,11 +94,14 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
+
//! Get a input with argument's name described in `op_proto`
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
const std::vector& Inputs(const std::string& name) const;
+ std::vector InputVars() const;
+
//! Get a output with argument's name described in `op_proto`
std::string Output(const std::string& name) const;
//! Get an output which has multiple variables.
@@ -311,9 +314,9 @@ class InferShapeContext {
}
template
- std::vector MultiOutput(const std::string& name) const {
+ std::vector MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
- std::vector res;
+ std::vector res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 643f875491..ce938b2143 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -43,6 +43,9 @@ class Tensor {
template
friend struct EigenTensor;
+ template
+ friend struct EigenMatrix;
+
template
friend struct EigenVector;
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index 94f436294f..637f04ae00 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -148,5 +148,13 @@ inline Tensor& Tensor::Resize(const DDim& dims) {
inline const DDim& Tensor::dims() const { return dims_; }
+template
+inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
+ Tensor res;
+ res.ShareDataWith(src);
+ res.Resize(flatten_to_2d(src.dims(), num_col_dims));
+ return res;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index 7db38d5cae..55302ea471 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) {
}
#endif
}
+
+TEST(Tensor, ReshapeToMatrix) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace());
+ for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
+ src_ptr[i] = i;
+ }
+ Tensor res = ReshapeToMatrix(src, 2);
+ ASSERT_EQ(res.dims()[0], 2 * 3);
+ ASSERT_EQ(res.dims()[1], 4 * 9);
+}
\ No newline at end of file
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp
index 1ceaaaa206..f7a80e23e1 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.cpp
+++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp
@@ -62,14 +62,18 @@ void BatchNormBaseLayer::calFeatureMapSize() {
const ImageConfig& conf = config_.inputs(0).image_conf();
imageH_ = inputLayers_[0]->getOutput().getFrameHeight();
imageW_ = inputLayers_[0]->getOutput().getFrameWidth();
+ imageD_ = inputLayers_[0]->getOutput().getFrameDepth();
+
+ if (0 == imageD_) imageD_ = conf.img_size_z();
if (imageH_ == 0 && imageW_ == 0) {
imageH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
imageW_ = conf.img_size();
} else {
getOutput().setFrameHeight(imageH_);
getOutput().setFrameWidth(imageW_);
+ getOutput().setFrameDepth(imageD_);
}
- imgPixels_ = imageH_ * imageW_;
+ imgPixels_ = imageH_ * imageW_ * imageD_;
}
} // namespace paddle
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h
index 230bafc31d..e721d2d267 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.h
+++ b/paddle/gserver/layers/BatchNormBaseLayer.h
@@ -80,6 +80,7 @@ protected:
/// Height or width of input image feature.
/// Both of them are 1 if the input is fully-connected layer.
+ int imageD_;
int imageH_;
int imageW_;
/// Height * Width.
diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
index 44ba2c4b7d..49a9540c0b 100644
--- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp
+++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
@@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
}
void CudnnBatchNormLayer::reshape(int batchSize) {
- hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
+ hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
}
void CudnnBatchNormLayer::forward(PassType passType) {
@@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
EPS,
batchSize,
channels_,
- imageH_,
+ imageH_ * imageD_,
imageW_);
}
}
diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp
index 6a91042f62..d7eee6eaf0 100644
--- a/paddle/gserver/layers/SwitchOrderLayer.cpp
+++ b/paddle/gserver/layers/SwitchOrderLayer.cpp
@@ -24,19 +24,21 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap,
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf();
+ size_t inD = img_conf.img_size_z();
size_t inH =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
size_t inW = img_conf.img_size();
size_t inC = img_conf.channels();
+ inH = inH * inD;
inDims_ = TensorShape({0, inC, inH, inW});
outDims_ = TensorShape(4);
auto& reshape_conf = config_.reshape_conf();
- for (size_t i = 0; i < reshape_conf.heightaxis_size(); i++) {
- heightAxis_.push_back(reshape_conf.heightaxis(i));
+ for (int i = 0; i < reshape_conf.height_axis_size(); i++) {
+ heightAxis_.push_back(reshape_conf.height_axis(i));
}
- for (size_t i = 0; i < reshape_conf.widthaxis_size(); i++) {
- widthAxis_.push_back(reshape_conf.widthaxis(i));
+ for (int i = 0; i < reshape_conf.width_axis_size(); i++) {
+ widthAxis_.push_back(reshape_conf.width_axis(i));
}
createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig());
createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig());
@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
inDims_.setDim(0, batchSize);
-
+ int d = inputLayers_[0]->getOutput().getFrameDepth();
+ d = (d == 0 ? 1 : d);
int h = inputLayers_[0]->getOutput().getFrameHeight();
- if (h != 0) inDims_.setDim(2, h);
+ if (h != 0) inDims_.setDim(2, h * d);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
int totalCount = input->getElementCnt();
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index e0c14ad5b5..0e6be2df9e 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1703,6 +1703,55 @@ TEST(Layer, BatchNormalizationLayer) {
#endif
}
+void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
+ TestConfig config;
+ const int CHANNELS = 10;
+ const int IMG_SIZE = 16;
+ const int IMG_SIZE_Y = 8;
+ const int IMG_SIZE_Z = 8;
+ size_t size = CHANNELS * IMG_SIZE * IMG_SIZE_Y * IMG_SIZE_Z;
+ config.layerConfig.set_type(type);
+ config.layerConfig.set_size(size);
+ config.layerConfig.set_active_type("sigmoid");
+ config.biasSize = CHANNELS;
+ config.inputDefs.push_back({INPUT_DATA,
+ "layer_0",
+ /* dim= */ size,
+ /* paraSize= */ CHANNELS});
+
+ config.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+ config.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+
+ LayerInputConfig* input = config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+
+ ImageConfig* img_conf = input->mutable_image_conf();
+ img_conf->set_channels(CHANNELS);
+ img_conf->set_img_size(IMG_SIZE);
+ img_conf->set_img_size_y(IMG_SIZE_Y);
+ img_conf->set_img_size_z(IMG_SIZE_Z);
+
+ testLayerGrad(config,
+ "batch_norm",
+ 64,
+ /* trans= */ trans,
+ useGpu,
+ /* useWeight */ true);
+}
+
+TEST(Layer, testBatchNorm3DLayer) {
+ testBatchNorm3DLayer("batch_norm", false, false);
+#ifndef PADDLE_ONLY_CPU
+ testBatchNorm3DLayer("batch_norm", false, true);
+ if (hl_get_cudnn_lib_version() >= int(4000)) {
+ testBatchNorm3DLayer("cudnn_batch_norm", false, true);
+ }
+#endif
+}
+
void testConvOperator(bool isDeconv) {
TestConfig config;
const int NUM_FILTERS = 16;
@@ -2019,10 +2068,10 @@ TEST(Layer, SwitchOrderLayer) {
img->set_img_size_y(16);
ReshapeConfig* reshape = config.layerConfig.mutable_reshape_conf();
- reshape->add_heightaxis(0);
- reshape->add_heightaxis(1);
- reshape->add_heightaxis(2);
- reshape->add_widthaxis(3);
+ reshape->add_height_axis(0);
+ reshape->add_height_axis(1);
+ reshape->add_height_axis(2);
+ reshape->add_width_axis(3);
// config softmax layer
config.layerConfig.set_type("switch_order");
diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc
index be956bf3b3..7d9d4fa519 100644
--- a/paddle/operators/identity_op.cc
+++ b/paddle/operators/identity_op.cc
@@ -18,17 +18,20 @@
namespace paddle {
namespace operators {
-// identity is a alias of scale op. This is also a example for creating a alias
-// operator.
+// The identity operator is an alias of the scale operator. This is also an
+// example for creating an alias for an existing operator.
template
class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IdentityOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
- AddInput("X", "input tensor of identity op");
- AddOutput("Out", "output tensor of identity op");
- AddComment("identity operator. Just a alias of scale op which scale = 1.0");
+ AddInput("X", "The input tensor of identity operator.");
+ AddOutput("Out", "The output tensor of identity operator.");
+ AddComment(R"DOC(
+The identity operator is an alias of the scale operator
+with the attribute scale fixed to 1.0.
+)DOC");
}
};
diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc
index f905600bb3..4f380388b1 100644
--- a/paddle/operators/math/im2col_test.cc
+++ b/paddle/operators/math/im2col_test.cc
@@ -74,7 +74,9 @@ void testIm2col() {
#ifndef PADDLE_ONLY_CPU
context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
-#endif
+#else
+ PADDLE_THROW("no GPU support");
+#endif // PADDLE_ONLY_CPU
}
im2col(input, output_cfo, stride, stride, padding, padding, context);
im2col_ocf(input, output_ocf, stride, stride, padding, padding, context);
diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc
index 28a47cdff2..710a56a0e8 100644
--- a/paddle/operators/mul_op.cc
+++ b/paddle/operators/mul_op.cc
@@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("Y")->dims();
- PADDLE_ENFORCE_EQ(dim0.size(), 2,
- "input X(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("X"));
- PADDLE_ENFORCE_EQ(dim1.size(), 2,
- "input Y(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("Y"));
+ auto x_dims = ctx.Input("X")->dims();
+ auto y_dims = ctx.Input("Y")->dims();
+ int x_num_col_dims = Attr("x_num_col_dims");
+ int y_num_col_dims = Attr("y_num_col_dims");
+
+ PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
+ "The rank of input tensor X(%s) should be larger than "
+ "`mul_op`'s `x_num_col_dims`.",
+ ctx.op().Input("X"));
+ PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
+ "The rank of input tensor Y(%s) should be larger than "
+ "`mul_op`'s `y_num_col_dims`.",
+ ctx.op().Input("Y"));
+
+ auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
+ auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
+
PADDLE_ENFORCE_EQ(
- dim0[1], dim1[0],
+ x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
- ctx.Output("Out")->Resize({dim0[0], dim1[1]});
+ ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
}
};
@@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
+ AddAttr(
+ "x_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
+ in that case, tensors will be reshaped to a matrix. The matrix's first
+ dimension(column length) will be the product of tensor's last
+ `num_col_dims` dimensions, and the matrix's second dimension(row length)
+ will be the product of tensor's first `rank - num_col_dims` dimensions.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
+ AddAttr(
+ "y_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
+ in that case, tensors will be reshaped to a matrix. Just like input `X`.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
AddComment(R"DOC(
Two Element Mul Operator.
@@ -70,10 +96,20 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output(framework::GradVarName("X"));
auto *y_grad = ctx.Output(framework::GradVarName("Y"));
- PADDLE_ENFORCE(x_dims[0] == out_dims[0],
- "Out@GRAD M X N must equal to X dims 0, M ");
- PADDLE_ENFORCE(y_dims[1] == out_dims[1],
- "Out@GRAD M X N must equal to Y dims 1, N ");
+
+ auto x_mat_dims =
+ framework::flatten_to_2d(x_dims, Attr("x_num_col_dims"));
+ auto y_mat_dims =
+ framework::flatten_to_2d(y_dims, Attr("y_num_col_dims"));
+
+ PADDLE_ENFORCE_EQ(
+ x_mat_dims[0], out_dims[0],
+ "The first dimension of Out@GRAD must equal to the first dimension of "
+ "the first operand.");
+ PADDLE_ENFORCE_EQ(
+ y_mat_dims[1], out_dims[1],
+ "The second dimension of Out@GRAD must equal to the second "
+ "dimension of the second operand.");
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h
index 05a79e13b3..3c01f868bd 100644
--- a/paddle/operators/mul_op.h
+++ b/paddle/operators/mul_op.h
@@ -1,7 +1,7 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
+ You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
@@ -31,13 +31,25 @@ template
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto* x = context.Input("X");
- auto* y = context.Input("Y");
- auto* z = context.Output("Out");
+ const Tensor* x = context.Input("X");
+ const Tensor* y = context.Input("Y");
+ Tensor* z = context.Output("Out");
+ const Tensor x_matrix =
+ x->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *x, context.template Attr("x_num_col_dims"))
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *y, context.template Attr("y_num_col_dims"))
+ : *y;
+
z->mutable_data(context.GetPlace());
auto* device_context =
const_cast(context.device_context_);
- math::matmul(*x, false, *y, false, 1, z, 0, device_context);
+ math::matmul(x_matrix, false, y_matrix, false, 1, z, 0,
+ device_context);
}
};
@@ -45,23 +57,39 @@ template
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
- auto* x = ctx.Input("X");
- auto* y = ctx.Input("Y");
- auto* dout = ctx.Input(framework::GradVarName("Out"));
+ int x_num_col_dims = ctx.template Attr("x_num_col_dims");
+ int y_num_col_dims = ctx.template Attr("y_num_col_dims");
+ const Tensor* x = ctx.Input("X");
+ const Tensor* y = ctx.Input("Y");
+ const Tensor x_matrix =
+ x->dims().size() > 2 ? framework::ReshapeToMatrix(*x, x_num_col_dims)
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2 ? framework::ReshapeToMatrix(*y, y_num_col_dims)
+ : *y;
+ const Tensor* dout = ctx.Input(framework::GradVarName("Out"));
- auto* dx = ctx.Output(framework::GradVarName("X"));
- auto* dy = ctx.Output(framework::GradVarName("Y"));
+ Tensor* dx = ctx.Output(framework::GradVarName("X"));
+ Tensor* dy = ctx.Output(framework::GradVarName("Y"));
auto* device_context =
const_cast(ctx.device_context_);
if (dx) {
dx->mutable_data(ctx.GetPlace());
+ Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dx, x_num_col_dims)
+ : *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
- math::matmul(*dout, false, *y, true, 1, dx, 0, device_context);
+ math::matmul(*dout, false, y_matrix, true, 1, &dx_matrix, 0,
+ device_context);
}
if (dy) {
dy->mutable_data(ctx.GetPlace());
+ Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dy, y_num_col_dims)
+ : *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
- math::matmul(*x, true, *dout, false, 1, dy, 0, device_context);
+ math::matmul(x_matrix, true, *dout, false, 1, &dy_matrix, 0,
+ device_context);
}
}
};
diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc
index 30b4b40431..fa8f0ff1a8 100644
--- a/paddle/operators/rowwise_add_op.cc
+++ b/paddle/operators/rowwise_add_op.cc
@@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("b")->dims();
-
- PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
- PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
- PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
- PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1");
- ctx.Output("Out")->Resize(ctx.Input("X")->dims());
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
+ PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1");
+ ctx.Output("Out")->Resize(x_dims);
}
};
@@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
- auto dims0 = ctx.Input("X")->dims();
- auto dims1 = ctx.Input("b")->dims();
- PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
auto *dx = ctx.Output(framework::GradVarName("X"));
auto *db = ctx.Output(framework::GradVarName("b"));
- if (dx) dx->Resize(dims0);
- if (db) db->Resize(dims1);
+ if (dx) dx->Resize(x_dims);
+ if (db) db->Resize(b_dims);
}
};
diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h
index 4e926d9f29..35774b9409 100644
--- a/paddle/operators/rowwise_add_op.h
+++ b/paddle/operators/rowwise_add_op.h
@@ -33,10 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output("Out");
out->mutable_data(context.GetPlace());
-
- auto input = EigenMatrix::From(*context.Input("X"));
- auto bias = EigenVector::From(*context.Input("b"));
- auto output = EigenMatrix::From(*out);
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
+ auto input =
+ EigenMatrix::Reshape(*context.Input("X"), num_col_dims);
+ auto bias = EigenVector::Flatten(*context.Input("b"));
+ auto output = EigenMatrix::Reshape(*out, num_col_dims);
const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size;
@@ -54,12 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel {
auto* dout = context.Input(framework::GradVarName("Out"));
auto* dx = context.Output(framework::GradVarName("X"));
auto* db = context.Output(framework::GradVarName("b"));
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
- auto out_grad = EigenMatrix::From(*dout);
+ auto out_grad = EigenMatrix::Reshape(*dout, num_col_dims);
auto place = context.GetEigenDevice();
+
if (dx) {
dx->mutable_data(context.GetPlace());
- EigenMatrix::From(*dx).device(place) = out_grad;
+ EigenMatrix::Reshape(*dx, num_col_dims).device(place) = out_grad;
}
if (db) {
diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc
index 3d82b34582..ea991f683d 100644
--- a/paddle/operators/scale_op.cc
+++ b/paddle/operators/scale_op.cc
@@ -44,11 +44,13 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
The equation is: Out = scale*X
)DOC");
- AddAttr("scale", "scale of scale operator.").SetDefault(1.0);
+ AddAttr("scale", "The scaling factor of the scale operator.")
+ .SetDefault(1.0);
}
};
-// Scale Op's gradient is scale op, too.
+// The operator to calculate gradients of a scale operator is just the scale
+// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template
class ScaleGradOp : public NetOp {
diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc
index 7d062ad67c..7166b2f60b 100644
--- a/paddle/operators/softmax_op.cc
+++ b/paddle/operators/softmax_op.cc
@@ -51,7 +51,7 @@ the other dimensions in the K-dimensional vector input. Then the ratio of the
exponential of the given dimension and the sum of exponential values of all
the other dimensions is the output of the softmax operator.
-For each row `i` and each column `j` in X, we have:
+For each row `i` and each column `j` in input X, we have:
Y[i, j] = exp(X[i, j]) / sum_j(exp(X[i, j]))
)DOC");
@@ -64,14 +64,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
- "Input(Y@GRAD) should not be null");
- PADDLE_ENFORCE(ctx.Input("Y")->dims() ==
- ctx.Input(framework::GradVarName("Y"))->dims(),
- "the shape of Input(0) and Input(1) should be the same");
+ "Input(Y@GRAD) should be not null.");
+ PADDLE_ENFORCE_EQ(ctx.Input("Y")->dims(),
+ ctx.Input(framework::GradVarName("Y"))->dims(),
+ "Input(Y) and its gradients should have a same shape.");
+
ctx.Output(framework::GradVarName("X"))
- ->Resize(ctx.Input("Y")->dims());
+ ->Resize(ctx.Input("X")->dims());
}
};
diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h
index 4fa6b59540..8a3a5ab927 100644
--- a/paddle/operators/softmax_op.h
+++ b/paddle/operators/softmax_op.h
@@ -28,12 +28,12 @@ template
class SoftmaxKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto input = context.Input("X");
- auto output = context.Output("Y");
- output->mutable_data(context.GetPlace());
+ auto X = context.Input("X");
+ auto Y = context.Output("Y");
+ Y->mutable_data(context.GetPlace());
- auto logits = EigenMatrix::From(*input);
- auto softmax = EigenMatrix::From(*output);
+ auto logits = EigenMatrix::From(*X);
+ auto softmax = EigenMatrix::From(*Y);
const int kBatchDim = 0;
const int kClassDim = 1;
diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc
new file mode 100644
index 0000000000..5805826ee8
--- /dev/null
+++ b/paddle/operators/sum_op.cc
@@ -0,0 +1,73 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/sum_op.h"
+#include
+
+namespace paddle {
+namespace operators {
+using framework::Tensor;
+
+class SumOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto ins = ctx.MultiInput("X");
+ auto *out = ctx.Output("Out");
+ int N = ins.size();
+
+ auto in_dim = ins[0]->dims();
+
+ PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
+ for (int i = 1; i < N; i++) {
+ auto dim = ins[i]->dims();
+ PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
+ }
+ out->Resize(in_dim);
+ }
+};
+
+class SumOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "the input tensors of sum operator.").AsDuplicable();
+ AddOutput("Out", "the output tensor of sum operator.");
+ AddComment(R"DOC(
+ Sum the input tensors.
+ )DOC");
+ }
+};
+
+class SumGradOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto outputs = ctx.MultiOutput(framework::GradVarName("X"));
+ auto dims = ctx.Input(framework::GradVarName("Out"))->dims();
+ for (auto output : outputs) {
+ output->Resize(dims);
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
+REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_CPU_KERNEL(sum_grad,
+ ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.cu b/paddle/operators/sum_op.cu
new file mode 100644
index 0000000000..a465cf3659
--- /dev/null
+++ b/paddle/operators/sum_op.cu
@@ -0,0 +1,18 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/sum_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_GPU_KERNEL(sum_grad,
+ ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h
new file mode 100644
index 0000000000..0b1e9ebaa3
--- /dev/null
+++ b/paddle/operators/sum_op.h
@@ -0,0 +1,65 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template
+using EigenVector = framework::EigenVector;
+
+template
+class SumKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto ins = context.MultiInput("X");
+ auto* out = context.Output("Out");
+ out->mutable_data(context.GetPlace());
+
+ auto place = context.GetEigenDevice();
+ auto result = EigenVector::Flatten(*out);
+
+ int N = ins.size();
+ auto in = EigenVector::Flatten(*(ins[0]));
+ result.device(place) = in;
+ for (int i = 1; i < N; i++) {
+ auto in = EigenVector::Flatten(*(ins[i]));
+ result.device(place) = result + in;
+ }
+ }
+};
+
+template
+class SumGradKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto* input = context.Input(framework::GradVarName("Out"));
+ auto outs = context.MultiOutput(framework::GradVarName("X"));
+ for (auto out : outs) {
+ out->mutable_data(context.GetPlace());
+ }
+
+ auto place = context.GetEigenDevice();
+ auto in = EigenVector::Flatten(*input);
+ for (auto out : outs) {
+ auto result = EigenVector::Flatten(*out);
+ result.device(place) = in;
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc
new file mode 100644
index 0000000000..38d2f0a09a
--- /dev/null
+++ b/paddle/operators/top_k_op.cc
@@ -0,0 +1,67 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/top_k_op.h"
+
+namespace paddle {
+namespace operators {
+
+class TopkOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
+ "Input of TopkOP must be initialized.");
+ auto *input = ctx.Input("X");
+ const int k = static_cast(ctx.Attr("k"));
+
+ PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
+ PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape");
+ PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k,
+ "input must have >= k columns");
+
+ framework::DDim dims = input->dims();
+ dims[dims.size() - 1] = k;
+ ctx.Output("Out")->Resize(dims);
+ ctx.Output("Indices")->Resize(dims);
+ }
+};
+
+class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ TopkOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "The input of Topk op");
+ AddOutput("Out", "The output tensor of Topk op");
+ AddOutput("Indices", "The indices of Topk elements of input");
+ AddComment(
+ R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j].
+
+ For matrices, computes the top k entries in each row. )DOC");
+ AddAttr("k",
+ "Number of top elements to look for along the last "
+ "dimension (along each row for matrices).")
+ .SetDefault(1);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker);
+REGISTER_OP_CPU_KERNEL(top_k,
+ ops::TopkKernel);
diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu
new file mode 100644
index 0000000000..afe4d149c5
--- /dev/null
+++ b/paddle/operators/top_k_op.cu
@@ -0,0 +1,318 @@
+/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/op_registry.h"
+#include "paddle/platform/assert.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+struct Pair {
+ __device__ __forceinline__ Pair() {}
+ __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
+
+ __device__ __forceinline__ void set(T value, int id) {
+ v = value;
+ id = id;
+ }
+
+ __device__ __forceinline__ void operator=(const Pair& in) {
+ v = in.v;
+ id = in.id;
+ }
+
+ __device__ __forceinline__ bool operator<(const T value) const {
+ return (v < value);
+ }
+
+ __device__ __forceinline__ bool operator<(const Pair& in) const {
+ return (v < in.v) || ((v == in.v) && (id > in.id));
+ }
+
+ __device__ __forceinline__ bool operator>(const Pair& in) const {
+ return (v > in.v) || ((v == in.v) && (id < in.id));
+ }
+
+ T v;
+ int id;
+};
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p,
+ int beam_size) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* src,
+ bool& firstStep, bool& is_empty,
+ Pair& max, int dim,
+ const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, src, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, src, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* val,
+ int* col, bool& firstStep,
+ bool& is_empty, Pair& max,
+ int dim, const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, val, col, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, val, col, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid,
+ Pair topk[], T** topVal,
+ int** topIds, int& beam, int& k,
+ const int tid, const int warp) {
+ while (true) {
+ __syncthreads();
+ if (tid < BlockSize / 2) {
+ if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
+ maxid[tid] = tid + BlockSize / 2;
+ } else {
+ maxid[tid] = tid;
+ }
+ }
+ __syncthreads();
+ for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
+ if (tid < stride) {
+ if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
+ maxid[tid] = maxid[tid + stride];
+ }
+ }
+ __syncthreads();
+ }
+ __syncthreads();
+
+ if (tid == 0) {
+ **topVal = sh_topk[maxid[0]].v;
+ **topIds = sh_topk[maxid[0]].id;
+ (*topVal)++;
+ (*topIds)++;
+ }
+ if (tid == maxid[0]) beam++;
+ if (--k == 0) break;
+ __syncthreads();
+
+ if (tid == maxid[0]) {
+ if (beam < MaxLength) {
+ sh_topk[tid] = topk[beam];
+ }
+ }
+ if (maxid[0] / 32 == warp) {
+ if (__shfl(beam, (maxid[0]) % 32, 32) == MaxLength) break;
+ }
+ }
+}
+
+/**
+ * Each block compute one sample.
+ * In a block:
+ * 1. every thread get top MaxLength value;
+ * 2. merge to sh_topk, block reduce and get max value;
+ * 3. go to the second setp, until one thread's topk value is null;
+ * 4. go to the first setp, until get the topk value.
+ */
+template
+__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
+ const T* src, int lds, int dim, int k) {
+ __shared__ Pair sh_topk[BlockSize];
+ __shared__ int maxid[BlockSize / 2];
+ const int tid = threadIdx.x;
+ const int warp = threadIdx.x / 32;
+ output += blockIdx.x * output_stride;
+ indices += blockIdx.x * k;
+
+ Pair topk[MaxLength];
+ int beam = MaxLength;
+ Pair max;
+ bool is_empty = false;
+ bool firststep = true;
+
+ for (int k = 0; k < MaxLength; k++) {
+ topk[k].set(-INFINITY, -1);
+ }
+ while (k) {
+ ThreadGetTopK(topk, beam, k,
+ src + blockIdx.x * lds, firststep,
+ is_empty, max, dim, tid);
+
+ sh_topk[tid] = topk[0];
+ BlockReduce(sh_topk, maxid, topk, &output,
+ &indices, beam, k, tid, warp);
+ }
+}
+
+template
+class TopkOpCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+ "It must use GPUPlace.");
+ auto* input = ctx.Input("X");
+ auto* output = ctx.Output("Out");
+ auto* indices = ctx.Output("Indices");
+ size_t k = static_cast(ctx.Attr("k"));
+
+ const T* input_data = input->data();
+
+ T* output_data = output->mutable_data(ctx.GetPlace());
+ // FIXME(typhoonzero): data is always converted to type T?
+ int* indices_data = indices->mutable_data(ctx.GetPlace());
+
+ size_t input_height = input->dims()[0];
+ size_t input_width = input->dims()[1];
+ if (k > input_width) k = input_width;
+
+ // NOTE: pass lds and dim same to input width.
+ // NOTE: old matrix implementation of stride is different to eigen.
+ // TODO(typhoonzero): launch kernel on specified stream.
+ // TODO(typhoonzero): refine this kernel.
+ dim3 threads(256, 1);
+ dim3 grid(input_height, 1);
+
+ KeMatrixTopK<<>>(
+ output_data, output->dims()[1], indices_data, input_data, input_width,
+ input_width, int(k));
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+REGISTER_OP_GPU_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel);
diff --git a/paddle/operators/top_k_op.h b/paddle/operators/top_k_op.h
new file mode 100644
index 0000000000..ef66acc1d5
--- /dev/null
+++ b/paddle/operators/top_k_op.h
@@ -0,0 +1,76 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+using EigenMatrix = framework::EigenMatrix;
+
+template
+class TopkKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ // Get the top k elements of each row of input tensor
+ // FIXME: only deal with matrix(2d tensor).
+ auto* input = ctx.Input("X");
+ auto* output = ctx.Output("Out");
+ auto* indices = ctx.Output("Indices");
+ // k is determined by Attr
+ const size_t k = static_cast(ctx.Attr("k"));
+
+ T* output_data = output->mutable_data(ctx.GetPlace());
+ T* indices_data = indices->mutable_data(ctx.GetPlace());
+
+ auto eg_input = EigenMatrix::From(*input);
+
+ // reshape input to a flattern matrix(like flat_inner_dims)
+ framework::DDim inputdims = input->dims();
+ const size_t row = framework::product(
+ framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
+ const size_t col = inputdims[inputdims.size() - 1];
+ Eigen::DSizes flat2dims(row, col);
+ // NOTE: eigen shape doesn't affect paddle tensor.
+ eg_input.reshape(flat2dims);
+
+ for (size_t i = 0; i < row; i++) {
+ std::vector> vec;
+ for (size_t j = 0; j < col; j++) {
+ vec.push_back(std::pair(eg_input(i, j), j));
+ }
+
+ std::partial_sort(
+ vec.begin(), vec.begin() + k, vec.end(),
+ [](const std::pair