|
|
|
@ -181,13 +181,13 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
|
|
|
|
|
return RET_INFER_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> out_shape;
|
|
|
|
|
out_shape_.clear();
|
|
|
|
|
if (inputs_.size() == kDoubleNum) {
|
|
|
|
|
auto shape_tensor = inputs_.at(1);
|
|
|
|
|
if (shape_tensor->IsConst()) {
|
|
|
|
|
if (shape_tensor->data_c() == nullptr || (shape_tensor->shape().size() == 1 && shape_tensor->shape()[0] == 0)) {
|
|
|
|
|
MS_LOG(DEBUG) << "reshape to a scalar.";
|
|
|
|
|
output->set_shape(out_shape);
|
|
|
|
|
output->set_shape(out_shape_);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -199,23 +199,23 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
|
|
|
|
|
switch (shape_tensor->data_type()) {
|
|
|
|
|
case kNumberTypeInt8: {
|
|
|
|
|
auto data = reinterpret_cast<int8_t *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<int8_t>(data, inputs_, &out_shape, shape_size);
|
|
|
|
|
CalShape<int8_t>(data, inputs_, &out_shape_, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
case kNumberTypeInt32: {
|
|
|
|
|
auto data = reinterpret_cast<int32_t *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
|
|
|
|
|
CalShape<int32_t>(data, inputs_, &out_shape_, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
case kNumberTypeInt64: {
|
|
|
|
|
auto data = reinterpret_cast<int64_t *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<int64_t>(data, inputs_, &out_shape, shape_size);
|
|
|
|
|
CalShape<int64_t>(data, inputs_, &out_shape_, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
case kNumberTypeFloat: {
|
|
|
|
|
auto data = reinterpret_cast<float *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<float>(data, inputs_, &out_shape, shape_size);
|
|
|
|
|
CalShape<float>(data, inputs_, &out_shape_, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
case kNumberTypeUInt32: {
|
|
|
|
|
auto data = reinterpret_cast<uint32_t *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<uint32_t>(data, inputs_, &out_shape, shape_size);
|
|
|
|
|
CalShape<uint32_t>(data, inputs_, &out_shape_, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
default: {
|
|
|
|
|
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();
|
|
|
|
@ -224,18 +224,18 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
|
|
|
|
|
}
|
|
|
|
|
} else if (inputs_.size() == kSingleNum) {
|
|
|
|
|
for (size_t i = 0; i < GetShape().size(); ++i) {
|
|
|
|
|
out_shape.push_back(GetShape().at(i));
|
|
|
|
|
out_shape_.push_back(GetShape().at(i));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "inputs tensor size invalid.";
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
auto ret = CalNewShape(inputs_.front(), &out_shape);
|
|
|
|
|
auto ret = CalNewShape(inputs_.front(), &out_shape_);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "CalNewShape error";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
output->set_shape(out_shape);
|
|
|
|
|
output->set_shape(out_shape_);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
} // namespace lite
|
|
|
|
|