|
|
|
@ -26,10 +26,12 @@ using DDim = framework::DDim;
|
|
|
|
|
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
struct SumFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename Dim>
|
|
|
|
@ -95,6 +97,20 @@ template <typename DeviceContext, typename T, typename Functor>
|
|
|
|
|
class ReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
bool reduce_all = context.Attr<bool>("reduce_all");
|
|
|
|
|
if (reduce_all) {
|
|
|
|
|
// Flatten and reduce 1-D tensor
|
|
|
|
|
auto* input = context.Input<Tensor>("X");
|
|
|
|
|
auto* output = context.Output<Tensor>("Out");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto x = EigenVector<T>::Flatten(*input);
|
|
|
|
|
auto out = EigenScalar<T>::From(*output);
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto reduce_dim = Eigen::array<int, 1>({{0}});
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, x, out, reduce_dim);
|
|
|
|
|
} else {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
|
switch (rank) {
|
|
|
|
|
case 1:
|
|
|
|
@ -117,6 +133,7 @@ class ReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
template <size_t D>
|
|
|
|
@ -157,6 +174,25 @@ template <typename DeviceContext, typename T, typename Functor>
|
|
|
|
|
class ReduceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
bool reduce_all = context.Attr<bool>("reduce_all");
|
|
|
|
|
if (reduce_all) {
|
|
|
|
|
auto* input0 = context.Input<Tensor>("X");
|
|
|
|
|
auto* input1 = context.Input<Tensor>("Out");
|
|
|
|
|
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto x = EigenVector<T>::Flatten(*input0);
|
|
|
|
|
auto x_reduce = EigenVector<T>::From(*input1);
|
|
|
|
|
auto x_reduce_grad = EigenVector<T>::From(*input2);
|
|
|
|
|
auto x_grad = EigenVector<T>::Flatten(*output);
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto broadcast_dim =
|
|
|
|
|
Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
|
|
|
|
|
broadcast_dim[0]);
|
|
|
|
|
} else {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
|
switch (rank) {
|
|
|
|
|
case 1:
|
|
|
|
@ -179,6 +215,7 @@ class ReduceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
template <size_t D>
|
|
|
|
|