|
|
|
@ -85,14 +85,17 @@ class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
|
|
|
template <typename Device, typename Dim, typename X, typename Out>
|
|
|
|
|
void ComputeImp(Device d, const Dim& dims, X x, Out out, int axis,
|
|
|
|
|
bool reverse, bool exclusive) const {
|
|
|
|
|
Functor func();
|
|
|
|
|
if (!reverse) {
|
|
|
|
|
out.reshape(dims).device(d) = Functor()(x.reshape(dims), axis, exclusive);
|
|
|
|
|
out.reshape(dims).device(d) =
|
|
|
|
|
func.apply(x.reshape(dims), axis, exclusive);
|
|
|
|
|
} else {
|
|
|
|
|
std::array<bool, Dim::count> rev;
|
|
|
|
|
rev.fill(false);
|
|
|
|
|
rev[axis] = reverse;
|
|
|
|
|
out.reshape(dims).device(d) =
|
|
|
|
|
Functor()(x.reshape(dims).reverse(rev), axis, exclusive).reverse(rev);
|
|
|
|
|
func.apply(x.reshape(dims).reverse(rev), axis, exclusive)
|
|
|
|
|
.reverse(rev);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -101,8 +104,7 @@ template <typename T>
|
|
|
|
|
struct CumsumFunctor {
|
|
|
|
|
using ELEMENT_TYPE = T;
|
|
|
|
|
template <typename X>
|
|
|
|
|
const typename X::TensorScanSumOp operator()(X x, int axis,
|
|
|
|
|
bool exclusive) const {
|
|
|
|
|
const typename X::TensorScanSumOp apply(X x, int axis, bool exclusive) const {
|
|
|
|
|
return x.cumsum(axis, exclusive);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|