|
|
|
@ -26,10 +26,12 @@ using DDim = framework::DDim;
|
|
|
|
|
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
struct SumFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename Dim>
|
|
|
|
@ -95,26 +97,41 @@ template <typename DeviceContext, typename T, typename Functor>
|
|
|
|
|
class ReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
|
switch (rank) {
|
|
|
|
|
case 1:
|
|
|
|
|
ReduceCompute<1>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 2:
|
|
|
|
|
ReduceCompute<2>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 3:
|
|
|
|
|
ReduceCompute<3>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 4:
|
|
|
|
|
ReduceCompute<4>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 5:
|
|
|
|
|
ReduceCompute<5>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 6:
|
|
|
|
|
ReduceCompute<6>(context);
|
|
|
|
|
break;
|
|
|
|
|
bool reduce_all = context.Attr<bool>("reduce_all");
|
|
|
|
|
if (reduce_all) {
|
|
|
|
|
// Flatten and reduce 1-D tensor
|
|
|
|
|
auto* input = context.Input<Tensor>("X");
|
|
|
|
|
auto* output = context.Output<Tensor>("Out");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto x = EigenVector<T>::Flatten(*input);
|
|
|
|
|
auto out = EigenScalar<T>::From(*output);
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto reduce_dim = Eigen::array<int, 1>({{0}});
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, x, out, reduce_dim);
|
|
|
|
|
} else {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
|
switch (rank) {
|
|
|
|
|
case 1:
|
|
|
|
|
ReduceCompute<1>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 2:
|
|
|
|
|
ReduceCompute<2>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 3:
|
|
|
|
|
ReduceCompute<3>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 4:
|
|
|
|
|
ReduceCompute<4>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 5:
|
|
|
|
|
ReduceCompute<5>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 6:
|
|
|
|
|
ReduceCompute<6>(context);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -157,26 +174,46 @@ template <typename DeviceContext, typename T, typename Functor>
|
|
|
|
|
class ReduceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
|
switch (rank) {
|
|
|
|
|
case 1:
|
|
|
|
|
ReduceGradCompute<1>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 2:
|
|
|
|
|
ReduceGradCompute<2>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 3:
|
|
|
|
|
ReduceGradCompute<3>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 4:
|
|
|
|
|
ReduceGradCompute<4>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 5:
|
|
|
|
|
ReduceGradCompute<5>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 6:
|
|
|
|
|
ReduceGradCompute<6>(context);
|
|
|
|
|
break;
|
|
|
|
|
bool reduce_all = context.Attr<bool>("reduce_all");
|
|
|
|
|
if (reduce_all) {
|
|
|
|
|
auto* input0 = context.Input<Tensor>("X");
|
|
|
|
|
auto* input1 = context.Input<Tensor>("Out");
|
|
|
|
|
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto x = EigenVector<T>::Flatten(*input0);
|
|
|
|
|
auto x_reduce = EigenVector<T>::From(*input1);
|
|
|
|
|
auto x_reduce_grad = EigenVector<T>::From(*input2);
|
|
|
|
|
auto x_grad = EigenVector<T>::Flatten(*output);
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto broadcast_dim =
|
|
|
|
|
Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
|
|
|
|
|
broadcast_dim[0]);
|
|
|
|
|
} else {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
|
switch (rank) {
|
|
|
|
|
case 1:
|
|
|
|
|
ReduceGradCompute<1>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 2:
|
|
|
|
|
ReduceGradCompute<2>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 3:
|
|
|
|
|
ReduceGradCompute<3>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 4:
|
|
|
|
|
ReduceGradCompute<4>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 5:
|
|
|
|
|
ReduceGradCompute<5>(context);
|
|
|
|
|
break;
|
|
|
|
|
case 6:
|
|
|
|
|
ReduceGradCompute<6>(context);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|