fix bug of tensorlist

pull/11716/head
mengyuanli 4 years ago
parent b2cd022c5f
commit 394555853f

@ -131,8 +131,13 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
}
output0->set_max_elements_num(input0->max_elements_num());
output0->set_element_shape(input0->element_shape());
if (input0->tensors().empty() && input0->element_shape().empty() && index == 0) {
input0->set_element_shape(value_tensor->shape());
output0->set_element_shape(value_tensor->shape());
} else {
output0->set_element_shape(input0->element_shape());
}
std::vector<std::vector<int> > out_shape;
if (index == 0 && input0->tensors().size() == 0) { // uninitialized tensorlist
out_shape.push_back(value_tensor->shape());

@ -89,7 +89,9 @@ int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::Tens
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
return RET_ERROR;
}
if (dst_tensor->element_shape() != src_tensor->element_shape()) {
if (dst_tensor->element_shape().empty()) {
dst_tensor->set_element_shape(src_tensor->element_shape());
} else if (dst_tensor->element_shape() != src_tensor->element_shape()) {
MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible";
return RET_ERROR;
}

@ -68,7 +68,7 @@ int TensorListGetItemCPUKernel::Run() {
} else {
// reset 0 and dtype = dtype_
// TODO(DT_VARIANT): dtype = DT_VARIANT is not handle
auto out_data = out_tensors_[0]->MutableData();
auto out_data = out_tensors_[0]->data_c();
if (out_data == nullptr) {
MS_LOG(ERROR) << "data of out_tensors_[0] is nullptr";
return RET_ERROR;

@ -30,6 +30,22 @@ namespace mindspore::kernel {
int TensorListSetItemCPUKernel::Init() { return RET_OK; }
int TensorListSetItemCPUKernel::CheckParam() {
if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) {
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
return RET_ERROR;
}
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int";
return RET_ERROR;
}
if (in_tensors_[1]->ElementsNum() != 1) {
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
return RET_ERROR;
}
return RET_OK;
}
int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) {
output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]);
int new_tensors_size = origin_size + 1;
@ -46,19 +62,13 @@ int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) {
int TensorListSetItemCPUKernel::Run() {
input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) {
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
if (CheckParam() != RET_OK) {
MS_LOG(ERROR) << "check param failed.";
return RET_ERROR;
}
int dim0 = input0_->ElementsNum() - 1;
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int";
return RET_ERROR;
}
if (in_tensors_[1]->ElementsNum() != 1) {
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
return RET_ERROR;
}
index_ = reinterpret_cast<int *>(in_tensors_[1]->data_c())[0];
if (index_ < 0 || index_ > dim0) {
if (IncrementOutputSize(output0_->shape()[0]) != RET_OK) {
@ -81,6 +91,10 @@ int TensorListSetItemCPUKernel::Run() {
}
}
// copy each tensor in tensors_
if (input0_->tensors().empty() && index_ == 0) {
input0_->set_element_shape(input2_->shape());
output0_->set_element_shape(input2_->shape());
}
for (int i = 0; i < output0_->ElementsNum(); ++i) {
if (i == index_) {
auto dst = output0_->GetTensor(i);

@ -39,6 +39,7 @@ class TensorListSetItemCPUKernel : public LiteKernel {
int IncrementOutputSize(int origin_size);
private:
int CheckParam();
lite::TensorList *input0_ = nullptr;
lite::Tensor *input2_ = nullptr;
lite::TensorList *output0_ = nullptr;

@ -240,6 +240,9 @@ Tensor *TensorList::GetTensor(int index) {
}
bool TensorList::IsCompatibleShape(const std::vector<int> &shape) {
if (this->tensors_.empty() && this->element_shape_.empty()) {
return true;
}
if (shape.size() != this->element_shape_.size()) {
return false;
}

@ -439,6 +439,9 @@ STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, const FuncG
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
copy_param_value->set_tensor_shape(param_value->tensor_shape());
copy_param_value->set_format(param_value->format());
copy_param_value->set_tensor_type(param_value->tensor_type());
copy_param_value->SetTensorData(copy_data, param_value->tensor_size());
ext_subgraph_input->set_default_param(copy_param_value);
} else {

Loading…
Cancel
Save