cumsum operator (#8288)
parent
69712ef276
commit
725e64486a
@ -0,0 +1,111 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
#include "paddle/operators/detail/safe_ref.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename Functor>
|
||||
class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
||||
public:
|
||||
using T = typename Functor::ELEMENT_TYPE;
|
||||
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
|
||||
"Cannot get input tensor X, variable name = %s",
|
||||
context.op().Input("X"));
|
||||
|
||||
auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
|
||||
"Cannot get output tensor Out, variable name = %s",
|
||||
context.op().Output("Out"));
|
||||
int axis = context.Attr<int>("axis");
|
||||
bool exclusive = context.Attr<bool>("exclusive");
|
||||
bool reverse = context.Attr<bool>("reverse");
|
||||
auto x_dims = X.dims();
|
||||
if (axis == -1) {
|
||||
axis = x_dims.size() - 1;
|
||||
}
|
||||
PADDLE_ENFORCE_LT(
|
||||
axis, x_dims.size(),
|
||||
"axis should be less than the dimensiotn of the input tensor");
|
||||
Out.mutable_data<T>(context.GetPlace());
|
||||
|
||||
int pre = 1;
|
||||
int post = 1;
|
||||
int mid = x_dims[axis];
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
pre *= x_dims[i];
|
||||
}
|
||||
for (int i = axis + 1; i < x_dims.size(); ++i) {
|
||||
post *= x_dims[i];
|
||||
}
|
||||
|
||||
auto x = framework::EigenVector<T>::Flatten(X);
|
||||
auto out = framework::EigenVector<T>::Flatten(Out);
|
||||
auto* place =
|
||||
context.template device_context<DeviceContext>().eigen_device();
|
||||
|
||||
using IndexT = Eigen::DenseIndex;
|
||||
if (pre == 1) {
|
||||
if (post == 1) {
|
||||
ComputeImp(*place, Eigen::DSizes<IndexT, 1>(mid), x, out,
|
||||
/* axis= */ 0, reverse, exclusive);
|
||||
} else {
|
||||
ComputeImp(*place, Eigen::DSizes<IndexT, 2>(mid, post), x, out,
|
||||
/* axis= */ 0, reverse, exclusive);
|
||||
}
|
||||
} else {
|
||||
if (post == 1) {
|
||||
ComputeImp(*place, Eigen::DSizes<IndexT, 2>(pre, mid), x, out,
|
||||
/* axis= */ 1, reverse, exclusive);
|
||||
} else {
|
||||
ComputeImp(*place, Eigen::DSizes<IndexT, 3>(pre, mid, post), x, out,
|
||||
/* axis= */ 1, reverse, exclusive);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename Device, typename Dim, typename X, typename Out>
|
||||
void ComputeImp(Device d, const Dim& dims, X x, Out out, int axis,
|
||||
bool reverse, bool exclusive) const {
|
||||
if (!reverse) {
|
||||
out.reshape(dims).device(d) = Functor()(x.reshape(dims), axis, exclusive);
|
||||
} else {
|
||||
std::array<bool, Dim::count> rev;
|
||||
rev.fill(false);
|
||||
rev[axis] = reverse;
|
||||
out.reshape(dims).device(d) =
|
||||
Functor()(x.reshape(dims).reverse(rev), axis, exclusive).reverse(rev);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CumsumFunctor {
|
||||
using ELEMENT_TYPE = T;
|
||||
template <typename X>
|
||||
const typename X::TensorScanSumOp operator()(X x, int axis,
|
||||
bool exclusive) const {
|
||||
return x.cumsum(axis, exclusive);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,82 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/cum_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class CumOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CumsumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "Input of Cumsum operator");
|
||||
AddOutput("Out", "Output of Cumsum operator");
|
||||
AddAttr<int>("axis",
|
||||
"(int, default -1). The dimenstion to accumulate along. "
|
||||
"-1 means the last dimenstion")
|
||||
.SetDefault(-1)
|
||||
.EqualGreaterThan(-1);
|
||||
AddAttr<bool>("exclusive",
|
||||
"bool, default false). Whether to perform exclusive cumsum")
|
||||
.SetDefault(false);
|
||||
AddAttr<bool>("reverse",
|
||||
"bool, default false). If true, the cumsum is performed in "
|
||||
"the reversed direction")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
The cumulative sum of the elements along a given axis.
|
||||
By default, the first element of the result is the same of the first element of
|
||||
the input. If exlusive is true, the first element of the result is 0.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CumsumGradMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
auto *grad_op = new framework::OpDesc();
|
||||
grad_op->SetType("cumsum");
|
||||
grad_op->SetInput("X", OutputGrad("Out"));
|
||||
grad_op->SetOutput("Out", InputGrad("X"));
|
||||
grad_op->SetAttr("axis", Attr<int>("axis"));
|
||||
grad_op->SetAttr("reverse", !Attr<bool>("reverse"));
|
||||
grad_op->SetAttr("exclusive", Attr<bool>("exclusive"));
|
||||
return std::unique_ptr<framework::OpDesc>(grad_op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CPU = paddle::platform::CPUDeviceContext;
|
||||
|
||||
REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, ops::CumsumGradMaker);
|
||||
REGISTER_OP_CPU_KERNEL(cumsum, ops::CumKernel<CPU, ops::CumsumFunctor<float>>,
|
||||
ops::CumKernel<CPU, ops::CumsumFunctor<double>>,
|
||||
ops::CumKernel<CPU, ops::CumsumFunctor<int>>)
|
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/cum_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CUDA = paddle::platform::CUDADeviceContext;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(cumsum, ops::CumKernel<CUDA, ops::CumsumFunctor<float>>,
|
||||
ops::CumKernel<CUDA, ops::CumsumFunctor<double>>,
|
||||
ops::CumKernel<CUDA, ops::CumsumFunctor<int>>)
|
@ -0,0 +1,127 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestSumOp1(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cumsum"
|
||||
self.attrs = {'axis': 2}
|
||||
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
|
||||
self.outputs = {'Out': self.inputs['X'].cumsum(axis=2)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestSumOp2(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cumsum"
|
||||
self.attrs = {'axis': -1, 'reverse': True}
|
||||
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
|
||||
self.outputs = {
|
||||
'Out': np.flip(
|
||||
np.flip(
|
||||
self.inputs['X'], axis=2).cumsum(axis=2), axis=2)
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestSumOp3(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cumsum"
|
||||
self.attrs = {'axis': 1}
|
||||
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
|
||||
self.outputs = {'Out': self.inputs['X'].cumsum(axis=1)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestSumOp4(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cumsum"
|
||||
self.attrs = {'axis': 0}
|
||||
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
|
||||
self.outputs = {'Out': self.inputs['X'].cumsum(axis=0)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestSumOp5(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cumsum"
|
||||
self.inputs = {'X': np.random.random((5, 6)).astype("float64")}
|
||||
self.outputs = {'Out': self.inputs['X'].cumsum(axis=1)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestSumOp7(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cumsum"
|
||||
self.inputs = {'X': np.random.random((6)).astype("float64")}
|
||||
self.outputs = {'Out': self.inputs['X'].cumsum(axis=0)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestSumOp8(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cumsum"
|
||||
self.attrs = {'axis': 2, "exclusive": True}
|
||||
a = np.random.random((5, 6, 3)).astype("float64")
|
||||
self.inputs = {'X': a}
|
||||
self.outputs = {
|
||||
'Out': np.concatenate(
|
||||
(np.zeros(
|
||||
(5, 6, 1), dtype=np.float64), a[:, :, :-1].cumsum(axis=2)),
|
||||
axis=2)
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue