fix broadcast bug;test=develop (#21898)

release/1.7
danleifeng 5 years ago committed by Yi Liu
parent e0d8b8f5c0
commit b7697f6218

@ -25,7 +25,9 @@ void default_elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) {
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z);
} else {

@ -31,7 +31,9 @@ void default_elementwise_div(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) {
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
DivFunctor<T>(), z);
} else {

@ -71,7 +71,9 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) {
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
} else {
@ -118,7 +120,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
}
z->mutable_data<T>(ctx.GetPlace());
if (x.numel() == y->numel()) {
auto dims_equal = x.dims() == y->dims();
if (dims_equal) {
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
same_dims_mul(ctx, &x, y, z);
} else {

@ -26,7 +26,9 @@ void default_elementwise_sub(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) {
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} else {

Loading…
Cancel
Save