|
|
|
@ -87,7 +87,7 @@ struct MaxOrMinGradFunctor {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename Functor>
|
|
|
|
|
class ReduceKernel : public framework::OpKernel {
|
|
|
|
|
class ReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
@ -141,7 +141,7 @@ class ReduceKernel : public framework::OpKernel {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename Functor>
|
|
|
|
|
class ReduceGradKernel : public framework::OpKernel {
|
|
|
|
|
class ReduceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
|