Add reduce_op

update-doc-pybind
guosheng 8 years ago
parent a2393fc1bd
commit 3994e91a67

@ -0,0 +1,207 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/reduce_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using framework::DDim;
class ReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported");
int dim = static_cast<int>(ctx.Attr<int>("dim"));
if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT(
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)]");
bool keep_dim = true; // TODO;
auto dims_vector = vectorize(x_dims);
if (keep_dim || x_rank == 1) {
dims_vector[dim] = 1;
} else {
dims_vector.erase(dims_vector.begin() + dim);
}
auto out_dims = framework::make_ddim(dims_vector);
ctx.Output<Tensor>("Out")->Resize(out_dims);
}
};
class ReduceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported");
int dim = static_cast<int>(ctx.Attr<int>("dim"));
if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT(
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)]");
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
if (x_grad) x_grad->Resize(x_dims);
}
};
class ReduceSumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReduceSumOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor) The input tensor. Tensors with rank at most 6 are supported");
AddOutput("Out", "(Tensor) The result tensor.");
AddComment(R"DOC(
ReduceMean operator computes the sum of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC");
AddAttr<int>("dim",
"(int, default 0) The dimension to reduce. "
"Must be in the range [-rank(input), rank(input)]")
.SetDefault(0);
AddAttr<bool>("keep_dim",
"(bool, default fasle) "
"If true, retain the reduced dimension with length 1.")
.SetDefault(false);
}
};
class ReduceMeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReduceMeanOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor) The input tensor. Tensors with rank at most 6 are supported");
AddOutput("Out", "(Tensor) The result tensor.");
AddComment(R"DOC(
ReduceMean operator computes the mean of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC");
AddAttr<int>("dim",
"(int, default 0) The dimension to reduce. "
"Must be in the range [-rank(input), rank(input)]")
.SetDefault(0);
AddAttr<bool>("keep_dim",
"(bool, default fasle) "
"If true, retain the reduced dimension with length 1.")
.SetDefault(false);
}
};
class ReduceMaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReduceMaxOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor) The input tensor. Tensors with rank at most 6 are supported");
AddOutput("Out", "(Tensor) The result tensor.");
AddComment(R"DOC(
ReduceMax operator computes the maximum of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC");
AddAttr<int>("dim",
"(int, default 0) The dimension to reduce. "
"Must be in the range [-rank(input), rank(input)]")
.SetDefault(0);
AddAttr<bool>("keep_dim",
"(bool, default fasle) "
"If true, retain the reduced dimension with length 1.")
.SetDefault(false);
}
};
class ReduceMinOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReduceMinOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor) The input tensor. Tensors with rank at most 6 are supported");
AddOutput("Out", "(Tensor) The result tensor.");
AddComment(R"DOC(
ReduceMin operator computes the minimum of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC");
AddAttr<int>("dim",
"(int, default 0) The dimension to reduce. "
"Must be in the range [-rank(input), rank(input)]")
.SetDefault(0);
AddAttr<bool>("keep_dim",
"(bool, default fasle) "
"If true, retain the reduced dimension with length 1.")
.SetDefault(false);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad,
ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_sum,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::SumFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_sum_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::SumGradFunctor>);
REGISTER_OP(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker,
reduce_mean_grad, ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_mean,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MeanFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_mean_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::MeanGradFunctor>);
REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad,
ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_max,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MaxFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_max_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::MaxOrMinGradFunctor>);
REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad,
ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_min,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MinFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_min_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::MaxOrMinGradFunctor>);

@ -0,0 +1,46 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/reduce_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
reduce_sum,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::SumFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_sum_grad,
ops::ReduceGradEigenKernel<paddle::platform::GPUPlace,
float, ops::SumGradFunctor>);
REGISTER_OP_GPU_KERNEL(
reduce_mean,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MeanFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_mean_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::MeanGradFunctor>);
REGISTER_OP_GPU_KERNEL(
reduce_max,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MaxFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_max_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::MaxOrMinGradFunctor>);
REGISTER_OP_GPU_KERNEL(
reduce_min,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MinFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_min_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::MaxOrMinGradFunctor>);

File diff suppressed because it is too large Load Diff

@ -0,0 +1,92 @@
import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
from paddle.v2.framework.op import Operator
class TestSumOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': -2}
out = self.inputs['X'].sum(axis=self.attrs['dim'])
self.outputs = {'Out': out}
class TestSumGradOp(GradientChecker):
def test_normal(self):
op = Operator("reduce_sum", X="X", Out="Out", dim=-2)
# use small size to decrease the error of numerical calculation
inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.check_grad(op, inputs, set(["X"]), "Out")
def test_1d_tensor(self):
op = Operator("reduce_sum", X="X", Out="Out", dim=0)
# use small size to decrease the error of numerical calculation
inputs = {'X': np.random.random(10).astype("float32")}
self.check_grad(op, inputs, set(["X"]), "Out")
class TestKeepdimSumOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': -2}
out = self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True)
self.outputs = {'Out': out}
class TestMeanOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': -1}
out = self.inputs['X'].mean(axis=self.attrs['dim'])
self.outputs = {'Out': out}
class TestMeanGradOp(GradientChecker):
def test_normal(self):
op = Operator("reduce_mean", X="X", Out="Out", dim=-2)
# use small size to decrease the error of numerical calculation
inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.check_grad(op, inputs, set(["X"]), "Out")
def test_1d_tensor(self):
op = Operator("reduce_mean", X="X", Out="Out", dim=0)
# use small size to decrease the error of numerical calculation
inputs = {'X': np.random.random(10).astype("float32")}
self.check_grad(op, inputs, set(["X"]), "Out")
class TestMaxOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "reduce_max"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': -1}
out = self.inputs['X'].max(axis=self.attrs['dim'])
self.outputs = {'Out': out}
class TestMinOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "reduce_max"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': -2}
out = self.inputs['X'].min(axis=self.attrs['dim'])
self.outputs = {'Out': out}
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save