|
|
|
@ -42,19 +42,23 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (shape[i] == -1) {
|
|
|
|
|
neg_dims_idx.push_back(i);
|
|
|
|
|
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
|
|
|
|
|
"Only one dimension of Attr(shape) can be -1.");
|
|
|
|
|
"Only one dimension of Attr(shape) can be unknown.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// capacity check
|
|
|
|
|
int64_t capacity =
|
|
|
|
|
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
|
|
|
|
int64_t in_size = framework::product(x_dims);
|
|
|
|
|
if (neg_dims_idx.size() == 1) {
|
|
|
|
|
shape[neg_dims_idx[0]] = in_size / (-capacity);
|
|
|
|
|
PADDLE_ENFORCE(shape[neg_dims_idx[0]] > 0,
|
|
|
|
|
"The size of Input(X) mismatches with Attr(shape).");
|
|
|
|
|
// dim infer
|
|
|
|
|
shape[neg_dims_idx[0]] = in_size / (-capacity);
|
|
|
|
|
// recalculate capacity
|
|
|
|
|
capacity = std::accumulate(shape.begin(), shape.end(), 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
}
|
|
|
|
|
// capacity check
|
|
|
|
|
PADDLE_ENFORCE(capacity == in_size,
|
|
|
|
|
"The size of Input(X) mismatches with Attr(shape).");
|
|
|
|
|
// resize output
|
|
|
|
|
std::vector<int64_t> shape_int64(shape.size(), 0);
|
|
|
|
|
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
|
|
|
|
|