From e76fa85cc842a5b9f21e66f64cd34d6f7fa00719 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 4 Sep 2017 16:55:29 -0700 Subject: [PATCH 01/11] WIP --- paddle/framework/attribute.h | 17 ++++++++++ paddle/framework/ddim.cc | 30 ++++++++++------- paddle/framework/ddim.h | 2 ++ paddle/framework/eigen.h | 13 +++++++- paddle/framework/eigen_test.cc | 20 ++++++++++++ paddle/framework/tensor.h | 3 ++ paddle/framework/tensor_impl.h | 12 +++++++ paddle/framework/tensor_test.cc | 13 ++++++++ paddle/operators/mul_op.cc | 57 +++++++++++++++++++++++++-------- 9 files changed, 140 insertions(+), 27 deletions(-) diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 08b47cabd4..7da34e3f2b 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -51,6 +51,18 @@ class LargerThanChecker { T lower_bound_; }; +template +class EqualLargerThanChecker { + public: + explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} + void operator()(T& value) const { + PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail"); + } + + private: + T lower_bound_; +}; + // we can provide users more common Checker, like 'LessThanChecker', // 'BetweenChecker'... @@ -114,6 +126,11 @@ class TypedAttrChecker { return *this; } + TypedAttrChecker& EqualLargerThan(const T& lower_bound) { + value_checkers_.push_back(EqualLargerThanChecker(lower_bound)); + return *this; + } + // we can add more common limits, like LessThan(), Between()... TypedAttrChecker& SetDefault(const T& default_value) { diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index cfd3e8dfde..c32d66f41c 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -195,18 +195,6 @@ std::vector vectorize(const DDim& ddim) { return result; } -struct ProductVisitor : public boost::static_visitor { - template - ssize_t operator()(const Dim& dim) { - return product(dim); - } -}; - -ssize_t product(const DDim& ddim) { - ProductVisitor visitor; - return boost::apply_visitor(visitor, ddim); -} - struct SliceVectorizeVisitor : public boost::static_visitor<> { std::vector& vector; int begin; @@ -247,6 +235,24 @@ DDim slice_ddim(const DDim& dim, int begin, int end) { return make_ddim(vec); } +struct ProductVisitor : public boost::static_visitor { + template + ssize_t operator()(const Dim& dim) { + return product(dim); + } +}; + +ssize_t product(const DDim& ddim) { + ProductVisitor visitor; + return boost::apply_visitor(visitor, ddim); +} + +ssize_t product(const DDim& ddim, int begin, int end) { + ProductVisitor visitor; + DDim sliced_ddim = slice_ddim(ddim, begin, end); + return boost::apply_visitor(visitor, sliced_ddim); +} + /// \cond HIDDEN struct ArityVisitor : boost::static_visitor { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 95f294b627..7a02af6b8a 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -96,6 +96,8 @@ std::vector vectorize(const DDim& ddim); ssize_t product(const DDim& ddim); +ssize_t product(const DDim& ddim, int begin, int end); + /** * \brief Slice a ddim * diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index a4667cc51f..47551634a6 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -63,7 +63,18 @@ struct EigenTensor { template -struct EigenMatrix : public EigenTensor {}; +struct EigenMatrix : public EigenTensor { + static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) { + int rank = tensor.dims_.size(); + PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, + "`num_row_dims` must be between (0, rank_of_tensor)."); + return EigenMatrix::From( + tensor, make_ddim({static_cast( + product(tensor.dims_, 0, rank - num_row_dims)), + static_cast(product( + tensor.dims_, rank - num_row_dims, rank))})); + } +}; template diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index dc1957691b..bae82fdb7d 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -108,5 +108,25 @@ TEST(Eigen, Matrix) { } } +TEST(Eigen, MatrixReshape) { + Tensor t; + float* p = + t.mutable_data(make_ddim({2, 3, 6, 4}), platform::CPUPlace()); + for (int i = 0; i < 2 * 3 * 6 * 4; ++i) { + p[i] = static_cast(i); + } + + EigenMatrix::Type em = EigenMatrix::Reshape(t, 2); + + ASSERT_EQ(2 * 3, em.dimension(0)); + ASSERT_EQ(6 * 4, em.dimension(1)); + + for (int i = 0; i < 2 * 3; i++) { + for (int j = 0; j < 6 * 4; j++) { + ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 643f875491..ce938b2143 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -43,6 +43,9 @@ class Tensor { template friend struct EigenTensor; + template + friend struct EigenMatrix; + template friend struct EigenVector; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 7893e233b7..7c47c389a1 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -148,5 +148,17 @@ inline Tensor& Tensor::Resize(const DDim& dims) { inline const DDim& Tensor::dims() const { return dims_; } +template +inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { + Tensor res; + res.ShareDataWith(src); + DDim src_dim = src.dims(); + int rank = src_dim.size(); + res.Resize(make_ddim( + {static_cast(product(src_dim, 0, rank - num_row_dims)), + static_cast(product(src_dim, rank - num_row_dims, rank))})); + return res; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 7db38d5cae..cdd68b303c 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) { } #endif } + +TEST(Tensor, FlattenToMatrix) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + int* src_ptr = src.mutable_data(make_ddim({2, 3, 4, 9}), CPUPlace()); + for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { + src_ptr[i] = i; + } + Tensor res = FlattenToMatrix(src, 2); + ASSERT_EQ(res.dims()[0], 2 * 3); + ASSERT_EQ(res.dims()[1], 4 * 9); +} \ No newline at end of file diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 559d19e6bd..f668008a10 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -25,18 +25,26 @@ class MulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dim0 = ctx.Input("X")->dims(); - auto dim1 = ctx.Input("Y")->dims(); - PADDLE_ENFORCE_EQ(dim0.size(), 2, - "input X(%s) should be a tensor with 2 dims, a matrix", - ctx.op_.Input("X")); - PADDLE_ENFORCE_EQ(dim1.size(), 2, - "input Y(%s) should be a tensor with 2 dims, a matrix", - ctx.op_.Input("Y")); + auto x_dim = ctx.Input("X")->dims(); + auto y_dim = ctx.Input("Y")->dims(); + int x_num_row_dims = GetAttr("X_num_raw_dims"); + int y_num_row_dims = GetAttr("Y_num_raw_dims"); + + PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, + "The rank of input tensor X(%s) should be larger than " + "`mul_op`'s `X_num_raw_dims`.", + ctx.op_.Input("X")); + PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, + "The rank of input tensor Y(%s) should be larger than " + "`mul_op`'s `Y_num_raw_dims`.", + ctx.op_.Input("Y")); PADDLE_ENFORCE_EQ( - dim0[1], dim1[0], + product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), + product(y_dim, 0, y_dim.size() - y_num_row_dims), "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize({dim0[0], dim1[1]}); + ctx.Output("Out")->Resize( + {product(x_dim, 0, x_dim.size() - x_num_row_dims), + product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size())}); } }; @@ -47,6 +55,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The first input of mul op"); AddInput("Y", "The second input of mul op"); AddOutput("Out", "The output of mul op"); + AddAttr( + "x_num_row_dims", + "mul_op can take tensors with more than two dimensions as input `X`, " + "in that case, tensors will be flattened to a matrix. The matrix's " + "second dimension(row length) will be the product of tensor's last " + "`num_row_dims` dimensions, and the matrix's first dimension(column " + "length) will be the product of tensor's first `rank - num_row_dims` " + "dimensions.") + .SetDefault(1) + .EqualLargerThan(1); + AddAttr( + "y_num_row_dims", + "mul_op can take tensors with more than two dimensions as input `Y`, " + "in that case, tensors will be flattened to a matrix. Just like input " + "`X`.") + .SetDefault(1) + .EqualLargerThan(1); AddComment(R"DOC( Two Element Mul Operator. @@ -70,10 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel { auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); - PADDLE_ENFORCE(x_dims[0] == out_dims[0], - "Out@GRAD M X N must equal to X dims 0, M "); - PADDLE_ENFORCE(y_dims[1] == out_dims[1], - "Out@GRAD M X N must equal to Y dims 1, N "); + PADDLE_ENFORCE( + product(x_dim, 0, x_dims.size() - x_num_row_dims) == out_dims[0], + "The first dimension of Out@GRAD must equal to the first dimension of " + "the first operand."); + PADDLE_ENFORCE(product(y_dim, y_dims.size() - y_num_row_dims, + y_dims.size()) == out_dims[1], + "The second dimension of Out@GRAD must equal to the second " + "dimension of the second operand."); x_grad->Resize(x_dims); y_grad->Resize(y_dims); From af0264aa6b420f1401823792854c3a5c1e889cd2 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 5 Sep 2017 10:50:58 -0700 Subject: [PATCH 02/11] Add global function `FalttenToMatrix` and add `axis` for MulOp --- paddle/operators/mul_op.cc | 25 ++++++++++--------- paddle/operators/mul_op.h | 50 +++++++++++++++++++++++++++++--------- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index d301a8619f..be1782bb6b 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -27,24 +27,25 @@ class MulOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto x_dim = ctx.Input("X")->dims(); auto y_dim = ctx.Input("Y")->dims(); - int x_num_row_dims = GetAttr("X_num_raw_dims"); - int y_num_row_dims = GetAttr("Y_num_raw_dims"); + int x_num_row_dims = GetAttr("x_num_row_dims"); + int y_num_row_dims = GetAttr("y_num_row_dims"); PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, "The rank of input tensor X(%s) should be larger than " - "`mul_op`'s `X_num_raw_dims`.", + "`mul_op`'s `x_num_row_dims`.", ctx.op().Input("X")); PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, "The rank of input tensor Y(%s) should be larger than " - "`mul_op`'s `Y_num_raw_dims`.", + "`mul_op`'s `y_num_row_dims`.", ctx.op().Input("Y")); PADDLE_ENFORCE_EQ( product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), product(y_dim, 0, y_dim.size() - y_num_row_dims), "First matrix's width must be equal with second matrix's height."); ctx.Output("Out")->Resize( - {product(x_dim, 0, x_dim.size() - x_num_row_dims), - product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size())}); + {static_cast(product(x_dim, 0, x_dim.size() - x_num_row_dims)), + static_cast( + product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))}); } }; @@ -96,13 +97,15 @@ class MulOpGrad : public framework::OperatorWithKernel { auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); PADDLE_ENFORCE( - product(x_dim, 0, x_dims.size() - x_num_row_dims) == out_dims[0], + product(x_dims, 0, x_dims.size() - GetAttr("x_num_row_dims")) == + out_dims[0], "The first dimension of Out@GRAD must equal to the first dimension of " "the first operand."); - PADDLE_ENFORCE(product(y_dim, y_dims.size() - y_num_row_dims, - y_dims.size()) == out_dims[1], - "The second dimension of Out@GRAD must equal to the second " - "dimension of the second operand."); + PADDLE_ENFORCE( + product(y_dims, y_dims.size() - GetAttr("y_num_row_dims"), + y_dims.size()) == out_dims[1], + "The second dimension of Out@GRAD must equal to the second " + "dimension of the second operand."); x_grad->Resize(x_dims); y_grad->Resize(y_dims); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 8facc02814..73a53798e0 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -31,13 +31,25 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Y = context.Input("Y"); - auto* Z = context.Output("Out"); + const Tensor* X = context.Input("X"); + const Tensor* Y = context.Input("Y"); + Tensor* Z = context.Output("Out"); + const Tensor X_matrix = + X->dims().size() > 2 + ? framework::FlattenToMatrix( + *X, context.template GetAttr("x_num_row_dims")) + : *X; + const Tensor Y_matrix = + Y->dims().size() > 2 + ? framework::FlattenToMatrix( + *Y, context.template GetAttr("y_num_row_dims")) + : *Y; + Z->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); - math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + math::matmul(X_matrix, false, Y_matrix, false, 1, Z, 0, + device_context); } }; @@ -45,20 +57,36 @@ template class MulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* X = ctx.Input("X"); - auto* Y = ctx.Input("Y"); - auto* dOut = ctx.Input(framework::GradVarName("Out")); + int x_num_row_dims = ctx.template GetAttr("x_num_row_dims"); + int y_num_row_dims = ctx.template GetAttr("y_num_row_dims"); + const Tensor* X = ctx.Input("X"); + const Tensor* Y = ctx.Input("Y"); + const Tensor X_matrix = + X->dims().size() > 2 ? framework::FlattenToMatrix(*X, x_num_row_dims) + : *X; + const Tensor Y_matrix = + Y->dims().size() > 2 ? framework::FlattenToMatrix(*Y, y_num_row_dims) + : *Y; + const Tensor* dOut = ctx.Input(framework::GradVarName("Out")); - auto* dX = ctx.Output(framework::GradVarName("X")); - auto* dY = ctx.Output(framework::GradVarName("Y")); + Tensor* dX = ctx.Output(framework::GradVarName("X")); + Tensor* dY = ctx.Output(framework::GradVarName("Y")); dX->mutable_data(ctx.GetPlace()); dY->mutable_data(ctx.GetPlace()); + Tensor dX_matrix = dX->dims().size() > 2 + ? framework::FlattenToMatrix(*dX, x_num_row_dims) + : *dX; + Tensor dY_matrix = dY->dims().size() > 2 + ? framework::FlattenToMatrix(*dY, y_num_row_dims) + : *dY; auto* device_context = const_cast(ctx.device_context_); // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N - math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); + math::matmul(*dOut, false, Y_matrix, true, 1, &dX_matrix, 0, + device_context); // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K - math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); + math::matmul(X_matrix, true, *dOut, false, 1, &dY_matrix, 0, + device_context); } }; From d71396bf870966a14638d5ea108b2b2f8babbe2f Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 5 Sep 2017 15:20:00 -0700 Subject: [PATCH 03/11] Add global function `flatten_to_2d()` --- paddle/framework/ddim.cc | 14 +++++++------ paddle/framework/ddim.h | 4 ++-- paddle/framework/eigen.h | 7 ++----- paddle/framework/tensor_impl.h | 6 +----- paddle/operators/mul_op.cc | 36 +++++++++++++++++++--------------- 5 files changed, 33 insertions(+), 34 deletions(-) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index c32d66f41c..47d1f30116 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -247,12 +247,6 @@ ssize_t product(const DDim& ddim) { return boost::apply_visitor(visitor, ddim); } -ssize_t product(const DDim& ddim, int begin, int end) { - ProductVisitor visitor; - DDim sliced_ddim = slice_ddim(ddim, begin, end); - return boost::apply_visitor(visitor, sliced_ddim); -} - /// \cond HIDDEN struct ArityVisitor : boost::static_visitor { @@ -289,5 +283,13 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { DDim::DDim(std::initializer_list init_list) { *this = make_ddim(init_list); } + +DDim flatten_to_2d(const DDim& src, int num_row_dims) { + int rank = src.size(); + return make_ddim( + {static_cast(product(slice_ddim(src, 0, rank - num_row_dims))), + static_cast(product(slice_ddim(src, rank - num_row_dims, rank)))}); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 7a02af6b8a..cf786d140e 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -96,8 +96,6 @@ std::vector vectorize(const DDim& ddim); ssize_t product(const DDim& ddim); -ssize_t product(const DDim& ddim, int begin, int end); - /** * \brief Slice a ddim * @@ -117,6 +115,8 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); +DDim flatten_to_2d(const DDim& src, int num_row_dims); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 47551634a6..656aef4212 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -68,11 +68,8 @@ struct EigenMatrix : public EigenTensor { int rank = tensor.dims_.size(); PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, "`num_row_dims` must be between (0, rank_of_tensor)."); - return EigenMatrix::From( - tensor, make_ddim({static_cast( - product(tensor.dims_, 0, rank - num_row_dims)), - static_cast(product( - tensor.dims_, rank - num_row_dims, rank))})); + return EigenMatrix::From(tensor, + flatten_to_2d(tensor.dims(), num_row_dims)); } }; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 7c47c389a1..d32fe78f42 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -152,11 +152,7 @@ template inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { Tensor res; res.ShareDataWith(src); - DDim src_dim = src.dims(); - int rank = src_dim.size(); - res.Resize(make_ddim( - {static_cast(product(src_dim, 0, rank - num_row_dims)), - static_cast(product(src_dim, rank - num_row_dims, rank))})); + res.Resize(flatten_to_2d(src.dims(), num_row_dims)); return res; } diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 935fe889e5..dfc22decdc 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -25,27 +25,27 @@ class MulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto x_dim = ctx.Input("X")->dims(); - auto y_dim = ctx.Input("Y")->dims(); + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); int x_num_row_dims = GetAttr("x_num_row_dims"); int y_num_row_dims = GetAttr("y_num_row_dims"); - PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, + PADDLE_ENFORCE(x_dims.size() > x_num_row_dims, "The rank of input tensor X(%s) should be larger than " "`mul_op`'s `x_num_row_dims`.", ctx.op().Input("X")); - PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, + PADDLE_ENFORCE(y_dims.size() > y_num_row_dims, "The rank of input tensor Y(%s) should be larger than " "`mul_op`'s `y_num_row_dims`.", ctx.op().Input("Y")); + + auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_row_dims); + auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_row_dims); + PADDLE_ENFORCE_EQ( - product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), - product(y_dim, 0, y_dim.size() - y_num_row_dims), + x_mat_dims[1], y_mat_dims[0], "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize( - {static_cast(product(x_dim, 0, x_dim.size() - x_num_row_dims)), - static_cast( - product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))}); + ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]}); } }; @@ -96,14 +96,18 @@ class MulOpGrad : public framework::OperatorWithKernel { auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); - PADDLE_ENFORCE( - product(x_dims, 0, x_dims.size() - GetAttr("x_num_row_dims")) == - out_dims[0], + + auto x_mat_dims = + framework::flatten_to_2d(x_dims, GetAttr("x_num_row_dims")); + auto y_mat_dims = + framework::flatten_to_2d(y_dims, GetAttr("y_num_row_dims")); + + PADDLE_ENFORCE_EQ( + x_mat_dims[0], out_dims[0], "The first dimension of Out@GRAD must equal to the first dimension of " "the first operand."); - PADDLE_ENFORCE( - product(y_dims, y_dims.size() - GetAttr("y_num_row_dims"), - y_dims.size()) == out_dims[1], + PADDLE_ENFORCE_EQ( + y_mat_dims[1], out_dims[1], "The second dimension of Out@GRAD must equal to the second " "dimension of the second operand."); From e168fc44321b01de2e8447bacdf79b41816d067e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 5 Sep 2017 16:39:09 -0700 Subject: [PATCH 04/11] Add unit tests for cases that `mul_op` takes tensors as inputs --- .../paddle/v2/framework/tests/test_mul_op.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index b58e4266d1..3ea73d94b2 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -2,6 +2,7 @@ import unittest import numpy as np from gradient_checker import GradientChecker, create_op from op_test_util import OpTestMeta +from paddle.v2.framework.op import Operator class TestMulOp(unittest.TestCase): @@ -16,6 +17,22 @@ class TestMulOp(unittest.TestCase): self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} +class TestMulOp2(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "mul" + self.inputs = { + 'X': np.random.random((15, 4, 12, 10)).astype("float32"), + 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32") + } + self.attrs = {'x_num_row_dims': 2, 'y_num_row_dims': 3} + self.outputs = { + 'Out': np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), + self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9)) + } + + class TestMulGradOp(GradientChecker): def setUp(self): self.op = create_op("mul") @@ -49,6 +66,39 @@ class TestMulGradOp(GradientChecker): no_grad_set={"Y"}) +class TestMulGradTest2(GradientChecker): + def setUp(self): + self.op = Operator( + "mul", X="X", Y="Y", Out="Out", x_num_row_dims=2, y_num_row_dims=3) + self.inputs = { + "X": np.random.random((15, 4, 12, 10)).astype("float32"), + "Y": np.random.random((4, 30, 8, 2, 9)).astype("float32") + } + + def test_cpu_gpu_compare(self): + self.compare_grad(self.op, self.inputs) + + def test_normal(self): + self.check_grad( + self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5) + + def test_ignore_x(self): + self.check_grad( + self.op, + self.inputs, ["Y"], + "Out", + max_relative_error=0.5, + no_grad_set={"X"}) + + def test_ignore_y(self): + self.check_grad( + self.op, + self.inputs, ["X"], + "Out", + max_relative_error=0.5, + no_grad_set={"Y"}) + + # TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library if __name__ == '__main__': From 256d6a33d53f258cb6b57cd2bbfa4e5a58df642b Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 5 Sep 2017 20:20:42 -0700 Subject: [PATCH 05/11] Add axis for rowwise_add_op --- paddle/framework/ddim.cc | 4 ++ paddle/framework/ddim.h | 2 + paddle/framework/eigen.h | 15 ++++++-- paddle/operators/rowwise_add_op.cc | 38 ++++++++++++------- paddle/operators/rowwise_add_op.h | 15 +++++--- .../v2/framework/tests/test_rowwise_add_op.py | 30 +++++++++++++++ 6 files changed, 81 insertions(+), 23 deletions(-) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 47d1f30116..972dac7073 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -291,5 +291,9 @@ DDim flatten_to_2d(const DDim& src, int num_row_dims) { static_cast(product(slice_ddim(src, rank - num_row_dims, rank)))}); } +DDim flatten_to_1d(const DDim& src) { + return make_ddim({static_cast(product(src))}); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index cf786d140e..8f1269d9a1 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -117,6 +117,8 @@ std::ostream& operator<<(std::ostream&, const DDim&); DDim flatten_to_2d(const DDim& src, int num_row_dims); +DDim flatten_to_1d(const DDim& src); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 656aef4212..c6f42251da 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -71,6 +71,15 @@ struct EigenMatrix : public EigenTensor { return EigenMatrix::From(tensor, flatten_to_2d(tensor.dims(), num_row_dims)); } + + static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, + int num_row_dims) { + int rank = tensor.dims_.size(); + PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, + "`num_row_dims` must be between (0, rank_of_tensor)."); + return EigenMatrix::From(tensor, + flatten_to_2d(tensor.dims(), num_row_dims)); + } }; template { // Flatten reshapes a Tensor into an EigenVector. static typename EigenVector::Type Flatten(Tensor& tensor) { - return EigenVector::From( - tensor, make_ddim({static_cast(product(tensor.dims_))})); + return EigenVector::From(tensor, {static_cast(product(tensor.dims_))}); } static typename EigenVector::ConstType Flatten(const Tensor& tensor) { - return EigenVector::From( - tensor, make_ddim({static_cast(product(tensor.dims_))})); + return EigenVector::From(tensor, {static_cast(product(tensor.dims_))}); } }; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 30b4b40431..209281a45b 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dim0 = ctx.Input("X")->dims(); - auto dim1 = ctx.Input("b")->dims(); - - PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); - PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); - PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); - PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1"); - ctx.Output("Out")->Resize(ctx.Input("X")->dims()); + auto x_dims = ctx.Input("X")->dims(); + auto b_dims = ctx.Input("b")->dims(); + PADDLE_ENFORCE_GT( + x_dims.size(), b_dims.size(), + "The rank of input `X` must be larger than the one of input `b`."); + + int num_row_dims = b_dims.size(); + + PADDLE_ENFORCE_EQ(framework::slice_ddim( + x_dims, x_dims.size() - num_row_dims, x_dims.size()), + b_dims, "The width of two operands must be same"); + PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); + ctx.Output("Out")->Resize(x_dims); } }; @@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - auto dims0 = ctx.Input("X")->dims(); - auto dims1 = ctx.Input("b")->dims(); - PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") + auto x_dims = ctx.Input("X")->dims(); + auto b_dims = ctx.Input("b")->dims(); + PADDLE_ENFORCE_GT( + x_dims.size(), b_dims.size(), + "The rank of input `X` must be larger than the one of input `b`."); + + int num_row_dims = b_dims.size(); + PADDLE_ENFORCE_EQ(framework::slice_ddim( + x_dims, x_dims.size() - num_row_dims, x_dims.size()), + b_dims, "The width of two operands must be same"); auto *dx = ctx.Output(framework::GradVarName("X")); auto *db = ctx.Output(framework::GradVarName("b")); - if (dx) dx->Resize(dims0); - if (db) db->Resize(dims1); + if (dx) dx->Resize(x_dims); + if (db) db->Resize(b_dims); } }; diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 4e926d9f29..a52a53a7d2 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -33,10 +33,11 @@ class RowwiseAddKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); out->mutable_data(context.GetPlace()); - - auto input = EigenMatrix::From(*context.Input("X")); - auto bias = EigenVector::From(*context.Input("b")); - auto output = EigenMatrix::From(*out); + int num_row_dims = context.Input("b")->dims().size(); + auto input = + EigenMatrix::Reshape(*context.Input("X"), num_row_dims); + auto bias = EigenVector::Flatten(*context.Input("b")); + auto output = EigenMatrix::Reshape(*out, num_row_dims); const int bias_size = bias.dimension(0); const int rest_size = input.size() / bias_size; @@ -54,12 +55,14 @@ class RowwiseAddGradKernel : public framework::OpKernel { auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); auto* db = context.Output(framework::GradVarName("b")); + int num_row_dims = context.Input("b")->dims().size(); - auto out_grad = EigenMatrix::From(*dout); + auto out_grad = EigenMatrix::Reshape(*dout, num_row_dims); auto place = context.GetEigenDevice(); + if (dx) { dx->mutable_data(context.GetPlace()); - EigenMatrix::From(*dx).device(place) = out_grad; + EigenMatrix::Reshape(*dx, num_row_dims).device(place) = out_grad; } if (db) { diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index 2ddb85e2e7..8378c1cd21 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -16,6 +16,18 @@ class TestRowwiseAddOp(unittest.TestCase): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} +class TestRowwiseAddOp2(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "rowwise_add" + self.inputs = { + 'X': np.random.random((13, 6, 7, 8)).astype("float32"), + 'b': np.random.random((7, 8)).astype("float32") + } + self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} + + class TestRowwiseAddGradOp(GradientChecker): def setUp(self): self.op = create_op("rowwise_add") @@ -34,5 +46,23 @@ class TestRowwiseAddGradOp(GradientChecker): self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"}) +class TestRowwiseAddGradOp2(GradientChecker): + def setUp(self): + self.op = create_op("rowwise_add") + self.inputs = { + "X": np.random.uniform(0.1, 1, [2, 3, 2, 5]).astype("float32"), + "b": np.random.uniform(0.1, 1, [2, 5]).astype("float32") + } + + def test_normal(self): + self.check_grad(self.op, self.inputs, ["X", "b"], "Out") + + def test_ignore_b(self): + self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"}) + + def test_ignore_x(self): + self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"}) + + if __name__ == '__main__': unittest.main() From f2a66ffabbc704f2049addcb319620fde427e844 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 6 Sep 2017 10:53:27 -0700 Subject: [PATCH 06/11] Follow comments --- paddle/framework/attribute.h | 4 +-- paddle/framework/ddim.cc | 8 +++-- paddle/framework/ddim.h | 2 +- paddle/framework/eigen.h | 16 ++++----- paddle/framework/tensor_impl.h | 4 +-- paddle/framework/tensor_test.cc | 4 +-- paddle/operators/mul_op.cc | 34 +++++++++--------- paddle/operators/mul_op.h | 36 +++++++++---------- paddle/operators/rowwise_add_op.cc | 16 ++++----- paddle/operators/rowwise_add_op.h | 14 ++++---- .../paddle/v2/framework/tests/test_mul_op.py | 4 +-- 11 files changed, 73 insertions(+), 69 deletions(-) diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 7da34e3f2b..31e8218743 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -44,7 +44,7 @@ class LargerThanChecker { public: explicit LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { - PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail"); + PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails."); } private: @@ -56,7 +56,7 @@ class EqualLargerThanChecker { public: explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { - PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail"); + PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fails."); } private: diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 972dac7073..499d4ecbf1 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -284,11 +284,13 @@ DDim::DDim(std::initializer_list init_list) { *this = make_ddim(init_list); } -DDim flatten_to_2d(const DDim& src, int num_row_dims) { +// Reshape a tensor to a matrix. The matrix's first dimension(column length) +// will be the product of tensor's first `num_col_dims` dimensions +DDim flatten_to_2d(const DDim& src, int num_col_dims) { int rank = src.size(); return make_ddim( - {static_cast(product(slice_ddim(src, 0, rank - num_row_dims))), - static_cast(product(slice_ddim(src, rank - num_row_dims, rank)))}); + {static_cast(product(slice_ddim(src, 0, num_col_dims))), + static_cast(product(slice_ddim(src, num_col_dims, rank)))}); } DDim flatten_to_1d(const DDim& src) { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 8f1269d9a1..2dbd5f5f70 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -115,7 +115,7 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); -DDim flatten_to_2d(const DDim& src, int num_row_dims); +DDim flatten_to_2d(const DDim& src, int num_col_dims); DDim flatten_to_1d(const DDim& src); diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index c6f42251da..4b798cd5ae 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -64,21 +64,21 @@ struct EigenTensor { template struct EigenMatrix : public EigenTensor { - static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) { + static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) { int rank = tensor.dims_.size(); - PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, - "`num_row_dims` must be between (0, rank_of_tensor)."); + PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, + "`num_col_dims` must be between (0, rank_of_tensor)."); return EigenMatrix::From(tensor, - flatten_to_2d(tensor.dims(), num_row_dims)); + flatten_to_2d(tensor.dims(), num_col_dims)); } static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, - int num_row_dims) { + int num_col_dims) { int rank = tensor.dims_.size(); - PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, - "`num_row_dims` must be between (0, rank_of_tensor)."); + PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, + "`num_col_dims` must be between (0, rank_of_tensor)."); return EigenMatrix::From(tensor, - flatten_to_2d(tensor.dims(), num_row_dims)); + flatten_to_2d(tensor.dims(), num_col_dims)); } }; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index d32fe78f42..f1a7275899 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -149,10 +149,10 @@ inline Tensor& Tensor::Resize(const DDim& dims) { inline const DDim& Tensor::dims() const { return dims_; } template -inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { +inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { Tensor res; res.ShareDataWith(src); - res.Resize(flatten_to_2d(src.dims(), num_row_dims)); + res.Resize(flatten_to_2d(src.dims(), num_col_dims)); return res; } diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index cdd68b303c..a2c2d19dc7 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -263,7 +263,7 @@ TEST(Tensor, CopyFrom) { #endif } -TEST(Tensor, FlattenToMatrix) { +TEST(Tensor, ReshapeToMatrix) { using namespace paddle::framework; using namespace paddle::platform; Tensor src; @@ -271,7 +271,7 @@ TEST(Tensor, FlattenToMatrix) { for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { src_ptr[i] = i; } - Tensor res = FlattenToMatrix(src, 2); + Tensor res = ReshapeToMatrix(src, 2); ASSERT_EQ(res.dims()[0], 2 * 3); ASSERT_EQ(res.dims()[1], 4 * 9); } \ No newline at end of file diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index dfc22decdc..fb96d322e9 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -27,20 +27,20 @@ class MulOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); - int x_num_row_dims = GetAttr("x_num_row_dims"); - int y_num_row_dims = GetAttr("y_num_row_dims"); + int x_num_col_dims = GetAttr("x_num_col_dims"); + int y_num_col_dims = GetAttr("y_num_col_dims"); - PADDLE_ENFORCE(x_dims.size() > x_num_row_dims, + PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, "The rank of input tensor X(%s) should be larger than " - "`mul_op`'s `x_num_row_dims`.", + "`mul_op`'s `x_num_col_dims`.", ctx.op().Input("X")); - PADDLE_ENFORCE(y_dims.size() > y_num_row_dims, + PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, "The rank of input tensor Y(%s) should be larger than " - "`mul_op`'s `y_num_row_dims`.", + "`mul_op`'s `y_num_col_dims`.", ctx.op().Input("Y")); - auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_row_dims); - auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_row_dims); + auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); + auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); PADDLE_ENFORCE_EQ( x_mat_dims[1], y_mat_dims[0], @@ -57,19 +57,19 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Y", "The second input of mul op"); AddOutput("Out", "The output of mul op"); AddAttr( - "x_num_row_dims", + "x_num_col_dims", "mul_op can take tensors with more than two dimensions as input `X`, " - "in that case, tensors will be flattened to a matrix. The matrix's " - "second dimension(row length) will be the product of tensor's last " - "`num_row_dims` dimensions, and the matrix's first dimension(column " - "length) will be the product of tensor's first `rank - num_row_dims` " + "in that case, tensors will be reshaped to a matrix. The matrix's " + "first dimension(column length) will be the product of tensor's last " + "`num_col_dims` dimensions, and the matrix's second dimension(row " + "length) will be the product of tensor's first `rank - num_col_dims` " "dimensions.") .SetDefault(1) .EqualLargerThan(1); AddAttr( - "y_num_row_dims", + "y_num_col_dims", "mul_op can take tensors with more than two dimensions as input `Y`, " - "in that case, tensors will be flattened to a matrix. Just like input " + "in that case, tensors will be reshaped to a matrix. Just like input " "`X`.") .SetDefault(1) .EqualLargerThan(1); @@ -98,9 +98,9 @@ class MulOpGrad : public framework::OperatorWithKernel { auto *y_grad = ctx.Output(framework::GradVarName("Y")); auto x_mat_dims = - framework::flatten_to_2d(x_dims, GetAttr("x_num_row_dims")); + framework::flatten_to_2d(x_dims, GetAttr("x_num_col_dims")); auto y_mat_dims = - framework::flatten_to_2d(y_dims, GetAttr("y_num_row_dims")); + framework::flatten_to_2d(y_dims, GetAttr("y_num_col_dims")); PADDLE_ENFORCE_EQ( x_mat_dims[0], out_dims[0], diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 62557bb839..6656ecaf1a 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -1,14 +1,14 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - you may obtain a copy of the License at + You may not use this file except in compliance with the License. + You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANy KIND, either express or implied. + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ @@ -33,22 +33,22 @@ class MulKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { const Tensor* x = context.Input("X"); const Tensor* y = context.Input("Y"); - Tensor* Z = context.Output("Out"); + Tensor* z = context.Output("Out"); const Tensor x_matrix = x->dims().size() > 2 - ? framework::FlattenToMatrix( - *x, context.template GetAttr("x_num_row_dims")) + ? framework::ReshapeToMatrix( + *x, context.template GetAttr("x_num_col_dims")) : *x; const Tensor y_matrix = y->dims().size() > 2 - ? framework::FlattenToMatrix( - *y, context.template GetAttr("y_num_row_dims")) + ? framework::ReshapeToMatrix( + *y, context.template GetAttr("y_num_col_dims")) : *y; - Z->mutable_data(context.GetPlace()); + z->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); - math::matmul(x_matrix, false, y_matrix, false, 1, Z, 0, + math::matmul(x_matrix, false, y_matrix, false, 1, z, 0, device_context); } }; @@ -57,15 +57,15 @@ template class MulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - int x_num_row_dims = ctx.template GetAttr("x_num_row_dims"); - int y_num_row_dims = ctx.template GetAttr("y_num_row_dims"); + int x_num_col_dims = ctx.template GetAttr("x_num_col_dims"); + int y_num_col_dims = ctx.template GetAttr("y_num_col_dims"); const Tensor* x = ctx.Input("X"); const Tensor* y = ctx.Input("Y"); const Tensor x_matrix = - x->dims().size() > 2 ? framework::FlattenToMatrix(*x, x_num_row_dims) + x->dims().size() > 2 ? framework::ReshapeToMatrix(*x, x_num_col_dims) : *x; const Tensor y_matrix = - y->dims().size() > 2 ? framework::FlattenToMatrix(*y, y_num_row_dims) + y->dims().size() > 2 ? framework::ReshapeToMatrix(*y, y_num_col_dims) : *y; const Tensor* dout = ctx.Input(framework::GradVarName("Out")); @@ -75,8 +75,8 @@ class MulGradKernel : public framework::OpKernel { const_cast(ctx.device_context_); if (dx) { dx->mutable_data(ctx.GetPlace()); - Tensor dx_matrix = dx->dims().size() > 2 ? framework::FlattenToMatrix( - *dx, x_num_row_dims) + Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix( + *dx, x_num_col_dims) : *dx; // dx = dout * y'. dx: M x K, dout : M x N, y : K x N math::matmul(*dout, false, y_matrix, true, 1, &dx_matrix, 0, @@ -84,8 +84,8 @@ class MulGradKernel : public framework::OpKernel { } if (dy) { dy->mutable_data(ctx.GetPlace()); - Tensor dy_matrix = dy->dims().size() > 2 ? framework::FlattenToMatrix( - *dy, y_num_row_dims) + Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix( + *dy, y_num_col_dims) : *dy; // dy = x' * dout. dy K x N, dout : M x N, x : M x K math::matmul(x_matrix, true, *dout, false, 1, &dy_matrix, 0, diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 209281a45b..fa8f0ff1a8 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -31,11 +31,11 @@ class RowwiseAddOp : public framework::OperatorWithKernel { x_dims.size(), b_dims.size(), "The rank of input `X` must be larger than the one of input `b`."); - int num_row_dims = b_dims.size(); + int num_col_dims = x_dims.size() - b_dims.size(); - PADDLE_ENFORCE_EQ(framework::slice_ddim( - x_dims, x_dims.size() - num_row_dims, x_dims.size()), - b_dims, "The width of two operands must be same"); + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, + "The width of two operands must be same"); PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); ctx.Output("Out")->Resize(x_dims); } @@ -72,10 +72,10 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { x_dims.size(), b_dims.size(), "The rank of input `X` must be larger than the one of input `b`."); - int num_row_dims = b_dims.size(); - PADDLE_ENFORCE_EQ(framework::slice_ddim( - x_dims, x_dims.size() - num_row_dims, x_dims.size()), - b_dims, "The width of two operands must be same"); + int num_col_dims = x_dims.size() - b_dims.size(); + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, + "The width of two operands must be same"); auto *dx = ctx.Output(framework::GradVarName("X")); auto *db = ctx.Output(framework::GradVarName("b")); if (dx) dx->Resize(x_dims); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index a52a53a7d2..35774b9409 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -33,11 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); out->mutable_data(context.GetPlace()); - int num_row_dims = context.Input("b")->dims().size(); + int num_col_dims = context.Input("X")->dims().size() - + context.Input("b")->dims().size(); auto input = - EigenMatrix::Reshape(*context.Input("X"), num_row_dims); + EigenMatrix::Reshape(*context.Input("X"), num_col_dims); auto bias = EigenVector::Flatten(*context.Input("b")); - auto output = EigenMatrix::Reshape(*out, num_row_dims); + auto output = EigenMatrix::Reshape(*out, num_col_dims); const int bias_size = bias.dimension(0); const int rest_size = input.size() / bias_size; @@ -55,14 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel { auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); auto* db = context.Output(framework::GradVarName("b")); - int num_row_dims = context.Input("b")->dims().size(); + int num_col_dims = context.Input("X")->dims().size() - + context.Input("b")->dims().size(); - auto out_grad = EigenMatrix::Reshape(*dout, num_row_dims); + auto out_grad = EigenMatrix::Reshape(*dout, num_col_dims); auto place = context.GetEigenDevice(); if (dx) { dx->mutable_data(context.GetPlace()); - EigenMatrix::Reshape(*dx, num_row_dims).device(place) = out_grad; + EigenMatrix::Reshape(*dx, num_col_dims).device(place) = out_grad; } if (db) { diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index 3ea73d94b2..d8057f4ffa 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -26,7 +26,7 @@ class TestMulOp2(unittest.TestCase): 'X': np.random.random((15, 4, 12, 10)).astype("float32"), 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32") } - self.attrs = {'x_num_row_dims': 2, 'y_num_row_dims': 3} + self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2} self.outputs = { 'Out': np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9)) @@ -69,7 +69,7 @@ class TestMulGradOp(GradientChecker): class TestMulGradTest2(GradientChecker): def setUp(self): self.op = Operator( - "mul", X="X", Y="Y", Out="Out", x_num_row_dims=2, y_num_row_dims=3) + "mul", X="X", Y="Y", Out="Out", x_num_col_dims=2, y_num_col_dims=2) self.inputs = { "X": np.random.random((15, 4, 12, 10)).astype("float32"), "Y": np.random.random((4, 30, 8, 2, 9)).astype("float32") From 3d62c6dac48863cb0685d526b7c44b80aa3d5caa Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 6 Sep 2017 11:50:48 -0700 Subject: [PATCH 07/11] Fix bug --- paddle/framework/ddim.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 0e2cd3acc8..58e20625ce 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -288,14 +288,11 @@ DDim::DDim(std::initializer_list init_list) { // will be the product of tensor's first `num_col_dims` dimensions DDim flatten_to_2d(const DDim& src, int num_col_dims) { int rank = src.size(); - return make_ddim( - {static_cast(product(slice_ddim(src, 0, num_col_dims))), - static_cast(product(slice_ddim(src, num_col_dims, rank)))}); + return make_ddim({product(slice_ddim(src, 0, num_col_dims)), + product(slice_ddim(src, num_col_dims, rank))}); } -DDim flatten_to_1d(const DDim& src) { - return make_ddim({static_cast(product(src))}); -} +DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } } // namespace framework } // namespace paddle From 5aacd64b94644562b72c920397b7907561ba6675 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 6 Sep 2017 18:44:18 -0700 Subject: [PATCH 08/11] Follow comments --- paddle/framework/eigen.h | 4 ++-- paddle/framework/eigen_test.cc | 3 +-- paddle/framework/tensor_test.cc | 2 +- paddle/operators/mul_op.cc | 26 +++++++++++++------------- paddle/operators/mul_op.h | 8 ++++---- 5 files changed, 21 insertions(+), 22 deletions(-) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 0438e758e0..54bbeafcab 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -87,11 +87,11 @@ template { // Flatten reshapes a Tensor into an EigenVector. static typename EigenVector::Type Flatten(Tensor& tensor) { - return EigenVector::From(tensor, {static_cast(product(tensor.dims_))}); + return EigenVector::From(tensor, {product(tensor.dims_)}); } static typename EigenVector::ConstType Flatten(const Tensor& tensor) { - return EigenVector::From(tensor, {static_cast(product(tensor.dims_))}); + return EigenVector::From(tensor, {product(tensor.dims_)}); } }; diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index bae82fdb7d..bc4a2db32c 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -110,8 +110,7 @@ TEST(Eigen, Matrix) { TEST(Eigen, MatrixReshape) { Tensor t; - float* p = - t.mutable_data(make_ddim({2, 3, 6, 4}), platform::CPUPlace()); + float* p = t.mutable_data({2, 3, 6, 4}, platform::CPUPlace()); for (int i = 0; i < 2 * 3 * 6 * 4; ++i) { p[i] = static_cast(i); } diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index a2c2d19dc7..55302ea471 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -267,7 +267,7 @@ TEST(Tensor, ReshapeToMatrix) { using namespace paddle::framework; using namespace paddle::platform; Tensor src; - int* src_ptr = src.mutable_data(make_ddim({2, 3, 4, 9}), CPUPlace()); + int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace()); for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { src_ptr[i] = i; } diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index fb96d322e9..34595adedd 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -27,8 +27,8 @@ class MulOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); - int x_num_col_dims = GetAttr("x_num_col_dims"); - int y_num_col_dims = GetAttr("y_num_col_dims"); + int x_num_col_dims = Attr("x_num_col_dims"); + int y_num_col_dims = Attr("y_num_col_dims"); PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, "The rank of input tensor X(%s) should be larger than " @@ -58,19 +58,19 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output of mul op"); AddAttr( "x_num_col_dims", - "mul_op can take tensors with more than two dimensions as input `X`, " - "in that case, tensors will be reshaped to a matrix. The matrix's " - "first dimension(column length) will be the product of tensor's last " - "`num_col_dims` dimensions, and the matrix's second dimension(row " - "length) will be the product of tensor's first `rank - num_col_dims` " - "dimensions.") + R"DOC(mul_op can take tensors with more than two dimensions as input `X`, + in that case, tensors will be reshaped to a matrix. The matrix's first + dimension(column length) will be the product of tensor's last + `num_col_dims` dimensions, and the matrix's second dimension(row length) + will be the product of tensor's first `rank - num_col_dims` dimensions. + )DOC") .SetDefault(1) .EqualLargerThan(1); AddAttr( "y_num_col_dims", - "mul_op can take tensors with more than two dimensions as input `Y`, " - "in that case, tensors will be reshaped to a matrix. Just like input " - "`X`.") + R"DOC(mul_op can take tensors with more than two dimensions as input `Y`, + in that case, tensors will be reshaped to a matrix. Just like input `X`. + )DOC") .SetDefault(1) .EqualLargerThan(1); AddComment(R"DOC( @@ -98,9 +98,9 @@ class MulOpGrad : public framework::OperatorWithKernel { auto *y_grad = ctx.Output(framework::GradVarName("Y")); auto x_mat_dims = - framework::flatten_to_2d(x_dims, GetAttr("x_num_col_dims")); + framework::flatten_to_2d(x_dims, Attr("x_num_col_dims")); auto y_mat_dims = - framework::flatten_to_2d(y_dims, GetAttr("y_num_col_dims")); + framework::flatten_to_2d(y_dims, Attr("y_num_col_dims")); PADDLE_ENFORCE_EQ( x_mat_dims[0], out_dims[0], diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 6656ecaf1a..3c01f868bd 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -37,12 +37,12 @@ class MulKernel : public framework::OpKernel { const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix( - *x, context.template GetAttr("x_num_col_dims")) + *x, context.template Attr("x_num_col_dims")) : *x; const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix( - *y, context.template GetAttr("y_num_col_dims")) + *y, context.template Attr("y_num_col_dims")) : *y; z->mutable_data(context.GetPlace()); @@ -57,8 +57,8 @@ template class MulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - int x_num_col_dims = ctx.template GetAttr("x_num_col_dims"); - int y_num_col_dims = ctx.template GetAttr("y_num_col_dims"); + int x_num_col_dims = ctx.template Attr("x_num_col_dims"); + int y_num_col_dims = ctx.template Attr("y_num_col_dims"); const Tensor* x = ctx.Input("X"); const Tensor* y = ctx.Input("Y"); const Tensor x_matrix = From b7444306ba498cef508f90565a96661ebbe2ea3d Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 6 Sep 2017 19:30:06 -0700 Subject: [PATCH 09/11] Follow comments --- paddle/framework/attribute.h | 10 +++++----- paddle/operators/mul_op.cc | 4 ++-- python/paddle/v2/framework/tests/test_mul_op.py | 2 -- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 6968ffd838..2b788a76ca 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -53,11 +53,11 @@ class GreaterThanChecker { }; template -class EqualLargerThanChecker { +class EqualGreaterThanChecker { public: - explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} + explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { - PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fails."); + PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails."); } private: @@ -127,8 +127,8 @@ class TypedAttrChecker { return *this; } - TypedAttrChecker& EqualLargerThan(const T& lower_bound) { - value_checkers_.push_back(EqualLargerThanChecker(lower_bound)); + TypedAttrChecker& EqualGreaterThan(const T& lower_bound) { + value_checkers_.push_back(EqualGreaterThanChecker(lower_bound)); return *this; } diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 34595adedd..710a56a0e8 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -65,14 +65,14 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { will be the product of tensor's first `rank - num_col_dims` dimensions. )DOC") .SetDefault(1) - .EqualLargerThan(1); + .EqualGreaterThan(1); AddAttr( "y_num_col_dims", R"DOC(mul_op can take tensors with more than two dimensions as input `Y`, in that case, tensors will be reshaped to a matrix. Just like input `X`. )DOC") .SetDefault(1) - .EqualLargerThan(1); + .EqualGreaterThan(1); AddComment(R"DOC( Two Element Mul Operator. diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index d8057f4ffa..8c827e242e 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -99,7 +99,5 @@ class TestMulGradTest2(GradientChecker): no_grad_set={"Y"}) -# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library - if __name__ == '__main__': unittest.main() From 1d9a4d2e500416c2ba408174884a1d9102e45dae Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 6 Sep 2017 22:12:02 -0700 Subject: [PATCH 10/11] Move some comments to .h file --- paddle/framework/ddim.cc | 2 -- paddle/framework/ddim.h | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 58e20625ce..fc3d508553 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -284,8 +284,6 @@ DDim::DDim(std::initializer_list init_list) { *this = make_ddim(init_list); } -// Reshape a tensor to a matrix. The matrix's first dimension(column length) -// will be the product of tensor's first `num_col_dims` dimensions DDim flatten_to_2d(const DDim& src, int num_col_dims) { int rank = src.size(); return make_ddim({product(slice_ddim(src, 0, num_col_dims)), diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 1e7ca46bd9..48e14d16e3 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -115,6 +115,8 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); +// Reshape a tensor to a matrix. The matrix's first dimension(column length) +// will be the product of tensor's first `num_col_dims` dimensions DDim flatten_to_2d(const DDim& src, int num_col_dims); DDim flatten_to_1d(const DDim& src); From b6a46667de43325c38668c755189237b1b8a64d3 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 7 Sep 2017 14:37:29 -0700 Subject: [PATCH 11/11] test --- paddle/framework/ddim.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 48e14d16e3..ca29e7e8c7 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -116,7 +116,7 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); // Reshape a tensor to a matrix. The matrix's first dimension(column length) -// will be the product of tensor's first `num_col_dims` dimensions +// will be the product of tensor's first `num_col_dims` dimensions. DDim flatten_to_2d(const DDim& src, int num_col_dims); DDim flatten_to_1d(const DDim& src);