From a373aa76451b83a9eeb7617f92bdc9c1ea6e9ff6 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Wed, 24 Feb 2021 18:24:22 +0800 Subject: [PATCH] fix the bug in expand_v2 op (#30984) * update, test=develop --- paddle/fluid/operators/expand_v2_op.cc | 5 ++++- python/paddle/tensor/manipulation.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index a1ee47b7f9..05ab0f6c8d 100644 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -66,6 +66,9 @@ class ExpandV2Op : public framework::OperatorWithKernel { out_shape[i] = -1; } else if (expand_shape[i] == -1) { out_shape[i] = x_dims[i]; + } else if (expand_shape[i] == -2) { + // We use -2 to represent the element in expand_shape is a var. + out_shape[i] = -1; } else { PADDLE_ENFORCE_GT( expand_shape[i], 0, @@ -174,7 +177,7 @@ class ExpandV2GradOp : public framework::OperatorWithKernel { x_dim_vec.insert(x_dim_vec.begin(), diff, -1); for (size_t i = 0; i < expand_shape.size(); ++i) { - if (expand_shape[i] == -1 || x_dim_vec[i] == -1) { + if (expand_shape[i] < 0 || x_dim_vec[i] == -1) { continue; } else { if (ctx->IsRuntime()) { diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 2583c4b95d..9bcda74d11 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1448,7 +1448,7 @@ def expand(x, shape, name=None): attrs_expand_shape = [] for idx, shape in enumerate(list_expand_shape): if isinstance(shape, Variable): - attrs_expand_shape.append(-1) + attrs_expand_shape.append(-2) else: attrs_expand_shape.append(shape) assert shape > 0 or shape == -1, (