|
|
|
@ -18,6 +18,8 @@ limitations under the License. */
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_type_transform.h"
|
|
|
|
|
#include "paddle/fluid/operators/cast_op.h"
|
|
|
|
|
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -25,27 +27,111 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
#define HANDLE_DIM(NDIM, RDIM) \
|
|
|
|
|
if (ndim == NDIM && rdim == RDIM) { \
|
|
|
|
|
ReduceFunctor<DeviceContext, T, NDIM, RDIM, Functor>( \
|
|
|
|
|
ReduceFunctor<DeviceContext, OutT, NDIM, RDIM, Functor>( \
|
|
|
|
|
context.template device_context<DeviceContext>(), *input, output, \
|
|
|
|
|
dims, keep_dim); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T, typename Functor>
|
|
|
|
|
struct ReduceKernelFunctor {
|
|
|
|
|
const Tensor* input;
|
|
|
|
|
Tensor* output;
|
|
|
|
|
std::vector<int> dims;
|
|
|
|
|
bool keep_dim;
|
|
|
|
|
bool reduce_all;
|
|
|
|
|
const framework::ExecutionContext& context;
|
|
|
|
|
ReduceKernelFunctor(const Tensor* input, Tensor* output,
|
|
|
|
|
const std::vector<int>& dims, bool keep_dim,
|
|
|
|
|
bool reduce_all,
|
|
|
|
|
const framework::ExecutionContext& context)
|
|
|
|
|
: input(input),
|
|
|
|
|
output(output),
|
|
|
|
|
dims(dims),
|
|
|
|
|
keep_dim(keep_dim),
|
|
|
|
|
reduce_all(reduce_all),
|
|
|
|
|
context(context) {}
|
|
|
|
|
|
|
|
|
|
template <typename OutT>
|
|
|
|
|
void apply() const {
|
|
|
|
|
output->mutable_data<OutT>(context.GetPlace());
|
|
|
|
|
if (reduce_all) {
|
|
|
|
|
// Flatten and reduce 1-D tensor
|
|
|
|
|
auto x = EigenVector<OutT>::Flatten(*input);
|
|
|
|
|
auto out = EigenScalar<OutT>::From(*output);
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto reduce_dim = Eigen::array<int, 1>({{0}});
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, &x, &out, reduce_dim);
|
|
|
|
|
} else {
|
|
|
|
|
int ndim = input->dims().size();
|
|
|
|
|
int rdim = dims.size();
|
|
|
|
|
HANDLE_DIM(4, 3);
|
|
|
|
|
HANDLE_DIM(4, 2);
|
|
|
|
|
HANDLE_DIM(4, 1);
|
|
|
|
|
HANDLE_DIM(3, 2);
|
|
|
|
|
HANDLE_DIM(3, 1);
|
|
|
|
|
HANDLE_DIM(2, 1);
|
|
|
|
|
HANDLE_DIM(1, 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <typename DeviceContext, typename T, typename Functor>
|
|
|
|
|
class ReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
bool reduce_all = context.Attr<bool>("reduce_all");
|
|
|
|
|
auto* output = context.Output<Tensor>("Out");
|
|
|
|
|
auto dims = context.Attr<std::vector<int>>("dim");
|
|
|
|
|
bool keep_dim = context.Attr<bool>("keep_dim");
|
|
|
|
|
int out_dtype = context.Attr<int>("out_dtype");
|
|
|
|
|
framework::proto::VarType::Type cast_out_dtype;
|
|
|
|
|
|
|
|
|
|
if (out_dtype < 0) {
|
|
|
|
|
auto* cast_input = context.Input<Tensor>("X");
|
|
|
|
|
cast_out_dtype =
|
|
|
|
|
static_cast<framework::proto::VarType::Type>(cast_input->type());
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
cast_out_dtype,
|
|
|
|
|
ReduceKernelFunctor<DeviceContext, T, Functor>(
|
|
|
|
|
cast_input, output, dims, keep_dim, reduce_all, context));
|
|
|
|
|
} else {
|
|
|
|
|
Tensor tmp_tensor;
|
|
|
|
|
cast_out_dtype = static_cast<framework::proto::VarType::Type>(out_dtype);
|
|
|
|
|
auto* input = context.Input<Tensor>("X");
|
|
|
|
|
|
|
|
|
|
tmp_tensor.Resize(input->dims());
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
cast_out_dtype,
|
|
|
|
|
CastOpFunctor<DeviceContext, T>(
|
|
|
|
|
input, &tmp_tensor,
|
|
|
|
|
context.template device_context<DeviceContext>()));
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
cast_out_dtype,
|
|
|
|
|
ReduceKernelFunctor<DeviceContext, T, Functor>(
|
|
|
|
|
&tmp_tensor, output, dims, keep_dim, reduce_all, context));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename OutT, typename Functor>
|
|
|
|
|
class BoolReduceKernel : public framework::OpKernel<OutT> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
bool reduce_all = context.Attr<bool>("reduce_all");
|
|
|
|
|
auto* input = context.Input<Tensor>("X");
|
|
|
|
|
auto* output = context.Output<Tensor>("Out");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
output->mutable_data<OutT>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto dims = context.Attr<std::vector<int>>("dim");
|
|
|
|
|
bool keep_dim = context.Attr<bool>("keep_dim");
|
|
|
|
|
|
|
|
|
|
if (reduce_all) {
|
|
|
|
|
// Flatten and reduce 1-D tensor
|
|
|
|
|
auto x = EigenVector<T>::Flatten(*input);
|
|
|
|
|
auto out = EigenScalar<T>::From(*output);
|
|
|
|
|
auto x = EigenVector<OutT>::Flatten(*input);
|
|
|
|
|
auto out = EigenScalar<OutT>::From(*output);
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto reduce_dim = Eigen::array<int, 1>({{0}});
|
|
|
|
@ -74,18 +160,17 @@ class ReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T, typename Functor,
|
|
|
|
|
bool kNoNeedBufferX = false, bool kNoNeedBufferY = false>
|
|
|
|
|
class ReduceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
void ComputeFromInput(const Tensor* input2,
|
|
|
|
|
const framework::ExecutionContext& context) const {
|
|
|
|
|
bool reduce_all = context.Attr<bool>("reduce_all");
|
|
|
|
|
auto dims = context.Attr<std::vector<int>>("dim");
|
|
|
|
|
|
|
|
|
|
auto* input0 = context.Input<Tensor>("X");
|
|
|
|
|
auto* input1 = context.Input<Tensor>("Out");
|
|
|
|
|
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
@ -152,6 +237,26 @@ class ReduceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
int in_dtype = context.Attr<int>("in_dtype");
|
|
|
|
|
if (in_dtype >= 0) {
|
|
|
|
|
Tensor tmp_tensor;
|
|
|
|
|
auto* pre_input = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto in_kernel_type =
|
|
|
|
|
framework::OpKernelType(pre_input->type(), context.GetPlace());
|
|
|
|
|
auto out_kernel_type = framework::OpKernelType(
|
|
|
|
|
static_cast<framework::proto::VarType::Type>(in_dtype),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
framework::TransDataType(in_kernel_type, out_kernel_type, *pre_input,
|
|
|
|
|
&tmp_tensor);
|
|
|
|
|
ComputeFromInput(&tmp_tensor, context);
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
ComputeFromInput(input2, context);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ReduceOp : public framework::OperatorWithKernel {
|
|
|
|
@ -267,6 +372,12 @@ class ReduceGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
int in_dtype = ctx.Attr<int>("in_dtype");
|
|
|
|
|
if (in_dtype >= 0) {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
static_cast<framework::proto::VarType::Type>(in_dtype),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
|
|
|
|
ctx, framework::GradVarName("Out")),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
@ -295,6 +406,16 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"(bool, default false) "
|
|
|
|
|
"If true, output a scalar reduced along all dimensions.")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<int>("in_dtype",
|
|
|
|
|
"(int, default -1)"
|
|
|
|
|
"The dtype of input, default value is -1, the user could not "
|
|
|
|
|
"set this value.")
|
|
|
|
|
.SetDefault(-1);
|
|
|
|
|
AddAttr<int>(
|
|
|
|
|
"out_dtype",
|
|
|
|
|
"(int, default -1)"
|
|
|
|
|
"The dtype of output, default value is -1, the dtype is same as intput")
|
|
|
|
|
.SetDefault(-1);
|
|
|
|
|
AddComment(string::Sprintf(R"DOC(
|
|
|
|
|
%s Operator.
|
|
|
|
|
|
|
|
|
|