From b7444306ba498cef508f90565a96661ebbe2ea3d Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 6 Sep 2017 19:30:06 -0700 Subject: [PATCH] Follow comments --- paddle/framework/attribute.h | 10 +++++----- paddle/operators/mul_op.cc | 4 ++-- python/paddle/v2/framework/tests/test_mul_op.py | 2 -- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 6968ffd838..2b788a76ca 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -53,11 +53,11 @@ class GreaterThanChecker { }; template -class EqualLargerThanChecker { +class EqualGreaterThanChecker { public: - explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} + explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { - PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fails."); + PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails."); } private: @@ -127,8 +127,8 @@ class TypedAttrChecker { return *this; } - TypedAttrChecker& EqualLargerThan(const T& lower_bound) { - value_checkers_.push_back(EqualLargerThanChecker(lower_bound)); + TypedAttrChecker& EqualGreaterThan(const T& lower_bound) { + value_checkers_.push_back(EqualGreaterThanChecker(lower_bound)); return *this; } diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 34595adedd..710a56a0e8 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -65,14 +65,14 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { will be the product of tensor's first `rank - num_col_dims` dimensions. )DOC") .SetDefault(1) - .EqualLargerThan(1); + .EqualGreaterThan(1); AddAttr( "y_num_col_dims", R"DOC(mul_op can take tensors with more than two dimensions as input `Y`, in that case, tensors will be reshaped to a matrix. Just like input `X`. )DOC") .SetDefault(1) - .EqualLargerThan(1); + .EqualGreaterThan(1); AddComment(R"DOC( Two Element Mul Operator. diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index d8057f4ffa..8c827e242e 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -99,7 +99,5 @@ class TestMulGradTest2(GradientChecker): no_grad_set={"Y"}) -# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library - if __name__ == '__main__': unittest.main()