|
|
|
@ -66,6 +66,9 @@ class ExpandV2Op : public framework::OperatorWithKernel {
|
|
|
|
|
out_shape[i] = -1;
|
|
|
|
|
} else if (expand_shape[i] == -1) {
|
|
|
|
|
out_shape[i] = x_dims[i];
|
|
|
|
|
} else if (expand_shape[i] == -2) {
|
|
|
|
|
// We use -2 to represent the element in expand_shape is a var.
|
|
|
|
|
out_shape[i] = -1;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
expand_shape[i], 0,
|
|
|
|
@ -174,7 +177,7 @@ class ExpandV2GradOp : public framework::OperatorWithKernel {
|
|
|
|
|
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < expand_shape.size(); ++i) {
|
|
|
|
|
if (expand_shape[i] == -1 || x_dim_vec[i] == -1) {
|
|
|
|
|
if (expand_shape[i] < 0 || x_dim_vec[i] == -1) {
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|