From 3994e91a678b8547af77b6b7f4629f122b0d9f07 Mon Sep 17 00:00:00 2001 From: guosheng Date: Fri, 8 Sep 2017 18:39:01 +0800 Subject: [PATCH 1/7] Add reduce_op --- paddle/operators/reduce_op.cc | 207 +++++++++++++++ paddle/operators/reduce_op.cu | 46 ++++ paddle/operators/reduce_op.h | 251 ++++++++++++++++++ .../v2/framework/tests/test_reduce_op.py | 92 +++++++ 4 files changed, 596 insertions(+) create mode 100644 paddle/operators/reduce_op.cc create mode 100644 paddle/operators/reduce_op.cu create mode 100644 paddle/operators/reduce_op.h create mode 100644 python/paddle/v2/framework/tests/test_reduce_op.py diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc new file mode 100644 index 0000000000..ea4bfc50b2 --- /dev/null +++ b/paddle/operators/reduce_op.cc @@ -0,0 +1,207 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/reduce_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using framework::DDim; + +class ReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported"); + int dim = static_cast(ctx.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + PADDLE_ENFORCE_LT( + dim, x_rank, + "The dim should be in the range [-rank(input), rank(input)]"); + bool keep_dim = true; // TODO; + auto dims_vector = vectorize(x_dims); + if (keep_dim || x_rank == 1) { + dims_vector[dim] = 1; + } else { + dims_vector.erase(dims_vector.begin() + dim); + } + auto out_dims = framework::make_ddim(dims_vector); + ctx.Output("Out")->Resize(out_dims); + } +}; + +class ReduceGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported"); + int dim = static_cast(ctx.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + PADDLE_ENFORCE_LT( + dim, x_rank, + "The dim should be in the range [-rank(input), rank(input)]"); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + if (x_grad) x_grad->Resize(x_dims); + } +}; + +class ReduceSumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReduceSumOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor) The result tensor."); + AddComment(R"DOC( +ReduceMean operator computes the sum of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"); + AddAttr("dim", + "(int, default 0) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)]") + .SetDefault(0); + AddAttr("keep_dim", + "(bool, default fasle) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + } +}; + +class ReduceMeanOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReduceMeanOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor) The result tensor."); + AddComment(R"DOC( +ReduceMean operator computes the mean of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"); + AddAttr("dim", + "(int, default 0) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)]") + .SetDefault(0); + AddAttr("keep_dim", + "(bool, default fasle) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + } +}; + +class ReduceMaxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReduceMaxOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor) The result tensor."); + AddComment(R"DOC( +ReduceMax operator computes the maximum of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"); + AddAttr("dim", + "(int, default 0) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)]") + .SetDefault(0); + AddAttr("keep_dim", + "(bool, default fasle) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + } +}; + +class ReduceMinOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReduceMinOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor) The result tensor."); + AddComment(R"DOC( +ReduceMin operator computes the minimum of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"); + AddAttr("dim", + "(int, default 0) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)]") + .SetDefault(0); + AddAttr("keep_dim", + "(bool, default fasle) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad, + ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_sum, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_sum_grad, + ops::ReduceGradKernel); + +REGISTER_OP(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker, + reduce_mean_grad, ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_mean, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_mean_grad, + ops::ReduceGradKernel); + +REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad, + ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_max, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_max_grad, + ops::ReduceGradKernel); + +REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad, + ops::ReduceGradOp); +REGISTER_OP_CPU_KERNEL( + reduce_min, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_min_grad, + ops::ReduceGradKernel); diff --git a/paddle/operators/reduce_op.cu b/paddle/operators/reduce_op.cu new file mode 100644 index 0000000000..9effc17ed3 --- /dev/null +++ b/paddle/operators/reduce_op.cu @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/reduce_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + reduce_sum, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_sum_grad, + ops::ReduceGradEigenKernel); + +REGISTER_OP_GPU_KERNEL( + reduce_mean, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_mean_grad, + ops::ReduceGradKernel); + +REGISTER_OP_GPU_KERNEL( + reduce_max, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_max_grad, + ops::ReduceGradKernel); + +REGISTER_OP_GPU_KERNEL( + reduce_min, + ops::ReduceKernel); +REGISTER_OP_GPU_KERNEL(reduce_min_grad, + ops::ReduceGradKernel); \ No newline at end of file diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h new file mode 100644 index 0000000000..9fd7d335ac --- /dev/null +++ b/paddle/operators/reduce_op.h @@ -0,0 +1,251 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/operators/math/math_function.h" + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; +template +using EigenTensor = framework::EigenTensor; + +struct SumFunctor { + template + void operator()(const Place& place, In& in, Out& out, const Dim& dim) { + out.device(place) = in.sum(dim); + } +}; + +struct SumGradFunctor { + template + void operator()(const Place& place, In_Const& in, In& in_grad, Out& out, + Out& out_grad, const Dim& dim, int size) { + in_grad.device(place) = out_grad.broadcast(dim); + } +}; + +struct MeanFunctor { + template + void operator()(const Place& place, In& in, Out& out, const Dim& dim) { + out.device(place) = in.mean(dim); + } +}; + +struct MeanGradFunctor { + template + void operator()(const Place& place, In_Const& in, In& in_grad, Out& out, + Out& out_grad, const Dim& dim, int size) { + in_grad.device(place) = out_grad.broadcast(dim) / in_grad.constant(size); + } +}; + +struct MaxFunctor { + template + void operator()(const Place& place, In& in, Out& out, const Dim& dim) { + out.device(place) = in.maximum(dim); + } +}; + +struct MinFunctor { + template + void operator()(const Place& place, In& in, Out& out, const Dim& dim) { + out.device(place) = in.minimum(dim); + } +}; + +struct MaxOrMinGradFunctor { + template + void operator()(const Place& place, In_Const& in, In& in_grad, Out& out, + Out& out_grad, const Dim& dim, int size) { + auto equals = in == out.broadcast(dim); + auto ones = in_grad.constant(1); + auto zeros = in_grad.constant(0); + in_grad.device(place) = + out_grad.broadcast(dim) * equals.select(ones, zeros); + } +}; + +template +class ReduceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + ReduceCompute<1>(context); + break; + case 2: + ReduceCompute<2>(context); + break; + case 3: + ReduceCompute<3>(context); + break; + case 4: + ReduceCompute<4>(context); + break; + case 5: + ReduceCompute<5>(context); + break; + case 6: + ReduceCompute<6>(context); + break; + } + } + + private: + template + void ReduceCompute(const framework::ExecutionContext& context) const { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + + auto x = EigenTensor::From(*input); + auto x_rank = static_cast(x.dimensions().size()); + int dim = static_cast(context.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + auto reduce_dim = Eigen::array({{dim}}); + // construct the squeezed output tensor + bool keep_dim = true; // static_cast(context.Attr("keep_dim")); + DDim dims = output->dims(); + auto dims_vector = vectorize(dims); + if (keep_dim && x_rank > 1) { + dims_vector.erase(dims_vector.begin() + dim); + dims = framework::make_ddim(dims_vector); + } + auto out = EigenTensor < T, D == 1 ? 1 : (D - 1) > ::From(*output, dims); + auto& place = context.GetEigenDevice(); + Functor functor; + functor(place, x, out, reduce_dim); + } +}; + +template +class ReduceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + ReduceCompute<1>(context); + break; + case 2: + ReduceCompute<2>(context); + break; + case 3: + ReduceCompute<3>(context); + break; + case 4: + ReduceCompute<4>(context); + break; + case 5: + ReduceCompute<5>(context); + break; + case 6: + ReduceCompute<6>(context); + break; + } + } + + private: + template + void ReduceCompute(const framework::ExecutionContext& context) const { + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Out"); + auto* input2 = context.Input(framework::GradVarName("Out")); + auto* output = context.Output(framework::GradVarName("X")); + + if (output != nullptr) { + output->mutable_data(context.GetPlace()); + auto x = EigenTensor::From(*input0); + auto x_grad = EigenTensor::From(*output); + auto x_rank = static_cast(x.dimensions().size()); + int dim = static_cast(context.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + DDim dims = input0->dims(); + dims[dim] = 1; + auto x_reduce = EigenTensor::From(*input1, dims); + auto x_reduce_grad = EigenTensor::From(*input2, dims); + + Eigen::array braodcast_dim; + for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1; + braodcast_dim[dim] = input0->dims()[dim]; + auto& place = context.GetEigenDevice(); + Functor functor; + functor(place, x, x_grad, x_reduce, x_reduce_grad, braodcast_dim, + braodcast_dim[dim]); + } + } +}; + +// For EigenTensor unsupported reduce +template +class ReduceGradEigenFreeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Input("Out"); + auto* x_grad = context.Output(framework::GradVarName("X")); + auto* out_grad = context.Input(framework::GradVarName("Out")); + if (x_grad != nullptr) { + DDim dims = x->dims(); + int rank = dims.size(); + int dim = static_cast(context.Attr("dim")); + if (dim < 0) dim = rank + dim; + + auto* x_data = x->data(); + auto* x_grad_data = x_grad->mutable_data(context.GetPlace()); + auto* out_data = out->data(); + auto* out_grad_data = out_grad->data(); + + int outer_count = 1; + int inner_count = 1; + int mid_count = dims[dim]; + for (int i = 0; i < dim; ++i) { + outer_count *= dims[i]; + } + for (int i = dim + 1; i < rank; ++i) { + inner_count *= dims[i]; + } + + int x_offset = 0; // offset on raw data + int out_offset = 0; // offset on reduced data + Functor functor; + for (int i = 0; i < outer_count; ++i) { + for (int j = 0; j < inner_count; ++j) { + out_offset = inner_count * i + j; + for (int k = 0; k < mid_count; ++k) { + x_offset = (inner_count * mid_count) * i + inner_count * k + j; + functor(x_data + x_offset, x_grad_data + x_offset, + out_data + out_offset, out_grad_data + out_offset, + mid_count); + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_reduce_op.py b/python/paddle/v2/framework/tests/test_reduce_op.py new file mode 100644 index 0000000000..49ef8eabd2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_reduce_op.py @@ -0,0 +1,92 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta +from paddle.v2.framework.op import Operator + + +class TestSumOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_sum" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -2} + out = self.inputs['X'].sum(axis=self.attrs['dim']) + self.outputs = {'Out': out} + + +class TestSumGradOp(GradientChecker): + def test_normal(self): + op = Operator("reduce_sum", X="X", Out="Out", dim=-2) + # use small size to decrease the error of numerical calculation + inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.check_grad(op, inputs, set(["X"]), "Out") + + def test_1d_tensor(self): + op = Operator("reduce_sum", X="X", Out="Out", dim=0) + # use small size to decrease the error of numerical calculation + inputs = {'X': np.random.random(10).astype("float32")} + self.check_grad(op, inputs, set(["X"]), "Out") + + +class TestKeepdimSumOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_sum" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -2} + out = self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True) + self.outputs = {'Out': out} + + +class TestMeanOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -1} + out = self.inputs['X'].mean(axis=self.attrs['dim']) + self.outputs = {'Out': out} + + +class TestMeanGradOp(GradientChecker): + def test_normal(self): + op = Operator("reduce_mean", X="X", Out="Out", dim=-2) + # use small size to decrease the error of numerical calculation + inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.check_grad(op, inputs, set(["X"]), "Out") + + def test_1d_tensor(self): + op = Operator("reduce_mean", X="X", Out="Out", dim=0) + # use small size to decrease the error of numerical calculation + inputs = {'X': np.random.random(10).astype("float32")} + self.check_grad(op, inputs, set(["X"]), "Out") + + +class TestMaxOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_max" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -1} + out = self.inputs['X'].max(axis=self.attrs['dim']) + self.outputs = {'Out': out} + + +class TestMinOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "reduce_max" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': -2} + out = self.inputs['X'].min(axis=self.attrs['dim']) + self.outputs = {'Out': out} + + +if __name__ == '__main__': + unittest.main() From c8d877195b9763ec2da9eb480bb6858cee834359 Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 14 Sep 2017 01:11:31 +0800 Subject: [PATCH 2/7] Revise the reduce_op unit test accordingly --- paddle/operators/reduce_op.cc | 56 +++++---- paddle/operators/reduce_op.cu | 4 +- paddle/operators/reduce_op.h | 2 +- .../v2/framework/tests/test_reduce_op.py | 113 +++++++++--------- 4 files changed, 89 insertions(+), 86 deletions(-) diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index ea4bfc50b2..20e6319730 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -30,12 +30,14 @@ class ReduceOp : public framework::OperatorWithKernel { auto x_dims = ctx.Input("X")->dims(); auto x_rank = x_dims.size(); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported"); - int dim = static_cast(ctx.Attr("dim")); + int dim = ctx.Attr("dim"); if (dim < 0) dim = x_rank + dim; PADDLE_ENFORCE_LT( dim, x_rank, - "The dim should be in the range [-rank(input), rank(input)]"); - bool keep_dim = true; // TODO; + "The dim should be in the range [-rank(input), rank(input))"); + PADDLE_ENFORCE_GE(ctx.Attr("keep_dim"), 0, "keep_dim must be 0 or 1"); + PADDLE_ENFORCE_LE(ctx.Attr("keep_dim"), 1, "keep_dim must be 0 or 1"); + bool keep_dim = ctx.Attr("keep_dim") == 1; auto dims_vector = vectorize(x_dims); if (keep_dim || x_rank == 1) { dims_vector[dim] = 1; @@ -59,11 +61,11 @@ class ReduceGradOp : public framework::OperatorWithKernel { auto x_dims = ctx.Input("X")->dims(); auto x_rank = x_dims.size(); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported"); - int dim = static_cast(ctx.Attr("dim")); + int dim = ctx.Attr("dim"); if (dim < 0) dim = x_rank + dim; PADDLE_ENFORCE_LT( dim, x_rank, - "The dim should be in the range [-rank(input), rank(input)]"); + "The dim should be in the range [-rank(input), rank(input))"); auto *x_grad = ctx.Output(framework::GradVarName("X")); if (x_grad) x_grad->Resize(x_dims); } @@ -84,12 +86,13 @@ The result tensor has 1 fewer dimension than the input unless `keep_dim` is true )DOC"); AddAttr("dim", "(int, default 0) The dimension to reduce. " - "Must be in the range [-rank(input), rank(input)]") + "Must be in the range [-rank(input), rank(input))") + .SetDefault(0); + AddAttr( + "keep_dim", + "(int, default 0) " + "Must be 0 or 1. If 1, retain the reduced dimension with length 1.") .SetDefault(0); - AddAttr("keep_dim", - "(bool, default fasle) " - "If true, retain the reduced dimension with length 1.") - .SetDefault(false); } }; @@ -108,12 +111,13 @@ The result tensor has 1 fewer dimension than the input unless `keep_dim` is true )DOC"); AddAttr("dim", "(int, default 0) The dimension to reduce. " - "Must be in the range [-rank(input), rank(input)]") + "Must be in the range [-rank(input), rank(input))") + .SetDefault(0); + AddAttr( + "keep_dim", + "(int, default 0) " + "Must be 0 or 1. If 1, retain the reduced dimension with length 1.") .SetDefault(0); - AddAttr("keep_dim", - "(bool, default fasle) " - "If true, retain the reduced dimension with length 1.") - .SetDefault(false); } }; @@ -132,12 +136,13 @@ The result tensor has 1 fewer dimension than the input unless `keep_dim` is true )DOC"); AddAttr("dim", "(int, default 0) The dimension to reduce. " - "Must be in the range [-rank(input), rank(input)]") + "Must be in the range [-rank(input), rank(input))") + .SetDefault(0); + AddAttr( + "keep_dim", + "(int, default 0) " + "Must be 0 or 1. If 1, retain the reduced dimension with length 1.") .SetDefault(0); - AddAttr("keep_dim", - "(bool, default fasle) " - "If true, retain the reduced dimension with length 1.") - .SetDefault(false); } }; @@ -156,12 +161,13 @@ The result tensor has 1 fewer dimension than the input unless `keep_dim` is true )DOC"); AddAttr("dim", "(int, default 0) The dimension to reduce. " - "Must be in the range [-rank(input), rank(input)]") + "Must be in the range [-rank(input), rank(input))") + .SetDefault(0); + AddAttr( + "keep_dim", + "(int, default 0) " + "Must be 0 or 1. If 1, retain the reduced dimension with length 1.") .SetDefault(0); - AddAttr("keep_dim", - "(bool, default fasle) " - "If true, retain the reduced dimension with length 1.") - .SetDefault(false); } }; diff --git a/paddle/operators/reduce_op.cu b/paddle/operators/reduce_op.cu index 9effc17ed3..2dffea3a3a 100644 --- a/paddle/operators/reduce_op.cu +++ b/paddle/operators/reduce_op.cu @@ -21,8 +21,8 @@ REGISTER_OP_GPU_KERNEL( reduce_sum, ops::ReduceKernel); REGISTER_OP_GPU_KERNEL(reduce_sum_grad, - ops::ReduceGradEigenKernel); + ops::ReduceGradKernel); REGISTER_OP_GPU_KERNEL( reduce_mean, diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h index 9fd7d335ac..0d62fa7d15 100644 --- a/paddle/operators/reduce_op.h +++ b/paddle/operators/reduce_op.h @@ -127,7 +127,7 @@ class ReduceKernel : public framework::OpKernel { if (dim < 0) dim = x_rank + dim; auto reduce_dim = Eigen::array({{dim}}); // construct the squeezed output tensor - bool keep_dim = true; // static_cast(context.Attr("keep_dim")); + bool keep_dim = context.Attr("keep_dim") == 1; DDim dims = output->dims(); auto dims_vector = vectorize(dims); if (keep_dim && x_rank > 1) { diff --git a/python/paddle/v2/framework/tests/test_reduce_op.py b/python/paddle/v2/framework/tests/test_reduce_op.py index 49ef8eabd2..58951f2902 100644 --- a/python/paddle/v2/framework/tests/test_reduce_op.py +++ b/python/paddle/v2/framework/tests/test_reduce_op.py @@ -1,91 +1,88 @@ import unittest import numpy as np -from gradient_checker import GradientChecker, create_op -from op_test_util import OpTestMeta -from paddle.v2.framework.op import Operator +from op_test import OpTest -class TestSumOp(unittest.TestCase): - __metaclass__ = OpTestMeta - +class TestSumOp(OpTest): def setUp(self): - self.type = "reduce_sum" + self.op_type = "reduce_sum" self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.attrs = {'dim': -2} - out = self.inputs['X'].sum(axis=self.attrs['dim']) - self.outputs = {'Out': out} + self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + def test_check_output(self): + self.check_output() -class TestSumGradOp(GradientChecker): - def test_normal(self): - op = Operator("reduce_sum", X="X", Out="Out", dim=-2) - # use small size to decrease the error of numerical calculation - inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.check_grad(op, inputs, set(["X"]), "Out") + def test_check_grad(self): + self.check_grad(['X'], 'Out') - def test_1d_tensor(self): - op = Operator("reduce_sum", X="X", Out="Out", dim=0) - # use small size to decrease the error of numerical calculation - inputs = {'X': np.random.random(10).astype("float32")} - self.check_grad(op, inputs, set(["X"]), "Out") +class TestMeanOp(OpTest): + def setUp(self): + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} + self.attrs = {'dim': 1} + self.outputs = {'Out': self.inputs['X'].mean(axis=self.attrs['dim'])} -class TestKeepdimSumOp(unittest.TestCase): - __metaclass__ = OpTestMeta + def test_check_output(self): + self.check_output() - def setUp(self): - self.type = "reduce_sum" - self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.attrs = {'dim': -2} - out = self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True) - self.outputs = {'Out': out} + def test_check_grad(self): + self.check_grad(['X'], 'Out') -class TestMeanOp(unittest.TestCase): - __metaclass__ = OpTestMeta +class TestMaxOp(OpTest): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): - self.type = "reduce_mean" + self.op_type = "reduce_max" self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} self.attrs = {'dim': -1} - out = self.inputs['X'].mean(axis=self.attrs['dim']) - self.outputs = {'Out': out} + self.outputs = {'Out': self.inputs['X'].max(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output() -class TestMeanGradOp(GradientChecker): - def test_normal(self): - op = Operator("reduce_mean", X="X", Out="Out", dim=-2) - # use small size to decrease the error of numerical calculation - inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.check_grad(op, inputs, set(["X"]), "Out") +class TestMinOp(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" - def test_1d_tensor(self): - op = Operator("reduce_mean", X="X", Out="Out", dim=0) - # use small size to decrease the error of numerical calculation - inputs = {'X': np.random.random(10).astype("float32")} - self.check_grad(op, inputs, set(["X"]), "Out") + def setUp(self): + self.op_type = "reduce_min" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': 2} + self.outputs = {'Out': self.inputs['X'].min(axis=self.attrs['dim'])} + def test_check_output(self): + self.check_output() -class TestMaxOp(unittest.TestCase): - __metaclass__ = OpTestMeta +class TestKeepDimReduce(OpTest): def setUp(self): - self.type = "reduce_max" + self.op_type = "reduce_sum" self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.attrs = {'dim': -1} - out = self.inputs['X'].max(axis=self.attrs['dim']) - self.outputs = {'Out': out} + self.attrs = {'dim': -2, 'keep_dim': 1} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True) + } + + def test_check_output(self): + self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') -class TestMinOp(unittest.TestCase): - __metaclass__ = OpTestMeta +class Test1DReduce(OpTest): def setUp(self): - self.type = "reduce_max" - self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.attrs = {'dim': -2} - out = self.inputs['X'].min(axis=self.attrs['dim']) - self.outputs = {'Out': out} + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random(20).astype("float32")} + self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') if __name__ == '__main__': From 630273d45361c7832d1dabbd9e44c4ae6cdb3864 Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 14 Sep 2017 15:21:29 +0800 Subject: [PATCH 3/7] Fix reduce_op according to CI log --- paddle/operators/reduce_op.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h index 0d62fa7d15..f0d4e1f95c 100644 --- a/paddle/operators/reduce_op.h +++ b/paddle/operators/reduce_op.h @@ -14,8 +14,6 @@ #pragma once -#include "paddle/operators/math/math_function.h" - #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" From 8b3bf28c6b5da73d919b0414361473bee638f414 Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 21 Sep 2017 11:12:32 +0800 Subject: [PATCH 4/7] Refine reduce_op and follow comments --- paddle/operators/CMakeLists.txt | 7 ++ paddle/operators/reduce_op.cc | 147 ++++++++++++++------------------ paddle/operators/reduce_op.h | 63 +++++++------- 3 files changed, 103 insertions(+), 114 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f8b0bce681..eec0d0b595 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -61,6 +61,13 @@ function(op_library TARGET) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") endif() + + # reduce_op contains several operators + if ("${TARGET}" STREQUAL "reduce_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") + endif() # pybind USE_NO_KERNEL_OP file(READ ${TARGET}.cc TARGET_CONTENT) diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index 20e6319730..89f54fe74b 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -18,7 +18,7 @@ namespace paddle { namespace operators { using framework::Tensor; -using framework::DDim; +using framework::LoDTensor; class ReduceOp : public framework::OperatorWithKernel { public: @@ -26,18 +26,19 @@ class ReduceOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of ReduceOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of ReduceOp should not be null."); auto x_dims = ctx.Input("X")->dims(); auto x_rank = x_dims.size(); - PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported"); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); int dim = ctx.Attr("dim"); if (dim < 0) dim = x_rank + dim; PADDLE_ENFORCE_LT( dim, x_rank, - "The dim should be in the range [-rank(input), rank(input))"); - PADDLE_ENFORCE_GE(ctx.Attr("keep_dim"), 0, "keep_dim must be 0 or 1"); - PADDLE_ENFORCE_LE(ctx.Attr("keep_dim"), 1, "keep_dim must be 0 or 1"); - bool keep_dim = ctx.Attr("keep_dim") == 1; + "The dim should be in the range [-rank(input), rank(input))."); + bool keep_dim = ctx.Attr("keep_dim"); auto dims_vector = vectorize(x_dims); if (keep_dim || x_rank == 1) { dims_vector[dim] = 1; @@ -45,7 +46,7 @@ class ReduceOp : public framework::OperatorWithKernel { dims_vector.erase(dims_vector.begin() + dim); } auto out_dims = framework::make_ddim(dims_vector); - ctx.Output("Out")->Resize(out_dims); + ctx.Output("Out")->Resize(out_dims); } }; @@ -55,119 +56,101 @@ class ReduceGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); + "Input(Out@GRAD) should not be null."); auto x_dims = ctx.Input("X")->dims(); auto x_rank = x_dims.size(); - PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported"); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); int dim = ctx.Attr("dim"); if (dim < 0) dim = x_rank + dim; PADDLE_ENFORCE_LT( dim, x_rank, - "The dim should be in the range [-rank(input), rank(input))"); - auto *x_grad = ctx.Output(framework::GradVarName("X")); + "The dim should be in the range [-rank(input), rank(input))."); + auto *x_grad = + ctx.Output(framework::GradVarName("X")); if (x_grad) x_grad->Resize(x_dims); } }; -class ReduceSumOpMaker : public framework::OpProtoAndCheckerMaker { +class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { public: - ReduceSumOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + ReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "X", "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); AddOutput("Out", "(Tensor) The result tensor."); - AddComment(R"DOC( -ReduceMean operator computes the sum of input tensor along the given dimension. -The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. -)DOC"); AddAttr("dim", "(int, default 0) The dimension to reduce. " "Must be in the range [-rank(input), rank(input))") .SetDefault(0); - AddAttr( - "keep_dim", - "(int, default 0) " - "Must be 0 or 1. If 1, retain the reduced dimension with length 1.") - .SetDefault(0); + AddAttr("keep_dim", + "(bool, default false) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + comment_ = R"DOC( +{ReduceOP} operator computes the {reduce} of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. +)DOC"; + AddComment(comment_); + } + + protected: + std::string comment_; + + void Replace(std::string &src, std::string from, std::string to) { + std::size_t len_from = std::strlen(from.c_str()); + std::size_t len_to = std::strlen(to.c_str()); + for (std::size_t pos = src.find(from); pos != std::string::npos; + pos = src.find(from, pos + len_to)) { + src.replace(pos, len_from, to); + } + } + + void SetComment(std::string name, std::string op) { + Replace(comment_, "{ReduceOP}", name); + Replace(comment_, "{reduce}", op); } }; -class ReduceMeanOpMaker : public framework::OpProtoAndCheckerMaker { +class ReduceSumOpMaker : public ReduceOpMaker { + public: + ReduceSumOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : ReduceOpMaker(proto, op_checker) { + SetComment("ReduceSum", "sum"); + AddComment(comment_); + } +}; + +class ReduceMeanOpMaker : public ReduceOpMaker { public: ReduceMeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "X", - "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); - AddOutput("Out", "(Tensor) The result tensor."); - AddComment(R"DOC( -ReduceMean operator computes the mean of input tensor along the given dimension. -The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. -)DOC"); - AddAttr("dim", - "(int, default 0) The dimension to reduce. " - "Must be in the range [-rank(input), rank(input))") - .SetDefault(0); - AddAttr( - "keep_dim", - "(int, default 0) " - "Must be 0 or 1. If 1, retain the reduced dimension with length 1.") - .SetDefault(0); + : ReduceOpMaker(proto, op_checker) { + SetComment("ReduceMean", "mean"); + AddComment(comment_); } }; -class ReduceMaxOpMaker : public framework::OpProtoAndCheckerMaker { +class ReduceMaxOpMaker : public ReduceOpMaker { public: ReduceMaxOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "X", - "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); - AddOutput("Out", "(Tensor) The result tensor."); - AddComment(R"DOC( -ReduceMax operator computes the maximum of input tensor along the given dimension. -The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. -)DOC"); - AddAttr("dim", - "(int, default 0) The dimension to reduce. " - "Must be in the range [-rank(input), rank(input))") - .SetDefault(0); - AddAttr( - "keep_dim", - "(int, default 0) " - "Must be 0 or 1. If 1, retain the reduced dimension with length 1.") - .SetDefault(0); + : ReduceOpMaker(proto, op_checker) { + SetComment("ReduceMax", "max"); + AddComment(comment_); } }; -class ReduceMinOpMaker : public framework::OpProtoAndCheckerMaker { +class ReduceMinOpMaker : public ReduceOpMaker { public: ReduceMinOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "X", - "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); - AddOutput("Out", "(Tensor) The result tensor."); - AddComment(R"DOC( -ReduceMin operator computes the minimum of input tensor along the given dimension. -The result tensor has 1 fewer dimension than the input unless `keep_dim` is true. -)DOC"); - AddAttr("dim", - "(int, default 0) The dimension to reduce. " - "Must be in the range [-rank(input), rank(input))") - .SetDefault(0); - AddAttr( - "keep_dim", - "(int, default 0) " - "Must be 0 or 1. If 1, retain the reduced dimension with length 1.") - .SetDefault(0); + : ReduceOpMaker(proto, op_checker) { + SetComment("ReduceMin", "min"); + AddComment(comment_); } }; diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h index f0d4e1f95c..972bd7bd46 100644 --- a/paddle/operators/reduce_op.h +++ b/paddle/operators/reduce_op.h @@ -27,61 +27,60 @@ template ; struct SumFunctor { - template - void operator()(const Place& place, In& in, Out& out, const Dim& dim) { - out.device(place) = in.sum(dim); + template + void operator()(const Place& place, X& x, Y& y, const Dim& dim) { + y.device(place) = x.sum(dim); } }; struct SumGradFunctor { - template - void operator()(const Place& place, In_Const& in, In& in_grad, Out& out, - Out& out_grad, const Dim& dim, int size) { - in_grad.device(place) = out_grad.broadcast(dim); + void operator()(const Place& place, X& x, Y& y, DX& dx, DY& dy, + const Dim& dim, int size) { + dx.device(place) = dy.broadcast(dim); } }; struct MeanFunctor { - template - void operator()(const Place& place, In& in, Out& out, const Dim& dim) { - out.device(place) = in.mean(dim); + template + void operator()(const Place& place, X& x, Y& y, const Dim& dim) { + y.device(place) = x.mean(dim); } }; struct MeanGradFunctor { - template - void operator()(const Place& place, In_Const& in, In& in_grad, Out& out, - Out& out_grad, const Dim& dim, int size) { - in_grad.device(place) = out_grad.broadcast(dim) / in_grad.constant(size); + void operator()(const Place& place, X& x, Y& y, DX& dx, DY& dy, + const Dim& dim, int size) { + dx.device(place) = dy.broadcast(dim) / dx.constant(size); } }; struct MaxFunctor { - template - void operator()(const Place& place, In& in, Out& out, const Dim& dim) { - out.device(place) = in.maximum(dim); + template + void operator()(const Place& place, X& x, Y& y, const Dim& dim) { + y.device(place) = x.maximum(dim); } }; struct MinFunctor { - template - void operator()(const Place& place, In& in, Out& out, const Dim& dim) { - out.device(place) = in.minimum(dim); + template + void operator()(const Place& place, X& x, Y& y, const Dim& dim) { + y.device(place) = x.minimum(dim); } }; struct MaxOrMinGradFunctor { - template - void operator()(const Place& place, In_Const& in, In& in_grad, Out& out, - Out& out_grad, const Dim& dim, int size) { - auto equals = in == out.broadcast(dim); - auto ones = in_grad.constant(1); - auto zeros = in_grad.constant(0); - in_grad.device(place) = - out_grad.broadcast(dim) * equals.select(ones, zeros); + void operator()(const Place& place, X& x, Y& y, DX& dx, DY& dy, + const Dim& dim, int size) { + auto equals = x == y.broadcast(dim); + auto ones = dx.constant(1); + auto zeros = dx.constant(0); + dx.device(place) = dy.broadcast(dim) * equals.select(ones, zeros); } }; @@ -125,7 +124,7 @@ class ReduceKernel : public framework::OpKernel { if (dim < 0) dim = x_rank + dim; auto reduce_dim = Eigen::array({{dim}}); // construct the squeezed output tensor - bool keep_dim = context.Attr("keep_dim") == 1; + bool keep_dim = context.Attr("keep_dim"); DDim dims = output->dims(); auto dims_vector = vectorize(dims); if (keep_dim && x_rank > 1) { @@ -191,7 +190,7 @@ class ReduceGradKernel : public framework::OpKernel { braodcast_dim[dim] = input0->dims()[dim]; auto& place = context.GetEigenDevice(); Functor functor; - functor(place, x, x_grad, x_reduce, x_reduce_grad, braodcast_dim, + functor(place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim, braodcast_dim[dim]); } } @@ -235,8 +234,8 @@ class ReduceGradEigenFreeKernel : public framework::OpKernel { out_offset = inner_count * i + j; for (int k = 0; k < mid_count; ++k) { x_offset = (inner_count * mid_count) * i + inner_count * k + j; - functor(x_data + x_offset, x_grad_data + x_offset, - out_data + out_offset, out_grad_data + out_offset, + functor(x_data + x_offset, out_data + out_offset, + x_grad_data + x_offset, out_grad_data + out_offset, mid_count); } } From 1295e5ef5467a0a068179da243c20bc05e61f921 Mon Sep 17 00:00:00 2001 From: guosheng Date: Sun, 24 Sep 2017 16:07:14 +0800 Subject: [PATCH 5/7] Refine reduce_op unit test and add newline at end of file --- paddle/operators/reduce_op.cu | 2 +- python/paddle/v2/framework/tests/test_reduce_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/reduce_op.cu b/paddle/operators/reduce_op.cu index 2dffea3a3a..595127b858 100644 --- a/paddle/operators/reduce_op.cu +++ b/paddle/operators/reduce_op.cu @@ -43,4 +43,4 @@ REGISTER_OP_GPU_KERNEL( ops::ReduceKernel); REGISTER_OP_GPU_KERNEL(reduce_min_grad, ops::ReduceGradKernel); \ No newline at end of file + ops::MaxOrMinGradFunctor>); diff --git a/python/paddle/v2/framework/tests/test_reduce_op.py b/python/paddle/v2/framework/tests/test_reduce_op.py index 58951f2902..70359d60cb 100644 --- a/python/paddle/v2/framework/tests/test_reduce_op.py +++ b/python/paddle/v2/framework/tests/test_reduce_op.py @@ -60,7 +60,7 @@ class TestKeepDimReduce(OpTest): def setUp(self): self.op_type = "reduce_sum" self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.attrs = {'dim': -2, 'keep_dim': 1} + self.attrs = {'dim': -2, 'keep_dim': True} self.outputs = { 'Out': self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True) } From 477a6a0978063501051d171038d7993d3d27022a Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 25 Sep 2017 16:07:25 +0800 Subject: [PATCH 6/7] Refine reduce_op, follow comments and remove ReduceGradEigenFreeKernel --- paddle/operators/reduce_op.cc | 16 ++++-- paddle/operators/reduce_op.h | 102 +++++++++------------------------- 2 files changed, 38 insertions(+), 80 deletions(-) diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index 89f54fe74b..61b33d4bbd 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -18,7 +18,6 @@ namespace paddle { namespace operators { using framework::Tensor; -using framework::LoDTensor; class ReduceOp : public framework::OperatorWithKernel { public: @@ -46,7 +45,11 @@ class ReduceOp : public framework::OperatorWithKernel { dims_vector.erase(dims_vector.begin() + dim); } auto out_dims = framework::make_ddim(dims_vector); - ctx.Output("Out")->Resize(out_dims); + ctx.Output("Out")->Resize(out_dims); + if (dim != 0) { + // Only pass LoD when not reducing on the first dim + ctx.ShareLoD("X", /*->*/ "Out"); + } } }; @@ -81,9 +84,12 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { "X", "(Tensor) The input tensor. Tensors with rank at most 6 are supported"); AddOutput("Out", "(Tensor) The result tensor."); - AddAttr("dim", - "(int, default 0) The dimension to reduce. " - "Must be in the range [-rank(input), rank(input))") + AddAttr( + "dim", + "(int, default 1) The dimension to reduce. " + "Must be in the range [-rank(input), rank(input)). " + "If `dim < 0`, the dim to reduce is `rank + dim`. " + "Noting that reducing on the first dim will make the LoD info lost.") .SetDefault(0); AddAttr("keep_dim", "(bool, default false) " diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h index 972bd7bd46..2fbf94e34f 100644 --- a/paddle/operators/reduce_op.h +++ b/paddle/operators/reduce_op.h @@ -80,6 +80,8 @@ struct MaxOrMinGradFunctor { auto equals = x == y.broadcast(dim); auto ones = dx.constant(1); auto zeros = dx.constant(0); + // If there are multiple minimum or maximum elements, the subgradient of + // each is the set [0, 1], and we pass gradient to all of them here. dx.device(place) = dy.broadcast(dim) * equals.select(ones, zeros); } }; @@ -145,102 +147,52 @@ class ReduceGradKernel : public framework::OpKernel { int rank = context.Input("X")->dims().size(); switch (rank) { case 1: - ReduceCompute<1>(context); + ReduceGradCompute<1>(context); break; case 2: - ReduceCompute<2>(context); + ReduceGradCompute<2>(context); break; case 3: - ReduceCompute<3>(context); + ReduceGradCompute<3>(context); break; case 4: - ReduceCompute<4>(context); + ReduceGradCompute<4>(context); break; case 5: - ReduceCompute<5>(context); + ReduceGradCompute<5>(context); break; case 6: - ReduceCompute<6>(context); + ReduceGradCompute<6>(context); break; } } private: template - void ReduceCompute(const framework::ExecutionContext& context) const { + void ReduceGradCompute(const framework::ExecutionContext& context) const { auto* input0 = context.Input("X"); auto* input1 = context.Input("Out"); auto* input2 = context.Input(framework::GradVarName("Out")); auto* output = context.Output(framework::GradVarName("X")); - if (output != nullptr) { - output->mutable_data(context.GetPlace()); - auto x = EigenTensor::From(*input0); - auto x_grad = EigenTensor::From(*output); - auto x_rank = static_cast(x.dimensions().size()); - int dim = static_cast(context.Attr("dim")); - if (dim < 0) dim = x_rank + dim; - DDim dims = input0->dims(); - dims[dim] = 1; - auto x_reduce = EigenTensor::From(*input1, dims); - auto x_reduce_grad = EigenTensor::From(*input2, dims); - - Eigen::array braodcast_dim; - for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1; - braodcast_dim[dim] = input0->dims()[dim]; - auto& place = context.GetEigenDevice(); - Functor functor; - functor(place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim, - braodcast_dim[dim]); - } - } -}; - -// For EigenTensor unsupported reduce -template -class ReduceGradEigenFreeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* out = context.Input("Out"); - auto* x_grad = context.Output(framework::GradVarName("X")); - auto* out_grad = context.Input(framework::GradVarName("Out")); - if (x_grad != nullptr) { - DDim dims = x->dims(); - int rank = dims.size(); - int dim = static_cast(context.Attr("dim")); - if (dim < 0) dim = rank + dim; - - auto* x_data = x->data(); - auto* x_grad_data = x_grad->mutable_data(context.GetPlace()); - auto* out_data = out->data(); - auto* out_grad_data = out_grad->data(); - - int outer_count = 1; - int inner_count = 1; - int mid_count = dims[dim]; - for (int i = 0; i < dim; ++i) { - outer_count *= dims[i]; - } - for (int i = dim + 1; i < rank; ++i) { - inner_count *= dims[i]; - } - - int x_offset = 0; // offset on raw data - int out_offset = 0; // offset on reduced data - Functor functor; - for (int i = 0; i < outer_count; ++i) { - for (int j = 0; j < inner_count; ++j) { - out_offset = inner_count * i + j; - for (int k = 0; k < mid_count; ++k) { - x_offset = (inner_count * mid_count) * i + inner_count * k + j; - functor(x_data + x_offset, out_data + out_offset, - x_grad_data + x_offset, out_grad_data + out_offset, - mid_count); - } - } - } - } + output->mutable_data(context.GetPlace()); + auto x = EigenTensor::From(*input0); + auto x_grad = EigenTensor::From(*output); + auto x_rank = static_cast(x.dimensions().size()); + int dim = static_cast(context.Attr("dim")); + if (dim < 0) dim = x_rank + dim; + DDim dims = input0->dims(); + dims[dim] = 1; + auto x_reduce = EigenTensor::From(*input1, dims); + auto x_reduce_grad = EigenTensor::From(*input2, dims); + + Eigen::array braodcast_dim; + for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1; + braodcast_dim[dim] = input0->dims()[dim]; + auto& place = context.GetEigenDevice(); + Functor functor; + functor(place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim, + braodcast_dim[dim]); } }; From e33b411221577e053ad461350189198284f0628f Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 28 Sep 2017 10:53:20 +0800 Subject: [PATCH 7/7] Adapt reduce_op according to up-to-date dev --- paddle/operators/reduce_op.cc | 41 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index 61b33d4bbd..3ef443d1c7 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -24,20 +24,20 @@ class ReduceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of ReduceOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ReduceOp should not be null."); - auto x_dims = ctx.Input("X")->dims(); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ReduceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ReduceOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); - int dim = ctx.Attr("dim"); + int dim = ctx->Attrs().Get("dim"); if (dim < 0) dim = x_rank + dim; PADDLE_ENFORCE_LT( dim, x_rank, "The dim should be in the range [-rank(input), rank(input))."); - bool keep_dim = ctx.Attr("keep_dim"); + bool keep_dim = ctx->Attrs().Get("keep_dim"); auto dims_vector = vectorize(x_dims); if (keep_dim || x_rank == 1) { dims_vector[dim] = 1; @@ -45,10 +45,10 @@ class ReduceOp : public framework::OperatorWithKernel { dims_vector.erase(dims_vector.begin() + dim); } auto out_dims = framework::make_ddim(dims_vector); - ctx.Output("Out")->Resize(out_dims); + ctx->SetOutputDim("Out", out_dims); if (dim != 0) { - // Only pass LoD when not reducing on the first dim - ctx.ShareLoD("X", /*->*/ "Out"); + // Only pass LoD when not reducing on the first dim. + ctx->ShareLoD("X", /*->*/ "Out"); } } }; @@ -58,21 +58,22 @@ class ReduceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null."); - auto x_dims = ctx.Input("X")->dims(); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); - int dim = ctx.Attr("dim"); + int dim = ctx->Attrs().Get("dim"); if (dim < 0) dim = x_rank + dim; PADDLE_ENFORCE_LT( dim, x_rank, "The dim should be in the range [-rank(input), rank(input))."); - auto *x_grad = - ctx.Output(framework::GradVarName("X")); - if (x_grad) x_grad->Resize(x_dims); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } } };