|
|
|
@ -56,6 +56,9 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
static framework::DDim ValidateShape(const std::vector<int> shape,
|
|
|
|
|
const framework::DDim &in_dims) {
|
|
|
|
|
const int64_t in_size = framework::product(in_dims);
|
|
|
|
|
auto in_dims_vec = framework::vectorize(in_dims);
|
|
|
|
|
bool all_positive = std::all_of(in_dims_vec.cbegin(), in_dims_vec.cend(),
|
|
|
|
|
[](int64_t i) { return i > 0; });
|
|
|
|
|
// only one dimension can be set to -1, whose size will be automatically
|
|
|
|
|
// infered.
|
|
|
|
|
const int64_t unk_dim_val = -1;
|
|
|
|
@ -88,7 +91,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (unk_dim_idx != -1) {
|
|
|
|
|
if (in_size > 0) {
|
|
|
|
|
if (all_positive) {
|
|
|
|
|
// in_size < 0 and is un-determinate in compile time, skip the check,
|
|
|
|
|
// for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
|
|
|
|
|
// capacity = -24, in_size = -8, output_shape[0] = 0
|
|
|
|
|