Fix compile

tonyyang-svail-feed-op-desgin
Yu Yang 8 years ago
parent fb6a48c62d
commit b9c8637238

@ -44,7 +44,7 @@ class ConcatKernel : public framework::OpKernel<T> {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ConcatGradKernel : public framework::OpKernel { class ConcatGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));

@ -87,7 +87,7 @@ struct MaxOrMinGradFunctor {
}; };
template <typename Place, typename T, typename Functor> template <typename Place, typename T, typename Functor>
class ReduceKernel : public framework::OpKernel { class ReduceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();
@ -141,7 +141,7 @@ class ReduceKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename Functor> template <typename Place, typename T, typename Functor>
class ReduceGradKernel : public framework::OpKernel { class ReduceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();

Loading…
Cancel
Save