|
|
|
@ -30,6 +30,22 @@ namespace mindspore::kernel {
|
|
|
|
|
|
|
|
|
|
int TensorListSetItemCPUKernel::Init() { return RET_OK; }
|
|
|
|
|
|
|
|
|
|
int TensorListSetItemCPUKernel::CheckParam() {
|
|
|
|
|
if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) {
|
|
|
|
|
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
|
|
|
|
|
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (in_tensors_[1]->ElementsNum() != 1) {
|
|
|
|
|
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) {
|
|
|
|
|
output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]);
|
|
|
|
|
int new_tensors_size = origin_size + 1;
|
|
|
|
@ -46,19 +62,13 @@ int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) {
|
|
|
|
|
|
|
|
|
|
int TensorListSetItemCPUKernel::Run() {
|
|
|
|
|
input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
|
|
|
|
|
if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) {
|
|
|
|
|
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
|
|
|
|
|
|
|
|
|
|
if (CheckParam() != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "check param failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int dim0 = input0_->ElementsNum() - 1;
|
|
|
|
|
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
|
|
|
|
|
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (in_tensors_[1]->ElementsNum() != 1) {
|
|
|
|
|
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
index_ = reinterpret_cast<int *>(in_tensors_[1]->data_c())[0];
|
|
|
|
|
if (index_ < 0 || index_ > dim0) {
|
|
|
|
|
if (IncrementOutputSize(output0_->shape()[0]) != RET_OK) {
|
|
|
|
@ -81,6 +91,10 @@ int TensorListSetItemCPUKernel::Run() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// copy each tensor in tensors_
|
|
|
|
|
if (input0_->tensors().empty() && index_ == 0) {
|
|
|
|
|
input0_->set_element_shape(input2_->shape());
|
|
|
|
|
output0_->set_element_shape(input2_->shape());
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < output0_->ElementsNum(); ++i) {
|
|
|
|
|
if (i == index_) {
|
|
|
|
|
auto dst = output0_->GetTensor(i);
|
|
|
|
|