|
|
@ -39,7 +39,7 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
lite::STATUS ret;
|
|
|
|
lite::STATUS ret;
|
|
|
|
if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) {
|
|
|
|
if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) {
|
|
|
|
ret = MoveTensorLiteData(reinterpret_cast<lite::TensorList *>(dst_tensor),
|
|
|
|
ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor),
|
|
|
|
reinterpret_cast<lite::TensorList *>(src_tensor));
|
|
|
|
reinterpret_cast<lite::TensorList *>(src_tensor));
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
ret = MoveTensorData(dst_tensor, src_tensor);
|
|
|
|
ret = MoveTensorData(dst_tensor, src_tensor);
|
|
|
@ -55,7 +55,13 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
|
|
|
|
int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) {
|
|
|
|
int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) {
|
|
|
|
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() ||
|
|
|
|
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() ||
|
|
|
|
!(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) {
|
|
|
|
!(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) {
|
|
|
|
MS_LOG(ERROR) << "input tensor and output tensor is incompatible";
|
|
|
|
MS_LOG(ERROR) << "input tensor and output tensor is incompatible.";
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
|
|
|
|
|
|
|
|
<< "output tensor data_type: " << dst_tensor->data_type()
|
|
|
|
|
|
|
|
<< "input tensor format: " << src_tensor->format() << " vs "
|
|
|
|
|
|
|
|
<< "output tensor format: " << dst_tensor->format() << "input tensor shape: " << src_tensor->shape()
|
|
|
|
|
|
|
|
<< " vs "
|
|
|
|
|
|
|
|
<< "output tensor shape: " << dst_tensor->shape();
|
|
|
|
return RET_ERROR;
|
|
|
|
return RET_ERROR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (src_tensor->root_tensor() == nullptr) {
|
|
|
|
if (src_tensor->root_tensor() == nullptr) {
|
|
|
@ -83,18 +89,19 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_
|
|
|
|
return RET_OK;
|
|
|
|
return RET_OK;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) {
|
|
|
|
int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) {
|
|
|
|
// shape may change, because tensors.size() can be change in RunGraph
|
|
|
|
// shape may change, because tensors.size() can be change in RunGraph
|
|
|
|
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) {
|
|
|
|
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) {
|
|
|
|
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
|
|
|
|
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
|
|
|
|
|
|
|
|
<< "output tensor data_type: " << dst_tensor->data_type()
|
|
|
|
|
|
|
|
<< "input tensor format: " << src_tensor->format() << " vs "
|
|
|
|
|
|
|
|
<< "output tensor format: " << dst_tensor->format();
|
|
|
|
return RET_ERROR;
|
|
|
|
return RET_ERROR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (dst_tensor->element_shape().empty()) {
|
|
|
|
// when tensorlist malloc is done. this need to check element_shape compatibility
|
|
|
|
dst_tensor->set_element_shape(src_tensor->element_shape());
|
|
|
|
dst_tensor->set_element_shape(src_tensor->element_shape());
|
|
|
|
} else if (dst_tensor->element_shape() != src_tensor->element_shape()) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible";
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto update_data_type = kTypeUnknown;
|
|
|
|
auto update_data_type = kTypeUnknown;
|
|
|
|
auto dst_tensor_data_type = dst_tensor->tensors_data_type();
|
|
|
|
auto dst_tensor_data_type = dst_tensor->tensors_data_type();
|
|
|
|
auto src_tensor_data_type = src_tensor->tensors_data_type();
|
|
|
|
auto src_tensor_data_type = src_tensor->tensors_data_type();
|
|
|
|