|
|
|
@ -43,7 +43,6 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* z = ctx.Output<Tensor>("Out");
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* y_data = y->data<T>();
|
|
|
|
|
T* z_data = z->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
|
|
|
|
@ -92,6 +91,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
_x.ShareDataWith(*x);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
z->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto sum_func = [](T a, T b) -> T { return a + b; };
|
|
|
|
|
|
|
|
|
|
TransformFunctor<decltype(sum_func), T,
|
|
|
|
@ -155,6 +155,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto sum_pd = handler.AcquireSumPrimitiveDescriptor(
|
|
|
|
|
{src_x_memory, src_y_memory}, scales, dst_md);
|
|
|
|
|
|
|
|
|
|
T* z_data = z->mutable_data<T>(ctx.GetPlace(),
|
|
|
|
|
sum_pd->dst_primitive_desc().get_size());
|
|
|
|
|
|
|
|
|
|
auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data);
|
|
|
|
|
|
|
|
|
|
std::vector<primitive::at> inputs({*src_x_memory, *src_y_memory});
|
|
|
|
|