|
|
|
@ -31,12 +31,19 @@ static inline framework::DDim ComputeAndCheckShape(
|
|
|
|
|
auto out_dims = inputs_dims[0];
|
|
|
|
|
size_t in_zero_dims_size = out_dims.size();
|
|
|
|
|
for (size_t i = 1; i < n; i++) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of input[0] and input[%d] "
|
|
|
|
|
"is expected to be equal."
|
|
|
|
|
"But received input[0]'s shape = "
|
|
|
|
|
"[%s], input[%d]'s shape = [%s].",
|
|
|
|
|
i, inputs_dims[0], i, inputs_dims[i]));
|
|
|
|
|
for (size_t j = 0; j < in_zero_dims_size; j++) {
|
|
|
|
|
if (j == axis) {
|
|
|
|
|
if (is_runtime) {
|
|
|
|
|
out_dims[axis] += inputs_dims[i][j];
|
|
|
|
|
} else {
|
|
|
|
|
if (inputs_dims[i][j] == -1) {
|
|
|
|
|
if (inputs_dims[i][j] == -1 || out_dims[j] == -1) {
|
|
|
|
|
out_dims[axis] = -1;
|
|
|
|
|
} else {
|
|
|
|
|
out_dims[axis] += inputs_dims[i][j];
|
|
|
|
@ -55,6 +62,9 @@ static inline framework::DDim ComputeAndCheckShape(
|
|
|
|
|
"[%s], input[%d]'s shape = [%s].",
|
|
|
|
|
j, i, inputs_dims[0], i, inputs_dims[i]));
|
|
|
|
|
}
|
|
|
|
|
if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) {
|
|
|
|
|
out_dims[j] = inputs_dims[i][j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|