|
|
|
@ -71,7 +71,9 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx,
|
|
|
|
|
const framework::Tensor* x,
|
|
|
|
|
const framework::Tensor* y, framework::Tensor* z) {
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
if (x->numel() >= y->numel()) {
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims = y->dims();
|
|
|
|
|
if (x_dims.size() >= y_dims.size()) {
|
|
|
|
|
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
|
|
|
|
|
MulFunctor<T>(), z);
|
|
|
|
|
} else {
|
|
|
|
@ -118,7 +120,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
z->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
if (x.numel() == y->numel()) {
|
|
|
|
|
auto dims_equal = x.dims() == y->dims();
|
|
|
|
|
if (dims_equal) {
|
|
|
|
|
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
|
|
|
|
|
same_dims_mul(ctx, &x, y, z);
|
|
|
|
|
} else {
|
|
|
|
|