add_depthwiseConv_op_gpu
Yang Yu 7 years ago
parent 2b9b6c3d32
commit 2024489bb8

@ -54,7 +54,7 @@ class CompareOpKernel
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
ElementwiseComputeEx<Functor, DeviceContext, T>(context);
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context);
}
};

@ -373,7 +373,7 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
z->mutable_data<OutType>(ctx.GetPlace());
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), Functor());

Loading…
Cancel
Save