(lower_bound));
+ return *this;
+ }
+
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 8aa6728a95..c762811dfc 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -2,20 +2,20 @@
## Motivation
-In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
-
+In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the gradient operators/expressions together with the chain rule. Every forward network needs a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
+
## Backward Operator Registry
-A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs and output gradients and then calculate its input gradients.
+A backward network is built up with several backward operators. Backward operators take forward operators' inputs outputs, and output gradients and then calculate its input gradients.
| | forward operator | backward operator
| ---------------------- | ---------------- |------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
- In most cases, there is a one-to-one correspondence between forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
+ In most cases, there is a one-to-one correspondence between the forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
-For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro:
+For example, we have got a `mul_op`, and we can register its information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
@@ -27,17 +27,17 @@ REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
## Backward Opeartor Creating
-Given a certain forward operator, we can get its corresponding backward opeartor by calling:
+Given a certain forward operator, we can get its corresponding backward operator by calling:
```cpp
OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
-```
+```
The function `BuildGradOp` will sequentially execute following processes:
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these are not necessary for gradient computing.
+2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
@@ -49,31 +49,31 @@ A backward network is a series of backward operators. The main idea of building
In our design, the network itself is also a kind of operator. So the operators contained by a big network may be some small network.
-given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
+given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`, `InputGradients`.
1. Op
- when the input forward network is a Op, return its gradient Operator Immediately.
+ when the input forward network is an Op, return its gradient Operator Immediately.
2. NetOp
- when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to forward NetOp.
+ when the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
- **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwirte their shared input variable.
+ **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwrite their shared input variable.
- ![](./images/duplicate_op.png)
+ ![](./images/duplicate_op.png)
- 1. shared variable in two operators.
+ 1. Shared variable in operators.
- Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator replace the overwirte links.
+ Share variable between operators or same input variable used in multiple operators leads to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively and add a generic add operator replace the overwrite links.
- ![](images/duplicate_op2.png)
+ ![](images/duplicate_op2.png)
- 2. replace shared variable gradient with `Add` Operator
+ 2. Replace shared variable's gradient with `Add` operator.
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index 85b7de7974..fc3d508553 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -283,5 +283,14 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list init_list) {
*this = make_ddim(init_list);
}
+
+DDim flatten_to_2d(const DDim& src, int num_col_dims) {
+ int rank = src.size();
+ return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
+ product(slice_ddim(src, num_col_dims, rank))});
+}
+
+DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index db30c52394..ca29e7e8c7 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -115,6 +115,12 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&);
+// Reshape a tensor to a matrix. The matrix's first dimension(column length)
+// will be the product of tensor's first `num_col_dims` dimensions.
+DDim flatten_to_2d(const DDim& src, int num_col_dims);
+
+DDim flatten_to_1d(const DDim& src);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h
index 2d8d9ae10c..54bbeafcab 100644
--- a/paddle/framework/eigen.h
+++ b/paddle/framework/eigen.h
@@ -63,20 +63,35 @@ struct EigenTensor {
template
-struct EigenMatrix : public EigenTensor {};
+struct EigenMatrix : public EigenTensor {
+ static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+
+ static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
+ int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+};
template
struct EigenVector : public EigenTensor {
// Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
static typename EigenVector::ConstType Flatten(const Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
};
diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc
index dc1957691b..bc4a2db32c 100644
--- a/paddle/framework/eigen_test.cc
+++ b/paddle/framework/eigen_test.cc
@@ -108,5 +108,24 @@ TEST(Eigen, Matrix) {
}
}
+TEST(Eigen, MatrixReshape) {
+ Tensor t;
+ float* p = t.mutable_data({2, 3, 6, 4}, platform::CPUPlace());
+ for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
+ p[i] = static_cast(i);
+ }
+
+ EigenMatrix::Type em = EigenMatrix::Reshape(t, 2);
+
+ ASSERT_EQ(2 * 3, em.dimension(0));
+ ASSERT_EQ(6 * 4, em.dimension(1));
+
+ for (int i = 0; i < 2 * 3; i++) {
+ for (int j = 0; j < 6 * 4; j++) {
+ ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f);
+ }
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/images/duplicate_op2.graffle b/paddle/framework/images/duplicate_op2.graffle
index 2b658085d6..ede3bca30a 100644
Binary files a/paddle/framework/images/duplicate_op2.graffle and b/paddle/framework/images/duplicate_op2.graffle differ
diff --git a/paddle/framework/images/duplicate_op2.png b/paddle/framework/images/duplicate_op2.png
index c5588015d1..4e872dc2ca 100644
Binary files a/paddle/framework/images/duplicate_op2.png and b/paddle/framework/images/duplicate_op2.png differ
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 790cfc4746..e1e122091f 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type,
CheckAllInputOutputSet();
}
+std::vector OperatorBase::InputVars() const {
+ std::vector ret_val;
+ for (auto& o : outputs_) {
+ ret_val.reserve(ret_val.size() + o.second.size());
+ ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
+ }
+ return ret_val;
+}
+
std::vector OperatorBase::OutputVars(bool has_intermediate) const {
std::vector ret_val;
if (has_intermediate) {
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 9a98d4d3be..4600b06009 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -94,11 +94,14 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
+
//! Get a input with argument's name described in `op_proto`
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
const std::vector& Inputs(const std::string& name) const;
+ std::vector InputVars() const;
+
//! Get a output with argument's name described in `op_proto`
std::string Output(const std::string& name) const;
//! Get an output which has multiple variables.
@@ -311,9 +314,9 @@ class InferShapeContext {
}
template
- std::vector MultiOutput(const std::string& name) const {
+ std::vector MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
- std::vector res;
+ std::vector res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 643f875491..ce938b2143 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -43,6 +43,9 @@ class Tensor {
template
friend struct EigenTensor;
+ template
+ friend struct EigenMatrix;
+
template
friend struct EigenVector;
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index 94f436294f..637f04ae00 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -148,5 +148,13 @@ inline Tensor& Tensor::Resize(const DDim& dims) {
inline const DDim& Tensor::dims() const { return dims_; }
+template
+inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
+ Tensor res;
+ res.ShareDataWith(src);
+ res.Resize(flatten_to_2d(src.dims(), num_col_dims));
+ return res;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index 7db38d5cae..55302ea471 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) {
}
#endif
}
+
+TEST(Tensor, ReshapeToMatrix) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace());
+ for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
+ src_ptr[i] = i;
+ }
+ Tensor res = ReshapeToMatrix(src, 2);
+ ASSERT_EQ(res.dims()[0], 2 * 3);
+ ASSERT_EQ(res.dims()[1], 4 * 9);
+}
\ No newline at end of file
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp
index 1ceaaaa206..f7a80e23e1 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.cpp
+++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp
@@ -62,14 +62,18 @@ void BatchNormBaseLayer::calFeatureMapSize() {
const ImageConfig& conf = config_.inputs(0).image_conf();
imageH_ = inputLayers_[0]->getOutput().getFrameHeight();
imageW_ = inputLayers_[0]->getOutput().getFrameWidth();
+ imageD_ = inputLayers_[0]->getOutput().getFrameDepth();
+
+ if (0 == imageD_) imageD_ = conf.img_size_z();
if (imageH_ == 0 && imageW_ == 0) {
imageH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
imageW_ = conf.img_size();
} else {
getOutput().setFrameHeight(imageH_);
getOutput().setFrameWidth(imageW_);
+ getOutput().setFrameDepth(imageD_);
}
- imgPixels_ = imageH_ * imageW_;
+ imgPixels_ = imageH_ * imageW_ * imageD_;
}
} // namespace paddle
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h
index 230bafc31d..e721d2d267 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.h
+++ b/paddle/gserver/layers/BatchNormBaseLayer.h
@@ -80,6 +80,7 @@ protected:
/// Height or width of input image feature.
/// Both of them are 1 if the input is fully-connected layer.
+ int imageD_;
int imageH_;
int imageW_;
/// Height * Width.
diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
index 44ba2c4b7d..49a9540c0b 100644
--- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp
+++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
@@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
}
void CudnnBatchNormLayer::reshape(int batchSize) {
- hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
+ hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
}
void CudnnBatchNormLayer::forward(PassType passType) {
@@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
EPS,
batchSize,
channels_,
- imageH_,
+ imageH_ * imageD_,
imageW_);
}
}
diff --git a/paddle/gserver/layers/DetectionOutputLayer.cpp b/paddle/gserver/layers/DetectionOutputLayer.cpp
index 8ab838e191..0cf0a92bf4 100644
--- a/paddle/gserver/layers/DetectionOutputLayer.cpp
+++ b/paddle/gserver/layers/DetectionOutputLayer.cpp
@@ -139,7 +139,13 @@ void DetectionOutputLayer::forward(PassType passType) {
allDecodedBBoxes,
&allIndices);
- resetOutput(numKept, 7);
+ if (numKept > 0) {
+ resetOutput(numKept, 7);
+ } else {
+ MatrixPtr outV = getOutputValue();
+ outV = NULL;
+ return;
+ }
MatrixPtr outV = getOutputValue();
getDetectionOutput(confBuffer_->getData(),
numKept,
diff --git a/paddle/gserver/layers/DetectionUtil.cpp b/paddle/gserver/layers/DetectionUtil.cpp
index 3e61adc66e..d83674f45a 100644
--- a/paddle/gserver/layers/DetectionUtil.cpp
+++ b/paddle/gserver/layers/DetectionUtil.cpp
@@ -469,7 +469,7 @@ size_t getDetectionIndices(
const size_t numClasses,
const size_t backgroundId,
const size_t batchSize,
- const size_t confThreshold,
+ const real confThreshold,
const size_t nmsTopK,
const real nmsThreshold,
const size_t keepTopK,
diff --git a/paddle/gserver/layers/DetectionUtil.h b/paddle/gserver/layers/DetectionUtil.h
index fe4f9f075e..641ed873b4 100644
--- a/paddle/gserver/layers/DetectionUtil.h
+++ b/paddle/gserver/layers/DetectionUtil.h
@@ -275,7 +275,7 @@ size_t getDetectionIndices(
const size_t numClasses,
const size_t backgroundId,
const size_t batchSize,
- const size_t confThreshold,
+ const real confThreshold,
const size_t nmsTopK,
const real nmsThreshold,
const size_t keepTopK,
diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp
index 6a91042f62..d7eee6eaf0 100644
--- a/paddle/gserver/layers/SwitchOrderLayer.cpp
+++ b/paddle/gserver/layers/SwitchOrderLayer.cpp
@@ -24,19 +24,21 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap,
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf();
+ size_t inD = img_conf.img_size_z();
size_t inH =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
size_t inW = img_conf.img_size();
size_t inC = img_conf.channels();
+ inH = inH * inD;
inDims_ = TensorShape({0, inC, inH, inW});
outDims_ = TensorShape(4);
auto& reshape_conf = config_.reshape_conf();
- for (size_t i = 0; i < reshape_conf.heightaxis_size(); i++) {
- heightAxis_.push_back(reshape_conf.heightaxis(i));
+ for (int i = 0; i < reshape_conf.height_axis_size(); i++) {
+ heightAxis_.push_back(reshape_conf.height_axis(i));
}
- for (size_t i = 0; i < reshape_conf.widthaxis_size(); i++) {
- widthAxis_.push_back(reshape_conf.widthaxis(i));
+ for (int i = 0; i < reshape_conf.width_axis_size(); i++) {
+ widthAxis_.push_back(reshape_conf.width_axis(i));
}
createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig());
createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig());
@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
inDims_.setDim(0, batchSize);
-
+ int d = inputLayers_[0]->getOutput().getFrameDepth();
+ d = (d == 0 ? 1 : d);
int h = inputLayers_[0]->getOutput().getFrameHeight();
- if (h != 0) inDims_.setDim(2, h);
+ if (h != 0) inDims_.setDim(2, h * d);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
int totalCount = input->getElementCnt();
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index e0c14ad5b5..0e6be2df9e 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1703,6 +1703,55 @@ TEST(Layer, BatchNormalizationLayer) {
#endif
}
+void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
+ TestConfig config;
+ const int CHANNELS = 10;
+ const int IMG_SIZE = 16;
+ const int IMG_SIZE_Y = 8;
+ const int IMG_SIZE_Z = 8;
+ size_t size = CHANNELS * IMG_SIZE * IMG_SIZE_Y * IMG_SIZE_Z;
+ config.layerConfig.set_type(type);
+ config.layerConfig.set_size(size);
+ config.layerConfig.set_active_type("sigmoid");
+ config.biasSize = CHANNELS;
+ config.inputDefs.push_back({INPUT_DATA,
+ "layer_0",
+ /* dim= */ size,
+ /* paraSize= */ CHANNELS});
+
+ config.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+ config.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+
+ LayerInputConfig* input = config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+
+ ImageConfig* img_conf = input->mutable_image_conf();
+ img_conf->set_channels(CHANNELS);
+ img_conf->set_img_size(IMG_SIZE);
+ img_conf->set_img_size_y(IMG_SIZE_Y);
+ img_conf->set_img_size_z(IMG_SIZE_Z);
+
+ testLayerGrad(config,
+ "batch_norm",
+ 64,
+ /* trans= */ trans,
+ useGpu,
+ /* useWeight */ true);
+}
+
+TEST(Layer, testBatchNorm3DLayer) {
+ testBatchNorm3DLayer("batch_norm", false, false);
+#ifndef PADDLE_ONLY_CPU
+ testBatchNorm3DLayer("batch_norm", false, true);
+ if (hl_get_cudnn_lib_version() >= int(4000)) {
+ testBatchNorm3DLayer("cudnn_batch_norm", false, true);
+ }
+#endif
+}
+
void testConvOperator(bool isDeconv) {
TestConfig config;
const int NUM_FILTERS = 16;
@@ -2019,10 +2068,10 @@ TEST(Layer, SwitchOrderLayer) {
img->set_img_size_y(16);
ReshapeConfig* reshape = config.layerConfig.mutable_reshape_conf();
- reshape->add_heightaxis(0);
- reshape->add_heightaxis(1);
- reshape->add_heightaxis(2);
- reshape->add_widthaxis(3);
+ reshape->add_height_axis(0);
+ reshape->add_height_axis(1);
+ reshape->add_height_axis(2);
+ reshape->add_width_axis(3);
// config softmax layer
config.layerConfig.set_type("switch_order");
diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc
index 28a47cdff2..710a56a0e8 100644
--- a/paddle/operators/mul_op.cc
+++ b/paddle/operators/mul_op.cc
@@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("Y")->dims();
- PADDLE_ENFORCE_EQ(dim0.size(), 2,
- "input X(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("X"));
- PADDLE_ENFORCE_EQ(dim1.size(), 2,
- "input Y(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("Y"));
+ auto x_dims = ctx.Input("X")->dims();
+ auto y_dims = ctx.Input("Y")->dims();
+ int x_num_col_dims = Attr("x_num_col_dims");
+ int y_num_col_dims = Attr("y_num_col_dims");
+
+ PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
+ "The rank of input tensor X(%s) should be larger than "
+ "`mul_op`'s `x_num_col_dims`.",
+ ctx.op().Input("X"));
+ PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
+ "The rank of input tensor Y(%s) should be larger than "
+ "`mul_op`'s `y_num_col_dims`.",
+ ctx.op().Input("Y"));
+
+ auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
+ auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
+
PADDLE_ENFORCE_EQ(
- dim0[1], dim1[0],
+ x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
- ctx.Output("Out")->Resize({dim0[0], dim1[1]});
+ ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
}
};
@@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
+ AddAttr(
+ "x_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
+ in that case, tensors will be reshaped to a matrix. The matrix's first
+ dimension(column length) will be the product of tensor's last
+ `num_col_dims` dimensions, and the matrix's second dimension(row length)
+ will be the product of tensor's first `rank - num_col_dims` dimensions.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
+ AddAttr(
+ "y_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
+ in that case, tensors will be reshaped to a matrix. Just like input `X`.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
AddComment(R"DOC(
Two Element Mul Operator.
@@ -70,10 +96,20 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output(framework::GradVarName("X"));
auto *y_grad = ctx.Output(framework::GradVarName("Y"));
- PADDLE_ENFORCE(x_dims[0] == out_dims[0],
- "Out@GRAD M X N must equal to X dims 0, M ");
- PADDLE_ENFORCE(y_dims[1] == out_dims[1],
- "Out@GRAD M X N must equal to Y dims 1, N ");
+
+ auto x_mat_dims =
+ framework::flatten_to_2d(x_dims, Attr("x_num_col_dims"));
+ auto y_mat_dims =
+ framework::flatten_to_2d(y_dims, Attr("y_num_col_dims"));
+
+ PADDLE_ENFORCE_EQ(
+ x_mat_dims[0], out_dims[0],
+ "The first dimension of Out@GRAD must equal to the first dimension of "
+ "the first operand.");
+ PADDLE_ENFORCE_EQ(
+ y_mat_dims[1], out_dims[1],
+ "The second dimension of Out@GRAD must equal to the second "
+ "dimension of the second operand.");
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h
index 05a79e13b3..3c01f868bd 100644
--- a/paddle/operators/mul_op.h
+++ b/paddle/operators/mul_op.h
@@ -1,7 +1,7 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
+ You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
@@ -31,13 +31,25 @@ template
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto* x = context.Input("X");
- auto* y = context.Input("Y");
- auto* z = context.Output("Out");
+ const Tensor* x = context.Input("X");
+ const Tensor* y = context.Input("Y");
+ Tensor* z = context.Output("Out");
+ const Tensor x_matrix =
+ x->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *x, context.template Attr("x_num_col_dims"))
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *y, context.template Attr("y_num_col_dims"))
+ : *y;
+
z->mutable_data(context.GetPlace());
auto* device_context =
const_cast(context.device_context_);
- math::matmul(*x, false, *y, false, 1, z, 0, device_context);
+ math::matmul(x_matrix, false, y_matrix, false, 1, z, 0,
+ device_context);
}
};
@@ -45,23 +57,39 @@ template
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
- auto* x = ctx.Input("X");
- auto* y = ctx.Input("Y");
- auto* dout = ctx.Input(framework::GradVarName("Out"));
+ int x_num_col_dims = ctx.template Attr("x_num_col_dims");
+ int y_num_col_dims = ctx.template Attr("y_num_col_dims");
+ const Tensor* x = ctx.Input("X");
+ const Tensor* y = ctx.Input("Y");
+ const Tensor x_matrix =
+ x->dims().size() > 2 ? framework::ReshapeToMatrix(*x, x_num_col_dims)
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2 ? framework::ReshapeToMatrix(*y, y_num_col_dims)
+ : *y;
+ const Tensor* dout = ctx.Input(framework::GradVarName("Out"));
- auto* dx = ctx.Output(framework::GradVarName("X"));
- auto* dy = ctx.Output(framework::GradVarName("Y"));
+ Tensor* dx = ctx.Output(framework::GradVarName("X"));
+ Tensor* dy = ctx.Output(framework::GradVarName("Y"));
auto* device_context =
const_cast(ctx.device_context_);
if (dx) {
dx->mutable_data(ctx.GetPlace());
+ Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dx, x_num_col_dims)
+ : *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
- math::matmul(*dout, false, *y, true, 1, dx, 0, device_context);
+ math::matmul(*dout, false, y_matrix, true, 1, &dx_matrix, 0,
+ device_context);
}
if (dy) {
dy->mutable_data(ctx.GetPlace());
+ Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dy, y_num_col_dims)
+ : *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
- math::matmul(*x, true, *dout, false, 1, dy, 0, device_context);
+ math::matmul(x_matrix, true, *dout, false, 1, &dy_matrix, 0,
+ device_context);
}
}
};
diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc
index 30b4b40431..fa8f0ff1a8 100644
--- a/paddle/operators/rowwise_add_op.cc
+++ b/paddle/operators/rowwise_add_op.cc
@@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("b")->dims();
-
- PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
- PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
- PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
- PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1");
- ctx.Output("Out")->Resize(ctx.Input("X")->dims());
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
+ PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1");
+ ctx.Output("Out")->Resize(x_dims);
}
};
@@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
- auto dims0 = ctx.Input("X")->dims();
- auto dims1 = ctx.Input("b")->dims();
- PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
auto *dx = ctx.Output(framework::GradVarName("X"));
auto *db = ctx.Output(framework::GradVarName("b"));
- if (dx) dx->Resize(dims0);
- if (db) db->Resize(dims1);
+ if (dx) dx->Resize(x_dims);
+ if (db) db->Resize(b_dims);
}
};
diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h
index 4e926d9f29..35774b9409 100644
--- a/paddle/operators/rowwise_add_op.h
+++ b/paddle/operators/rowwise_add_op.h
@@ -33,10 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output("Out");
out->mutable_data(context.GetPlace());
-
- auto input = EigenMatrix::From(*context.Input("X"));
- auto bias = EigenVector::From(*context.Input("b"));
- auto output = EigenMatrix::From(*out);
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
+ auto input =
+ EigenMatrix::Reshape(*context.Input("X"), num_col_dims);
+ auto bias = EigenVector::Flatten(*context.Input("b"));
+ auto output = EigenMatrix::Reshape(*out, num_col_dims);
const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size;
@@ -54,12 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel {
auto* dout = context.Input(framework::GradVarName("Out"));
auto* dx = context.Output(framework::GradVarName("X"));
auto* db = context.Output(framework::GradVarName("b"));
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
- auto out_grad = EigenMatrix::From(*dout);
+ auto out_grad = EigenMatrix::Reshape(*dout, num_col_dims);
auto place = context.GetEigenDevice();
+
if (dx) {
dx->mutable_data(context.GetPlace());
- EigenMatrix::From(*dx).device(place) = out_grad;
+ EigenMatrix::Reshape(*dx, num_col_dims).device(place) = out_grad;
}
if (db) {
diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc
new file mode 100644
index 0000000000..5805826ee8
--- /dev/null
+++ b/paddle/operators/sum_op.cc
@@ -0,0 +1,73 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/sum_op.h"
+#include
+
+namespace paddle {
+namespace operators {
+using framework::Tensor;
+
+class SumOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto ins = ctx.MultiInput("X");
+ auto *out = ctx.Output("Out");
+ int N = ins.size();
+
+ auto in_dim = ins[0]->dims();
+
+ PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
+ for (int i = 1; i < N; i++) {
+ auto dim = ins[i]->dims();
+ PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
+ }
+ out->Resize(in_dim);
+ }
+};
+
+class SumOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "the input tensors of sum operator.").AsDuplicable();
+ AddOutput("Out", "the output tensor of sum operator.");
+ AddComment(R"DOC(
+ Sum the input tensors.
+ )DOC");
+ }
+};
+
+class SumGradOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto outputs = ctx.MultiOutput(framework::GradVarName("X"));
+ auto dims = ctx.Input(framework::GradVarName("Out"))->dims();
+ for (auto output : outputs) {
+ output->Resize(dims);
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
+REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_CPU_KERNEL(sum_grad,
+ ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.cu b/paddle/operators/sum_op.cu
new file mode 100644
index 0000000000..a465cf3659
--- /dev/null
+++ b/paddle/operators/sum_op.cu
@@ -0,0 +1,18 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/sum_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_GPU_KERNEL(sum_grad,
+ ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h
new file mode 100644
index 0000000000..0b1e9ebaa3
--- /dev/null
+++ b/paddle/operators/sum_op.h
@@ -0,0 +1,65 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template
+using EigenVector = framework::EigenVector;
+
+template
+class SumKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto ins = context.MultiInput("X");
+ auto* out = context.Output("Out");
+ out->mutable_data(context.GetPlace());
+
+ auto place = context.GetEigenDevice();
+ auto result = EigenVector::Flatten(*out);
+
+ int N = ins.size();
+ auto in = EigenVector::Flatten(*(ins[0]));
+ result.device(place) = in;
+ for (int i = 1; i < N; i++) {
+ auto in = EigenVector::Flatten(*(ins[i]));
+ result.device(place) = result + in;
+ }
+ }
+};
+
+template
+class SumGradKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto* input = context.Input(framework::GradVarName("Out"));
+ auto outs = context.MultiOutput(framework::GradVarName("X"));
+ for (auto out : outs) {
+ out->mutable_data(context.GetPlace());
+ }
+
+ auto place = context.GetEigenDevice();
+ auto in = EigenVector::Flatten(*input);
+ for (auto out : outs) {
+ auto result = EigenVector::Flatten(*out);
+ result.device(place) = in;
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc
new file mode 100644
index 0000000000..38d2f0a09a
--- /dev/null
+++ b/paddle/operators/top_k_op.cc
@@ -0,0 +1,67 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/top_k_op.h"
+
+namespace paddle {
+namespace operators {
+
+class TopkOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
+ "Input of TopkOP must be initialized.");
+ auto *input = ctx.Input("X");
+ const int k = static_cast(ctx.Attr("k"));
+
+ PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
+ PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape");
+ PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k,
+ "input must have >= k columns");
+
+ framework::DDim dims = input->dims();
+ dims[dims.size() - 1] = k;
+ ctx.Output("Out")->Resize(dims);
+ ctx.Output("Indices")->Resize(dims);
+ }
+};
+
+class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ TopkOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "The input of Topk op");
+ AddOutput("Out", "The output tensor of Topk op");
+ AddOutput("Indices", "The indices of Topk elements of input");
+ AddComment(
+ R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j].
+
+ For matrices, computes the top k entries in each row. )DOC");
+ AddAttr("k",
+ "Number of top elements to look for along the last "
+ "dimension (along each row for matrices).")
+ .SetDefault(1);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker);
+REGISTER_OP_CPU_KERNEL(top_k,
+ ops::TopkKernel);
diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu
new file mode 100644
index 0000000000..afe4d149c5
--- /dev/null
+++ b/paddle/operators/top_k_op.cu
@@ -0,0 +1,318 @@
+/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/op_registry.h"
+#include "paddle/platform/assert.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+struct Pair {
+ __device__ __forceinline__ Pair() {}
+ __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
+
+ __device__ __forceinline__ void set(T value, int id) {
+ v = value;
+ id = id;
+ }
+
+ __device__ __forceinline__ void operator=(const Pair& in) {
+ v = in.v;
+ id = in.id;
+ }
+
+ __device__ __forceinline__ bool operator<(const T value) const {
+ return (v < value);
+ }
+
+ __device__ __forceinline__ bool operator<(const Pair& in) const {
+ return (v < in.v) || ((v == in.v) && (id > in.id));
+ }
+
+ __device__ __forceinline__ bool operator>(const Pair& in) const {
+ return (v > in.v) || ((v == in.v) && (id < in.id));
+ }
+
+ T v;
+ int id;
+};
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p,
+ int beam_size) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* src,
+ bool& firstStep, bool& is_empty,
+ Pair& max, int dim,
+ const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, src, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, src, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* val,
+ int* col, bool& firstStep,
+ bool& is_empty, Pair& max,
+ int dim, const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, val, col, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, val, col, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid,
+ Pair topk[], T** topVal,
+ int** topIds, int& beam, int& k,
+ const int tid, const int warp) {
+ while (true) {
+ __syncthreads();
+ if (tid < BlockSize / 2) {
+ if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
+ maxid[tid] = tid + BlockSize / 2;
+ } else {
+ maxid[tid] = tid;
+ }
+ }
+ __syncthreads();
+ for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
+ if (tid < stride) {
+ if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
+ maxid[tid] = maxid[tid + stride];
+ }
+ }
+ __syncthreads();
+ }
+ __syncthreads();
+
+ if (tid == 0) {
+ **topVal = sh_topk[maxid[0]].v;
+ **topIds = sh_topk[maxid[0]].id;
+ (*topVal)++;
+ (*topIds)++;
+ }
+ if (tid == maxid[0]) beam++;
+ if (--k == 0) break;
+ __syncthreads();
+
+ if (tid == maxid[0]) {
+ if (beam < MaxLength) {
+ sh_topk[tid] = topk[beam];
+ }
+ }
+ if (maxid[0] / 32 == warp) {
+ if (__shfl(beam, (maxid[0]) % 32, 32) == MaxLength) break;
+ }
+ }
+}
+
+/**
+ * Each block compute one sample.
+ * In a block:
+ * 1. every thread get top MaxLength value;
+ * 2. merge to sh_topk, block reduce and get max value;
+ * 3. go to the second setp, until one thread's topk value is null;
+ * 4. go to the first setp, until get the topk value.
+ */
+template
+__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
+ const T* src, int lds, int dim, int k) {
+ __shared__ Pair sh_topk[BlockSize];
+ __shared__ int maxid[BlockSize / 2];
+ const int tid = threadIdx.x;
+ const int warp = threadIdx.x / 32;
+ output += blockIdx.x * output_stride;
+ indices += blockIdx.x * k;
+
+ Pair topk[MaxLength];
+ int beam = MaxLength;
+ Pair max;
+ bool is_empty = false;
+ bool firststep = true;
+
+ for (int k = 0; k < MaxLength; k++) {
+ topk[k].set(-INFINITY, -1);
+ }
+ while (k) {
+ ThreadGetTopK(topk, beam, k,
+ src + blockIdx.x * lds, firststep,
+ is_empty, max, dim, tid);
+
+ sh_topk[tid] = topk[0];
+ BlockReduce(sh_topk, maxid, topk, &output,
+ &indices, beam, k, tid, warp);
+ }
+}
+
+template
+class TopkOpCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+ "It must use GPUPlace.");
+ auto* input = ctx.Input("X");
+ auto* output = ctx.Output("Out");
+ auto* indices = ctx.Output("Indices");
+ size_t k = static_cast(ctx.Attr("k"));
+
+ const T* input_data = input->data();
+
+ T* output_data = output->mutable_data(ctx.GetPlace());
+ // FIXME(typhoonzero): data is always converted to type T?
+ int* indices_data = indices->mutable_data(ctx.GetPlace());
+
+ size_t input_height = input->dims()[0];
+ size_t input_width = input->dims()[1];
+ if (k > input_width) k = input_width;
+
+ // NOTE: pass lds and dim same to input width.
+ // NOTE: old matrix implementation of stride is different to eigen.
+ // TODO(typhoonzero): launch kernel on specified stream.
+ // TODO(typhoonzero): refine this kernel.
+ dim3 threads(256, 1);
+ dim3 grid(input_height, 1);
+
+ KeMatrixTopK<<>>(
+ output_data, output->dims()[1], indices_data, input_data, input_width,
+ input_width, int(k));
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+REGISTER_OP_GPU_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel);
diff --git a/paddle/operators/top_k_op.h b/paddle/operators/top_k_op.h
new file mode 100644
index 0000000000..ef66acc1d5
--- /dev/null
+++ b/paddle/operators/top_k_op.h
@@ -0,0 +1,76 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+using EigenMatrix = framework::EigenMatrix;
+
+template
+class TopkKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ // Get the top k elements of each row of input tensor
+ // FIXME: only deal with matrix(2d tensor).
+ auto* input = ctx.Input("X");
+ auto* output = ctx.Output("Out");
+ auto* indices = ctx.Output("Indices");
+ // k is determined by Attr
+ const size_t k = static_cast(ctx.Attr("k"));
+
+ T* output_data = output->mutable_data(ctx.GetPlace());
+ T* indices_data = indices->mutable_data(ctx.GetPlace());
+
+ auto eg_input = EigenMatrix::From(*input);
+
+ // reshape input to a flattern matrix(like flat_inner_dims)
+ framework::DDim inputdims = input->dims();
+ const size_t row = framework::product(
+ framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
+ const size_t col = inputdims[inputdims.size() - 1];
+ Eigen::DSizes flat2dims(row, col);
+ // NOTE: eigen shape doesn't affect paddle tensor.
+ eg_input.reshape(flat2dims);
+
+ for (size_t i = 0; i < row; i++) {
+ std::vector> vec;
+ for (size_t j = 0; j < col; j++) {
+ vec.push_back(std::pair(eg_input(i, j), j));
+ }
+
+ std::partial_sort(
+ vec.begin(), vec.begin() + k, vec.end(),
+ [](const std::pair& l, const std::pair& r) {
+ return l.first > r.first;
+ });
+ for (size_t j = 0; j < k; j++) {
+ output_data[i * k + j] = vec[j].first;
+ indices_data[i * k + j] = vec[j].second;
+ }
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h
index 81448897e9..64fcbd93b6 100644
--- a/paddle/platform/enforce.h
+++ b/paddle/platform/enforce.h
@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/string/printf.h"
#include "paddle/string/to_string.h"
-#ifdef __GNUC__
-#include // for __cxa_demangle
-#endif
-
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
@@ -46,19 +42,6 @@ limitations under the License. */
namespace paddle {
namespace platform {
-namespace {
-#ifdef __GNUC__
-inline std::string demangle(std::string name) {
- int status = -4; // some arbitrary value to eliminate the compiler warning
- std::unique_ptr res{
- abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
- return (status == 0) ? res.get() : name;
-}
-#else
-inline std::string demangle(std::string name) { return name; }
-#endif
-}
-
struct EnforceNotMet : public std::exception {
std::exception_ptr exp_;
std::string err_str_;
@@ -79,7 +62,7 @@ struct EnforceNotMet : public std::exception {
Dl_info info;
for (int i = 0; i < size; ++i) {
if (dladdr(call_stack[i], &info)) {
- auto demangled = demangle(info.dli_sname);
+ auto demangled = info.dli_sname;
auto addr_offset = static_cast(call_stack[i]) -
static_cast(info.dli_saddr);
sout << string::Sprintf("%-3d %*0p %s + %zd\n", i,
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index c708f471e3..c85fffb559 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -50,7 +50,9 @@ USE_OP(minus);
USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
+USE_OP(top_k);
USE_OP(squared_l2_distance);
+USE_OP(sum);
namespace paddle {
namespace framework {
@@ -216,7 +218,10 @@ All parameter, weight, gradient are variables in Paddle.
-> std::map> {
return op.Outputs();
})
+ .def("output_vars",
+ [](const OperatorBase &op) { return op.OutputVars(true); })
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
+ .def("input_vars", [](const OperatorBase &op) { return op.InputVars(); })
.def("__str__", &OperatorBase::DebugString)
.def("no_intermediate_outputs",
[](const OperatorBase &op) { return op.OutputVars(false); })
diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh
index 1798642022..2ac455d771 100644
--- a/paddle/scripts/docker/build.sh
+++ b/paddle/scripts/docker/build.sh
@@ -30,6 +30,8 @@ Configuring cmake in /paddle/build ...
-DCMAKE_BUILD_TYPE=Release
-DWITH_DOC=OFF
-DWITH_GPU=${WITH_GPU:-OFF}
+ -DWITH_MKLDNN=${WITH_MKLDNN:-ON}
+ -DWITH_MKLML=${WITH_MKLML:-ON}
-DWITH_AVX=${WITH_AVX:-OFF}
-DWITH_GOLANG=${WITH_GOLANG:-ON}
-DWITH_SWIG_PY=ON
@@ -37,7 +39,7 @@ Configuring cmake in /paddle/build ...
-DWITH_PYTHON=${WITH_PYTHON:-ON}
-DWITH_SWIG_PY=${WITH_SWIG_PY:-ON}
-DCUDNN_ROOT=/usr/
- -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
+ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON}
-DWITH_TESTING=${WITH_TESTING:-ON}
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
========================================
@@ -50,6 +52,8 @@ cmake .. \
-DCMAKE_BUILD_TYPE=Release \
-DWITH_DOC=OFF \
-DWITH_GPU=${WITH_GPU:-OFF} \
+ -DWITH_MKLDNN=${WITH_MKLDNN:-ON} \
+ -DWITH_MKLML=${WITH_MKLML:-ON} \
-DWITH_AVX=${WITH_AVX:-OFF} \
-DWITH_GOLANG=${WITH_GOLANG:-ON} \
-DWITH_SWIG_PY=${WITH_SWIG_PY:-ON} \
diff --git a/paddle/utils/Util.cpp b/paddle/utils/Util.cpp
index b18b73e06a..2755fdd9cd 100644
--- a/paddle/utils/Util.cpp
+++ b/paddle/utils/Util.cpp
@@ -320,6 +320,9 @@ void loadFileList(const std::string& fileListFileName,
}
double getMemoryUsage() {
+#if defined(__ANDROID__)
+ return 0.0;
+#else
FILE* fp = fopen("/proc/meminfo", "r");
CHECK(fp) << "failed to fopen /proc/meminfo";
size_t bufsize = 256 * sizeof(char);
@@ -357,6 +360,7 @@ double getMemoryUsage() {
delete[] buf;
double usedMem = 1.0 - 1.0 * (freeMem + bufMem + cacheMem) / totalMem;
return usedMem;
+#endif
}
SyncThreadPool* getGlobalSyncThreadPool() {
diff --git a/paddle/utils/Util.h b/paddle/utils/Util.h
index 613844669d..22ce2534d3 100644
--- a/paddle/utils/Util.h
+++ b/paddle/utils/Util.h
@@ -33,6 +33,13 @@ limitations under the License. */
#include "Flags.h"
#include "hl_gpu.h"
+#if defined(__ANDROID__) && (__ANDROID_API__ < 21)
+inline int rand_r(unsigned int* seedp) {
+ (void)seedp;
+ return rand();
+}
+#endif
+
/**
* Loop over the elements in a container
* TODO(yuyang18): It's this foreach useful? Why not use C++ 11 foreach,
diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto
index 0f44d8cb8d..ebf0911d6e 100644
--- a/proto/ModelConfig.proto
+++ b/proto/ModelConfig.proto
@@ -271,6 +271,7 @@ message ImageConfig {
// The size of input feature map.
required uint32 img_size = 8;
optional uint32 img_size_y = 9;
+ optional uint32 img_size_z = 10 [ default = 1 ];
}
message PriorBoxConfig {
@@ -288,8 +289,8 @@ message PadConfig {
}
message ReshapeConfig {
- repeated uint32 heightAxis = 1;
- repeated uint32 widthAxis = 2;
+ repeated uint32 height_axis = 1;
+ repeated uint32 width_axis = 2;
}
message MultiBoxLossConfig {
@@ -519,6 +520,7 @@ message LayerConfig {
// for HuberRegressionLoss
optional double delta = 57 [ default = 1.0 ];
+ // for 3D data
optional uint64 depth = 58 [ default = 1 ];
// for switch order layer
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index 11dc84ae20..356e1d8b6f 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -1332,6 +1332,12 @@ def parse_image(image, input_layer_name, image_conf):
get_img_size(input_layer_name, image_conf.channels)
+def parse_image3d(image, input_layer_name, image_conf):
+ image_conf.channels = image.channels
+ image_conf.img_size, image_conf.img_size_y, image_conf.img_size_z = \
+ get_img3d_size(input_layer_name, image_conf.channels)
+
+
def parse_norm(norm, input_layer_name, norm_conf):
norm_conf.norm_type = norm.norm_type
config_assert(
@@ -2365,9 +2371,11 @@ class BatchNormLayer(LayerBase):
name,
inputs,
bias=True,
+ img3D=False,
use_global_stats=True,
moving_average_fraction=0.9,
batch_norm_type=None,
+ mean_var_names=None,
**xargs):
if inputs is None:
inputs = []
@@ -2409,24 +2417,69 @@ class BatchNormLayer(LayerBase):
input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].image_conf
- parse_image(self.inputs[0].image, input_layer.name, image_conf)
-
- # Only pass the width and height of input to batch_norm layer
- # when either of it is non-zero.
- if input_layer.width != 0 or input_layer.height != 0:
- self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
- image_conf.channels, False)
+ if img3D:
+ parse_image3d(self.inputs[0].image, input_layer.name, image_conf)
+ # Only pass the width and height of input to batch_norm layer
+ # when either of it is non-zero.
+ if input_layer.width != 0 or input_layer.height != 0:
+ self.set_cnn_layer(
+ input_layer_name=name,
+ depth=image_conf.img_size_z,
+ height=image_conf.img_size_y,
+ width=image_conf.img_size,
+ channels=image_conf.channels,
+ is_print=True)
+ else:
+ self.set_layer_size(input_layer.size)
else:
- self.set_layer_size(input_layer.size)
+ parse_image(self.inputs[0].image, input_layer.name, image_conf)
+ # Only pass the width and height of input to batch_norm layer
+ # when either of it is non-zero.
+ if input_layer.width != 0 or input_layer.height != 0:
+ self.set_cnn_layer(
+ input_layer_name=name,
+ height=image_conf.img_size_y,
+ width=image_conf.img_size,
+ channels=image_conf.channels,
+ is_print=True)
+ else:
+ self.set_layer_size(input_layer.size)
psize = self.calc_parameter_size(image_conf)
dims = [1, psize]
+ if mean_var_names is not None:
+ assert len(mean_var_names) == 2
+ self.inputs[1].parameter_name = mean_var_names[0]
+ self.inputs[2].parameter_name = mean_var_names[1]
+
self.create_input_parameter(0, psize)
self.create_input_parameter(1, psize, dims)
self.create_input_parameter(2, psize, dims)
self.create_bias_parameter(bias, psize)
+ def set_cnn_layer(self,
+ input_layer_name,
+ depth=None,
+ height=None,
+ width=None,
+ channels=None,
+ is_print=True):
+ depthIsNone = False
+ if depth is None:
+ depth = 1
+ depthIsNone = True
+ size = depth * height * width * channels
+ self.set_layer_size(size)
+ self.set_layer_height_width(height, width)
+ self.set_layer_depth(depth)
+ if is_print and depthIsNone:
+ print("output for %s: c = %d, h = %d, w = %d, size = %d" %
+ (input_layer_name, channels, height, width, size))
+ elif is_print:
+ print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
+ (input_layer_name, channels, depth, height, width, size))
+
def calc_parameter_size(self, image_conf):
return image_conf.channels
@@ -2688,9 +2741,20 @@ class AddToLayer(LayerBase):
super(AddToLayer, self).__init__(
name, 'addto', 0, inputs=inputs, **xargs)
config_assert(len(inputs) > 0, 'inputs cannot be empty for AddToLayer')
- for input_index in xrange(len(self.inputs)):
- input_layer = self.get_input_layer(input_index)
- self.set_layer_size(input_layer.size)
+
+ if len(self.inputs) > 1:
+ for input_index in xrange(len(self.inputs)):
+ assert self.get_input_layer(0).height == self.get_input_layer(
+ input_index).height
+ assert self.get_input_layer(0).width == self.get_input_layer(
+ input_index).width
+ assert self.get_input_layer(0).depth == self.get_input_layer(
+ input_index).depth
+
+ self.set_layer_size(self.get_input_layer(0).size)
+ self.set_layer_height_width(self.get_input_layer(0).height, \
+ self.get_input_layer(0).width)
+ self.set_layer_depth(self.get_input_layer(0).depth)
self.create_bias_parameter(bias, self.config.size)
@@ -3370,11 +3434,20 @@ class ConcatenateLayer(LayerBase):
name, 'concat', 0, inputs=inputs, **xargs)
size = 0
for input_index in xrange(len(self.inputs)):
+ assert self.get_input_layer(0).height == self.get_input_layer(
+ input_index).height
+ assert self.get_input_layer(0).width == self.get_input_layer(
+ input_index).width
+ assert self.get_input_layer(0).depth == self.get_input_layer(
+ input_index).depth
input_layer = self.get_input_layer(input_index)
input = self.inputs[input_index]
if self.config.size == 0:
size += input_layer.size
+ self.set_layer_height_width(self.get_input_layer(0).height, \
+ self.get_input_layer(0).width)
+ self.set_layer_depth(self.get_input_layer(0).depth)
self.set_layer_size(size)
@@ -3675,8 +3748,8 @@ class SwitchOrderLayer(LayerBase):
def __init__(self, name, inputs, reshape, **xargs):
super(SwitchOrderLayer, self).__init__(
name, 'switch_order', 0, inputs=inputs, **xargs)
- self.config.reshape_conf.heightAxis.extend(reshape['height'])
- self.config.reshape_conf.widthAxis.extend(reshape['width'])
+ self.config.reshape_conf.height_axis.extend(reshape['height'])
+ self.config.reshape_conf.width_axis.extend(reshape['width'])
# Deprecated, use a new layer specific class instead
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index cba45bd3af..4b1d80d3db 100644
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -354,6 +354,10 @@ class LayerOutput(object):
def height(self):
return cp.g_layer_map[self.full_name].height
+ @property
+ def depth(self):
+ return cp.g_layer_map[self.full_name].depth
+
def set_input(self, input):
"""
Set the input for a memory layer. Can only be used for memory layer
@@ -943,7 +947,7 @@ def data_layer(name, size, depth=None, height=None, width=None,
if height is not None and width is not None:
num_filters = size / (width * height * depth)
assert num_filters * width * height * depth == size, \
- "size=%s width=%s height=%s depth=%s" % (size, width, height, depth)
+ "size=%s width=%s height=%s depth=%s" % (size, width, height, depth)
return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters)
@@ -1219,7 +1223,8 @@ def detection_output_layer(input_loc,
name=None):
"""
Apply the NMS to the output of network and compute the predict bounding
- box location.
+ box location. The output of this layer could be None if there is no valid
+ bounding box.
:param name: The Layer Name.
:type name: basestring
@@ -2953,13 +2958,15 @@ def img_cmrnorm_layer(input,
def batch_norm_layer(input,
act=None,
name=None,
+ img3D=False,
num_channels=None,
bias_attr=None,
param_attr=None,
layer_attr=None,
batch_norm_type=None,
moving_average_fraction=0.9,
- use_global_stats=None):
+ use_global_stats=None,
+ mean_var_names=None):
"""
Batch Normalization Layer. The notation of this layer as follow.
@@ -3026,6 +3033,8 @@ def batch_norm_layer(input,
:math:`runningMean = newMean*(1-factor)
+ runningMean*factor`
:type moving_average_fraction: float.
+ :param mean_var_names: [mean name, variance name]
+ :type mean_var_names: string list
:return: LayerOutput object.
:rtype: LayerOutput
"""
@@ -3039,6 +3048,7 @@ def batch_norm_layer(input,
(batch_norm_type == "cudnn_batch_norm")
l = Layer(
name=name,
+ img3D=img3D,
inputs=Input(
input.name, image=Image(channels=num_channels), **param_attr.attr),
active_type=act.name,
@@ -3047,6 +3057,7 @@ def batch_norm_layer(input,
bias=ParamAttr.to_bias(bias_attr),
moving_average_fraction=moving_average_fraction,
use_global_stats=use_global_stats,
+ mean_var_names=mean_var_names,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
@@ -6410,7 +6421,7 @@ def gated_unit_layer(input,
@wrap_name_default('switch_order')
def switch_order_layer(input,
name=None,
- reshape=None,
+ reshape_axis=None,
act=None,
layer_attr=None):
"""
@@ -6421,8 +6432,9 @@ def switch_order_layer(input,
The example usage is:
.. code-block:: python
+ reshape_axis = 3
+ switch = switch_order(input=layer, name='switch', reshape_axis=reshape_axis)
reshape = {'height':[ 0, 1, 2], 'width':[3]}
- switch = switch_order(input=layer, name='switch', reshape=reshape)
:param input: The input layer.
:type input: LayerOutput
@@ -6434,6 +6446,11 @@ def switch_order_layer(input,
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
+ assert reshape_axis != None and (reshape_axis > 0 and reshape_axis < 4)
+ height = [ele for ele in xrange(reshape_axis)]
+ width = [ele for ele in range(reshape_axis, 4)]
+ reshape = {'height': height, 'width': width}
+
l = Layer(
name=name,
inputs=input.name,
@@ -6444,6 +6461,7 @@ def switch_order_layer(input,
return LayerOutput(
name=name,
layer_type=LayerType.SWITCH_ORDER_LAYER,
+ activation=act,
parents=input,
size=l.config.size)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
index df872a90ff..8a204a96f3 100755
--- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
+++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
@@ -10,6 +10,6 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
-test_conv3d_layer test_deconv3d_layer)
+test_conv3d_layer test_deconv3d_layer test_BatchNorm3D)
export whole_configs=(test_split_datasource)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
index 1a577b8d9b..5ddf6052df 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
@@ -62,6 +62,7 @@ layers {
moving_average_fraction: 0.9
height: 227
width: 227
+ depth: 1
}
layers {
name: "__crmnorm_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
index 2818389b16..c0252b945b 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
@@ -62,6 +62,7 @@ layers {
moving_average_fraction: 0.9
height: 256
width: 256
+ depth: 1
}
layers {
name: "__crmnorm_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
new file mode 100644
index 0000000000..832ed24a31
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
@@ -0,0 +1,92 @@
+type: "nn"
+layers {
+ name: "data3D"
+ type: "data"
+ size: 360
+ active_type: ""
+ height: 6
+ width: 20
+ depth: 3
+}
+layers {
+ name: "__batch_norm_0__"
+ type: "batch_norm"
+ size: 360
+ active_type: "relu"
+ inputs {
+ input_layer_name: "data3D"
+ input_parameter_name: "___batch_norm_0__.w0"
+ image_conf {
+ channels: 1
+ img_size: 20
+ img_size_y: 6
+ img_size_z: 3
+ }
+ }
+ inputs {
+ input_layer_name: "data3D"
+ input_parameter_name: "___batch_norm_0__.w1"
+ }
+ inputs {
+ input_layer_name: "data3D"
+ input_parameter_name: "___batch_norm_0__.w2"
+ }
+ bias_parameter_name: "___batch_norm_0__.wbias"
+ moving_average_fraction: 0.9
+ height: 6
+ width: 20
+ depth: 3
+}
+parameters {
+ name: "___batch_norm_0__.w0"
+ size: 1
+ initial_mean: 1.0
+ initial_std: 0.0
+ initial_strategy: 0
+ initial_smart: false
+}
+parameters {
+ name: "___batch_norm_0__.w1"
+ size: 1
+ initial_mean: 0.0
+ initial_std: 0.0
+ dims: 1
+ dims: 1
+ initial_strategy: 0
+ initial_smart: false
+ is_static: true
+ is_shared: true
+}
+parameters {
+ name: "___batch_norm_0__.w2"
+ size: 1
+ initial_mean: 0.0
+ initial_std: 0.0
+ dims: 1
+ dims: 1
+ initial_strategy: 0
+ initial_smart: false
+ is_static: true
+ is_shared: true
+}
+parameters {
+ name: "___batch_norm_0__.wbias"
+ size: 1
+ initial_mean: 0.0
+ initial_std: 0.0
+ dims: 1
+ dims: 1
+ initial_strategy: 0
+ initial_smart: false
+}
+input_layer_names: "data3D"
+output_layer_names: "__batch_norm_0__"
+sub_models {
+ name: "root"
+ layer_names: "data3D"
+ layer_names: "__batch_norm_0__"
+ input_layer_names: "data3D"
+ output_layer_names: "__batch_norm_0__"
+ is_recurrent_layer_group: false
+}
+
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
index b110e91498..8a1399efad 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
@@ -74,6 +74,9 @@ layers {
inputs {
input_layer_name: "__bidirectional_gru_0___bw"
}
+ height: 0
+ width: 0
+ depth: 1
}
parameters {
name: "___bidirectional_gru_0___fw_transform.w0"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
index 8133aa9c8d..046037936a 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
@@ -16,6 +16,9 @@ layers {
inputs {
input_layer_name: "data"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_1__"
@@ -28,6 +31,9 @@ layers {
inputs {
input_layer_name: "__addto_0__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_2__"
@@ -40,6 +46,9 @@ layers {
inputs {
input_layer_name: "__addto_1__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_3__"
@@ -52,6 +61,9 @@ layers {
inputs {
input_layer_name: "__addto_2__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_4__"
@@ -64,6 +76,9 @@ layers {
inputs {
input_layer_name: "__addto_3__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_5__"
@@ -76,6 +91,9 @@ layers {
inputs {
input_layer_name: "__addto_4__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_6__"
@@ -88,6 +106,9 @@ layers {
inputs {
input_layer_name: "__addto_5__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_7__"
@@ -100,6 +121,9 @@ layers {
inputs {
input_layer_name: "__addto_6__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_8__"
@@ -112,6 +136,9 @@ layers {
inputs {
input_layer_name: "__addto_7__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_9__"
@@ -124,6 +151,9 @@ layers {
inputs {
input_layer_name: "__addto_8__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_10__"
@@ -136,6 +166,9 @@ layers {
inputs {
input_layer_name: "__addto_9__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_11__"
@@ -148,6 +181,9 @@ layers {
inputs {
input_layer_name: "__addto_10__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_12__"
@@ -160,6 +196,9 @@ layers {
inputs {
input_layer_name: "__addto_11__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_13__"
@@ -172,6 +211,9 @@ layers {
inputs {
input_layer_name: "__addto_12__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_14__"
@@ -184,6 +226,9 @@ layers {
inputs {
input_layer_name: "__addto_13__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_15__"
@@ -196,6 +241,9 @@ layers {
inputs {
input_layer_name: "__addto_14__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_16__"
@@ -208,6 +256,9 @@ layers {
inputs {
input_layer_name: "__addto_15__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_17__"
@@ -220,6 +271,9 @@ layers {
inputs {
input_layer_name: "__addto_16__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_18__"
@@ -232,6 +286,9 @@ layers {
inputs {
input_layer_name: "__addto_17__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_19__"
@@ -244,6 +301,9 @@ layers {
inputs {
input_layer_name: "__addto_18__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_20__"
@@ -256,6 +316,9 @@ layers {
inputs {
input_layer_name: "__addto_19__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_21__"
@@ -268,6 +331,9 @@ layers {
inputs {
input_layer_name: "__addto_20__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_22__"
@@ -280,6 +346,9 @@ layers {
inputs {
input_layer_name: "__addto_21__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_23__"
@@ -292,6 +361,9 @@ layers {
inputs {
input_layer_name: "__addto_22__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_24__"
@@ -304,6 +376,9 @@ layers {
inputs {
input_layer_name: "__addto_23__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_25__"
@@ -316,6 +391,9 @@ layers {
inputs {
input_layer_name: "__addto_24__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_26__"
@@ -328,6 +406,9 @@ layers {
inputs {
input_layer_name: "__addto_25__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_27__"
@@ -340,6 +421,9 @@ layers {
inputs {
input_layer_name: "__addto_26__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_28__"
@@ -352,6 +436,9 @@ layers {
inputs {
input_layer_name: "__addto_27__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_29__"
@@ -364,6 +451,9 @@ layers {
inputs {
input_layer_name: "__addto_28__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_30__"
@@ -376,6 +466,9 @@ layers {
inputs {
input_layer_name: "__addto_29__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_31__"
@@ -388,6 +481,9 @@ layers {
inputs {
input_layer_name: "__addto_30__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__fc_layer_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
index d0ad388165..7a2f3eab38 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
@@ -22,6 +22,9 @@ layers {
inputs {
input_layer_name: "b"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__concat_0__"
@@ -34,6 +37,9 @@ layers {
inputs {
input_layer_name: "b"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__concat_1__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py b/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py
new file mode 100644
index 0000000000..a991b22252
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py
@@ -0,0 +1,11 @@
+from paddle.trainer_config_helpers import *
+
+settings(batch_size=1000, learning_rate=1e-4)
+
+#data = data_layer(name='data', size=180, width=30, height=6)
+#batchNorm = batch_norm_layer(data, num_channels=1)
+#outputs(batchNorm)
+
+data3D = data_layer(name='data3D', size=120 * 3, width=20, height=6, depth=3)
+batchNorm3D = batch_norm_layer(data3D, num_channels=1, img3D=True)
+outputs(batchNorm3D)
diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py
index 7589cc9917..e66bf67d79 100644
--- a/python/paddle/v2/event.py
+++ b/python/paddle/v2/event.py
@@ -53,10 +53,13 @@ class BeginPass(object):
class EndPass(WithMetric):
"""
Event On One Pass Training Complete.
+ To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
+ in your event_handler call back
"""
- def __init__(self, pass_id, evaluator):
+ def __init__(self, pass_id, evaluator, gm):
self.pass_id = pass_id
+ self.gm = gm
WithMetric.__init__(self, evaluator)
@@ -73,10 +76,13 @@ class BeginIteration(object):
class EndIteration(WithMetric):
"""
Event On One Batch Training Complete.
+ To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
+ in your event_handler call back
"""
- def __init__(self, pass_id, batch_id, cost, evaluator):
+ def __init__(self, pass_id, batch_id, cost, evaluator, gm):
self.pass_id = pass_id
self.batch_id = batch_id
self.cost = cost
+ self.gm = gm
WithMetric.__init__(self, evaluator)
diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py
index c1585bcffc..4e91924a50 100644
--- a/python/paddle/v2/framework/op.py
+++ b/python/paddle/v2/framework/op.py
@@ -142,8 +142,8 @@ def create_op_creation_method(op_proto):
return OpInfo(
method=__impl__,
name=op_proto.type,
- inputs=[var.name for var in op_proto.inputs],
- outputs=[var.name for var in op_proto.outputs],
+ inputs=[(var.name, var.duplicable) for var in op_proto.inputs],
+ outputs=[(var.name, var.duplicable) for var in op_proto.outputs],
attrs=[attr.name for attr in op_proto.attrs])
@@ -180,9 +180,15 @@ class OperatorFactory(object):
return self.op_methods.get(t)
def get_op_input_names(self, type):
+ return map(lambda x: x[0], self.get_op_info(type).inputs)
+
+ def get_op_inputs(self, type):
return self.get_op_info(type).inputs
def get_op_output_names(self, type):
+ return map(lambda x: x[0], self.get_op_info(type).outputs)
+
+ def get_op_outputs(self, type):
return self.get_op_info(type).outputs
def get_op_attr_names(self, type):
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index aceab794bb..20be26cbdd 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -19,6 +19,7 @@ py_test(test_scatter_op SRCS test_scatter_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(test_fc_op SRCS test_fc_op.py)
py_test(test_minus_op SRCS test_minus_op.py)
+py_test(test_top_k_op SRCS test_top_k_op.py)
py_test(gradient_checker SRCS gradient_checker.py)
@@ -34,5 +35,6 @@ py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
py_test(test_lookup_table SRCS test_lookup_table.py)
py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
+py_test(test_sum_op SRCS test_sum_op.py)
py_test(mnist SRCS mnist.py)
py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py)
diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py
new file mode 100644
index 0000000000..3a6a5dca4c
--- /dev/null
+++ b/python/paddle/v2/framework/tests/op_test.py
@@ -0,0 +1,275 @@
+import unittest
+import numpy as np
+import itertools
+import paddle.v2.framework.core as core
+from paddle.v2.framework.op import Operator
+
+
+def grad_var_name(var_name):
+ return var_name + "@GRAD"
+
+
+def create_op(scope, op_type, inputs, outputs, attrs=None):
+ kwargs = dict()
+
+ for in_name, in_dup in Operator.get_op_inputs(op_type):
+ if in_name in inputs:
+ kwargs[in_name] = []
+ if in_dup:
+ sub_in = inputs[in_name]
+ for sub_in_name in sub_in:
+ var = scope.new_var(sub_in_name)
+ kwargs[in_name].append(sub_in_name)
+ else:
+ var = scope.new_var(in_name)
+ kwargs[in_name].append(in_name)
+
+ for out_name, out_dup in Operator.get_op_outputs(op_type):
+ if out_name in outputs:
+ kwargs[out_name] = []
+ if out_dup:
+ sub_in = outputs[out_name]
+ for sun_in_name in sub_in:
+ var = scope.new_var(sun_in_name)
+ kwargs[out_name].append(sun_in_name)
+ else:
+ var = scope.new_var(out_name)
+ kwargs[out_name].append(out_name)
+
+ for attr_name in Operator.get_op_attr_names(op_type):
+ kwargs[attr_name] = attrs[attr_name]
+ return Operator(op_type, **kwargs)
+
+
+def set_input(scope, op, inputs, place):
+ for in_name, in_dup in Operator.get_op_inputs(op.type()):
+ if in_name in inputs:
+ if in_dup:
+ sub_in = inputs[in_name]
+ for sub_in_name in sub_in:
+ var = scope.find_var(sub_in_name)
+ tensor = var.get_tensor()
+ arr = sub_in[sub_in_name]
+ tensor.set_dims(arr.shape)
+ tensor.set(arr, place)
+ else:
+ var = scope.find_var(in_name)
+ tensor = var.get_tensor()
+ arr = inputs[in_name]
+ tensor.set_dims(arr.shape)
+ tensor.set(arr, place)
+
+
+def set_output_grad(scope, op, outputs, place):
+ for out_name, out_dup in Operator.get_op_outputs(op.type()):
+ if out_name in outputs:
+ if out_dup:
+ sub_out = outputs[out_name]
+ for sub_out_name in sub_out:
+ out_tensor = scope.find_var(sub_out_name).get_tensor()
+ grad_tensor = scope.new_var(grad_var_name(
+ sub_out_name)).get_tensor()
+ grad_tensor.set_dims(out_tensor.shape())
+ data = np.ones(out_tensor.shape(), dtype=np.float32)
+ grad_tensor.set(data, place)
+ else:
+ out_tensor = scope.find_var(out_name).get_tensor()
+ grad_tensor = scope.new_var(grad_var_name(out_name)).get_tensor(
+ )
+ grad_tensor.set_dims(out_tensor.shape())
+ data = np.ones(out_tensor.shape(), dtype=np.float32)
+ grad_tensor.set(data, place)
+
+
+def get_numeric_gradient(scope,
+ op,
+ inputs,
+ input_to_check,
+ output_name,
+ delta=0.005,
+ in_place=False):
+
+ set_input(scope, op, inputs, core.CPUPlace())
+ op.infer_shape(scope)
+
+ tensor_to_check = scope.find_var(input_to_check).get_tensor()
+
+ def product(dim):
+ return reduce(lambda a, b: a * b, dim, 1)
+
+ ctx = core.DeviceContext.create(core.CPUPlace())
+
+ def get_output():
+ op.run(scope, ctx)
+ return np.array(scope.find_var(output_name).get_tensor()).sum()
+
+ tensor_to_check = scope.find_var(input_to_check).get_tensor()
+ tensor_size = product(tensor_to_check.get_dims())
+ gradient_flat = np.zeros(shape=(tensor_size, ), dtype='float32')
+ # we only compute gradient of one element each time.
+ # we use a for loop to compute the gradient of every element.
+ for i in xrange(tensor_size):
+ if in_place:
+ set_input(op, inputs, core.CPUPlace())
+
+ # get one input element throw it's index i.
+ origin = tensor_to_check.get_float_element(i)
+ # add delta to it, run op and then get the sum of the result tensor.
+ x_pos = origin + delta
+ tensor_to_check.set_float_element(i, x_pos)
+ y_pos = get_output()
+
+ if in_place:
+ set_input(op, inputs, core.CPUPlace())
+
+ x_neg = origin - delta
+ tensor_to_check.set_float_element(i, x_neg)
+ y_neg = get_output()
+
+ tensor_to_check.set_float_element(i, origin)
+ gradient_flat[i] = (y_pos - y_neg) / delta / 2
+
+ return gradient_flat.reshape(tensor_to_check.get_dims())
+
+
+def get_backward_op(scope, op, no_grad_set):
+ backward_op = core.Operator.backward(op, no_grad_set)
+ for input in backward_op.input_vars():
+ var = scope.new_var(input)
+ var.get_tensor()
+ for output in backward_op.output_vars():
+ var = scope.new_var(output)
+ var.get_tensor()
+ return backward_op
+
+
+def get_gradient(scope, op, inputs, outputs, grad_name, place,
+ no_grad_set=None):
+ ctx = core.DeviceContext.create(place)
+
+ set_input(scope, op, inputs, place)
+
+ op.infer_shape(scope)
+ op.run(scope, ctx)
+
+ if no_grad_set is None:
+ no_grad_set = set()
+
+ backward_op = get_backward_op(scope, op, no_grad_set)
+ set_output_grad(scope, op, outputs, place)
+
+ backward_op.infer_shape(scope)
+ backward_op.run(scope, ctx)
+
+ out = np.array(scope.find_var(grad_name).get_tensor())
+ return out
+
+
+class OpTest(unittest.TestCase):
+ def check_output_with_place(self, place):
+ self.scope = core.Scope()
+ self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
+ if isinstance(place, core.GPUPlace) and not self.op.support_gpu():
+ return
+ set_input(self.scope, self.op, self.inputs, place)
+ self.op.infer_shape(self.scope)
+ ctx = core.DeviceContext.create(place)
+ self.op.run(self.scope, ctx)
+
+ for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
+ if out_dup:
+ sub_out = self.outputs[out_name]
+ for sub_out_name in sub_out:
+ actual = np.array(
+ self.scope.find_var(sub_out_name).get_tensor())
+ expect = sub_out[sub_out_name]
+ self.assertTrue(
+ np.allclose(
+ actual, expect, atol=1e-05),
+ "output name: " + out_name + "has diff")
+ else:
+ actual = np.array(self.scope.find_var(out_name).get_tensor())
+ expect = self.outputs[out_name]
+ self.assertTrue(
+ np.allclose(
+ actual, expect, atol=1e-05),
+ "output name: " + out_name + "has diff")
+
+ def check_output(self):
+ places = [core.CPUPlace()]
+ if core.is_compile_gpu():
+ places.append(core.GPUPlace(0))
+ for place in places:
+ self.check_output_with_place(place)
+
+ def __assert_is_close(self, numeric_grads, analytic_grads, names,
+ max_relative_error, msg_prefix):
+
+ for a, b, name in itertools.izip(numeric_grads, analytic_grads, names):
+ abs_a = np.abs(a)
+ abs_a[abs_a < 1e-3] = 1
+
+ diff_mat = np.abs(a - b) / abs_a
+ max_diff = np.max(diff_mat)
+
+ def err_msg():
+ offset = np.argmax(diff_mat > max_relative_error)
+ return "%s Variable %s max gradient diff %f over limit %f, the first " \
+ "error element is %d" % (
+ msg_prefix, name, max_diff, max_relative_error, offset)
+
+ self.assertLessEqual(max_diff, max_relative_error, err_msg())
+
+ def check_grad(self,
+ inputs_to_check,
+ output_name,
+ no_grad_set=None,
+ in_place=False,
+ max_relative_error=0.005):
+ self.scope = core.Scope()
+ self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
+ if no_grad_set is None:
+ no_grad_set = set()
+
+ numeric_grads = [
+ get_numeric_gradient(
+ self.scope,
+ self.op,
+ self.inputs,
+ input_to_check,
+ output_name,
+ in_place=in_place) for input_to_check in inputs_to_check
+ ]
+ grad_names = [
+ grad_var_name(input_to_check) for input_to_check in inputs_to_check
+ ]
+
+ cpu_place = core.CPUPlace()
+ cpu_analytic_grads = [
+ get_gradient(self.scope, self.op, self.inputs, self.outputs,
+ grad_name, cpu_place, no_grad_set)
+ for grad_name in grad_names
+ ]
+
+ self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
+ max_relative_error,
+ "Gradient Check On %s" % str(cpu_place))
+
+ if core.is_compile_gpu() and self.op.support_gpu():
+ gpu_place = core.GPUPlace(0)
+ gpu_analytic_grads = [
+ get_gradient(self.scope, self.op, self.inputs, self.outputs,
+ grad_name, gpu_place, no_grad_set)
+ for grad_name in grad_names
+ ]
+
+ self.__assert_is_close(numeric_grads, gpu_analytic_grads,
+ grad_names, max_relative_error,
+ "Gradient Check On %s" % str(gpu_place))
+
+ for c_grad, g_grad, name in itertools.izip(
+ cpu_analytic_grads, gpu_analytic_grads, grad_names):
+ self.assertTrue(
+ np.allclose(
+ c_grad, g_grad, atol=1e-4),
+ "output name: " + name + " has diff")
diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py
index d4277f2a42..fb6a440e23 100644
--- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py
+++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py
@@ -1,36 +1,27 @@
import unittest
import numpy
-from op_test_util import OpTestMeta
-from gradient_checker import GradientChecker, create_op
+from op_test import OpTest
-class TestCrossEntropy(unittest.TestCase):
- __metaclass__ = OpTestMeta
-
+class TestCrossEntropy(OpTest):
def setUp(self):
- self.type = "onehot_cross_entropy"
+ self.op_type = "onehot_cross_entropy"
batch_size = 30
class_num = 10
- X = numpy.random.random((batch_size, class_num)).astype("float32")
- label = 5 * numpy.ones(batch_size).astype("int32")
+ X = numpy.random.uniform(0.1, 1.0,
+ [batch_size, class_num]).astype("float32")
+ label = (class_num / 2) * numpy.ones(batch_size).astype("int32")
self.inputs = {'X': X, 'label': label}
Y = []
for i in range(0, batch_size):
Y.append(-numpy.log(X[i][label[i]]))
self.outputs = {'Y': numpy.array(Y).astype("float32")}
+ def test_check_output(self):
+ self.check_output()
-class CrossEntropyGradOpTest(GradientChecker):
def test_check_grad(self):
- op = create_op("onehot_cross_entropy")
- batch_size = 30
- class_num = 10
- inputs = {
- "X": numpy.random.uniform(
- 0.1, 1.0, [batch_size, class_num]).astype("float32"),
- "label": (class_num / 2) * numpy.ones(batch_size).astype("int32")
- }
- self.check_grad(op, inputs, set("X"), "Y")
+ self.check_grad(["X"], "Y")
if __name__ == "__main__":
diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py
index 19eb464baa..4b7ce92c0f 100644
--- a/python/paddle/v2/framework/tests/test_lookup_table.py
+++ b/python/paddle/v2/framework/tests/test_lookup_table.py
@@ -4,7 +4,7 @@ from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
-class TestSigmoidOp(unittest.TestCase):
+class TestLookupTableOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
@@ -15,7 +15,7 @@ class TestSigmoidOp(unittest.TestCase):
self.outputs = {'Out': table[ids]}
-class TestSigmoidGradOp(GradientChecker):
+class TestLookupTableGradOp(GradientChecker):
def test_grad(self):
op = create_op('lookup_table')
table = np.random.random((17, 31)).astype('float32')
diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py
index b58e4266d1..8c827e242e 100644
--- a/python/paddle/v2/framework/tests/test_mul_op.py
+++ b/python/paddle/v2/framework/tests/test_mul_op.py
@@ -2,6 +2,7 @@ import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
+from paddle.v2.framework.op import Operator
class TestMulOp(unittest.TestCase):
@@ -16,6 +17,22 @@ class TestMulOp(unittest.TestCase):
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
+class TestMulOp2(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = "mul"
+ self.inputs = {
+ 'X': np.random.random((15, 4, 12, 10)).astype("float32"),
+ 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32")
+ }
+ self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2}
+ self.outputs = {
+ 'Out': np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10),
+ self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9))
+ }
+
+
class TestMulGradOp(GradientChecker):
def setUp(self):
self.op = create_op("mul")
@@ -49,7 +66,38 @@ class TestMulGradOp(GradientChecker):
no_grad_set={"Y"})
-# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
+class TestMulGradTest2(GradientChecker):
+ def setUp(self):
+ self.op = Operator(
+ "mul", X="X", Y="Y", Out="Out", x_num_col_dims=2, y_num_col_dims=2)
+ self.inputs = {
+ "X": np.random.random((15, 4, 12, 10)).astype("float32"),
+ "Y": np.random.random((4, 30, 8, 2, 9)).astype("float32")
+ }
+
+ def test_cpu_gpu_compare(self):
+ self.compare_grad(self.op, self.inputs)
+
+ def test_normal(self):
+ self.check_grad(
+ self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
+
+ def test_ignore_x(self):
+ self.check_grad(
+ self.op,
+ self.inputs, ["Y"],
+ "Out",
+ max_relative_error=0.5,
+ no_grad_set={"X"})
+
+ def test_ignore_y(self):
+ self.check_grad(
+ self.op,
+ self.inputs, ["X"],
+ "Out",
+ max_relative_error=0.5,
+ no_grad_set={"Y"})
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py
index 2ddb85e2e7..8378c1cd21 100644
--- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py
+++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py
@@ -16,6 +16,18 @@ class TestRowwiseAddOp(unittest.TestCase):
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
+class TestRowwiseAddOp2(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = "rowwise_add"
+ self.inputs = {
+ 'X': np.random.random((13, 6, 7, 8)).astype("float32"),
+ 'b': np.random.random((7, 8)).astype("float32")
+ }
+ self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
+
+
class TestRowwiseAddGradOp(GradientChecker):
def setUp(self):
self.op = create_op("rowwise_add")
@@ -34,5 +46,23 @@ class TestRowwiseAddGradOp(GradientChecker):
self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"})
+class TestRowwiseAddGradOp2(GradientChecker):
+ def setUp(self):
+ self.op = create_op("rowwise_add")
+ self.inputs = {
+ "X": np.random.uniform(0.1, 1, [2, 3, 2, 5]).astype("float32"),
+ "b": np.random.uniform(0.1, 1, [2, 5]).astype("float32")
+ }
+
+ def test_normal(self):
+ self.check_grad(self.op, self.inputs, ["X", "b"], "Out")
+
+ def test_ignore_b(self):
+ self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"})
+
+ def test_ignore_x(self):
+ self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"})
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py
index 273c2e5ab1..2316e49eff 100644
--- a/python/paddle/v2/framework/tests/test_sigmoid_op.py
+++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py
@@ -1,27 +1,21 @@
import unittest
import numpy as np
-from op_test_util import OpTestMeta
-from gradient_checker import GradientChecker, create_op
+from op_test import OpTest
-class TestSigmoidOp(unittest.TestCase):
- __metaclass__ = OpTestMeta
-
+class TestSigmoid(OpTest):
def setUp(self):
- self.type = "sigmoid"
- self.inputs = {'X': np.random.random((15, 31)).astype("float32")}
+ self.op_type = "sigmoid"
+ self.inputs = {
+ 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
+ }
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
+ def test_check_output(self):
+ self.check_output()
-class TestSigmoidGradOp(GradientChecker):
- def test_grad(self):
- op = create_op("sigmoid")
- inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")}
- # compare gpu and cpu results for backward op.
- # this test will be skiped if only compiling CPU version.
- self.compare_grad(op, inputs)
- # check gradients
- self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007)
+ def test_check_grad(self):
+ self.check_grad(["X"], "Y", max_relative_error=0.007)
if __name__ == '__main__':
diff --git a/python/paddle/v2/framework/tests/test_sum_op.py b/python/paddle/v2/framework/tests/test_sum_op.py
new file mode 100644
index 0000000000..66417d70e8
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_sum_op.py
@@ -0,0 +1,24 @@
+import unittest
+import numpy as np
+from op_test import OpTest
+
+
+class TestSumOp(OpTest):
+ def setUp(self):
+ self.op_type = "sum"
+ x0 = np.random.random((3, 4)).astype('float32')
+ x1 = np.random.random((3, 4)).astype('float32')
+ x2 = np.random.random((3, 4)).astype('float32')
+ self.inputs = {"X": {"x0": x0, "x1": x1, "x2": x2}}
+ y = x0 + x1 + x2
+ self.outputs = {'Out': y}
+
+ def test_check_output(self):
+ self.check_output()
+
+ def test_check_grad(self):
+ self.check_grad(["x0"], "Out")
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_top_k_op.py b/python/paddle/v2/framework/tests/test_top_k_op.py
new file mode 100644
index 0000000000..e841d96d26
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_top_k_op.py
@@ -0,0 +1,52 @@
+import unittest
+import numpy as np
+from gradient_checker import GradientChecker, create_op
+from op_test_util import OpTestMeta
+
+
+class TestTopkOp(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = "top_k"
+ k = 1
+ input = np.random.random((32, 84)).astype("float32")
+ output = np.ndarray((32, k))
+ indices = np.ndarray((32, k))
+
+ self.inputs = {'X': input}
+ self.attrs = {'k': k}
+
+ for rowid in xrange(32):
+ row = input[rowid]
+ output[rowid] = np.sort(row)[-k:]
+ indices[rowid] = row.argsort()[-k:]
+
+ self.outputs = {'Out': output, 'Indices': indices}
+
+
+class TestTopkOp3d(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = "top_k"
+ k = 1
+ input = np.random.random((32, 2, 84)).astype("float32")
+ input_flat_2d = input.reshape(64, 84)
+ output = np.ndarray((64, k))
+ indices = np.ndarray((64, k)).astype("int")
+
+ # FIXME: should use 'X': input for a 3d input
+ self.inputs = {'X': input_flat_2d}
+ self.attrs = {'k': k}
+
+ for rowid in xrange(64):
+ row = input_flat_2d[rowid]
+ output[rowid] = np.sort(row)[-k:]
+ indices[rowid] = row.argsort()[-k:]
+
+ self.outputs = {'Out': output, 'Indices': indices}
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py
index 0654a30104..ca95ef13bd 100644
--- a/python/paddle/v2/trainer.py
+++ b/python/paddle/v2/trainer.py
@@ -174,13 +174,18 @@ class SGD(object):
pass_id=pass_id,
batch_id=batch_id,
cost=cost,
- evaluator=batch_evaluator))
+ evaluator=batch_evaluator,
+ gm=self.__gradient_machine__))
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
self.__parameter_updater__.finishPass()
pass_evaluator.finish()
- event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
+ event_handler(
+ v2_event.EndPass(
+ pass_id,
+ evaluator=pass_evaluator,
+ gm=self.__gradient_machine__))
self.__gradient_machine__.finish()
def test(self, reader, feeding=None):