|
|
|
@ -34,21 +34,33 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
|
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
// TODO(qiao) change batch_size
|
|
|
|
|
for (size_t i = 1; i < shape.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE(shape[i] > 0,
|
|
|
|
|
"Each dimension of Attr(shape) "
|
|
|
|
|
"must be positive except the first one.");
|
|
|
|
|
}
|
|
|
|
|
if (shape[0] < 0) {
|
|
|
|
|
shape[0] = x_dims[0];
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> neg_dims_idx;
|
|
|
|
|
// set some dimension to -1 if it is unknown
|
|
|
|
|
const int unknown_size = -1;
|
|
|
|
|
for (size_t i = 0; i < shape.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
|
|
|
|
|
"Each dimension of Attr(shape) must be positive or %d.",
|
|
|
|
|
unknown_size);
|
|
|
|
|
if (shape[i] == unknown_size) {
|
|
|
|
|
neg_dims_idx.push_back(i);
|
|
|
|
|
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
|
|
|
|
|
"Only one dimension of Attr(shape) can be unknown.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// capacity check
|
|
|
|
|
|
|
|
|
|
int64_t capacity =
|
|
|
|
|
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
|
|
|
|
int64_t in_size = framework::product(x_dims);
|
|
|
|
|
PADDLE_ENFORCE_EQ(capacity, in_size,
|
|
|
|
|
"The size of Input(X) mismatches with Attr(shape).");
|
|
|
|
|
if (neg_dims_idx.size() == 1) {
|
|
|
|
|
// dim infer
|
|
|
|
|
shape[neg_dims_idx[0]] = in_size / (-capacity);
|
|
|
|
|
// recalculate capacity
|
|
|
|
|
capacity = shape[neg_dims_idx[0]] * (-capacity);
|
|
|
|
|
}
|
|
|
|
|
// capacity check
|
|
|
|
|
PADDLE_ENFORCE(capacity == in_size,
|
|
|
|
|
"The size of Input(X) mismatches with Attr(shape).");
|
|
|
|
|
// resize output
|
|
|
|
|
std::vector<int64_t> shape_int64(shape.size(), 0);
|
|
|
|
|
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
|
|
|
|
@ -88,6 +100,9 @@ the tensor X into a 2-D tensor:
|
|
|
|
|
|
|
|
|
|
[[1, 2, 3, 4]]
|
|
|
|
|
|
|
|
|
|
One dimension in the target shape can be set -1, representing that its
|
|
|
|
|
size is unknown. In this case, the real dimension will be infered from
|
|
|
|
|
the original shape of Input(X) and other dimensions in the target shape.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|