|
|
|
@ -29,35 +29,38 @@ using mindspore::schema::PrimitiveType_TensorListGetItem;
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
|
|
|
|
|
int TensorListGetItemCPUKernel::Init() {
|
|
|
|
|
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
|
|
|
|
|
MS_ASSERT(in_tensors_.size() >= 2);
|
|
|
|
|
MS_ASSERT(in_tensors_.at(0) != nullptr);
|
|
|
|
|
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0));
|
|
|
|
|
if (dtype_ != input0->tensors_data_type()) {
|
|
|
|
|
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0->tensors_data_type();
|
|
|
|
|
MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (in_tensors_[1]->ElementsNum() != 1) {
|
|
|
|
|
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
index_ = reinterpret_cast<int *>(in_tensors_[1]->data_c())[0];
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int TensorListGetItemCPUKernel::Run() {
|
|
|
|
|
MS_ASSERT(in_tensors_.size() >= 2);
|
|
|
|
|
MS_ASSERT(in_tensors_.at(0) != nullptr);
|
|
|
|
|
MS_ASSERT(in_tensors_.at(1) != nullptr);
|
|
|
|
|
MS_ASSERT(out_tensors_.at(0) != nullptr);
|
|
|
|
|
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0));
|
|
|
|
|
MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr);
|
|
|
|
|
index_ = reinterpret_cast<int *>(in_tensors_.at(1)->data_c())[0];
|
|
|
|
|
int dim0 = input0->ElementsNum() - 1;
|
|
|
|
|
if (index_ < 0 || index_ > dim0) {
|
|
|
|
|
MS_LOG(ERROR) << "index tensor:[" << index_ << "] must be in [0, " << dim0 << "]!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int TensorListGetItemCPUKernel::Run() {
|
|
|
|
|
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
|
|
|
|
|
auto src_ptr = input0->GetTensor(index_);
|
|
|
|
|
MS_ASSERT(src_ptr != nullptr);
|
|
|
|
|
if (src_ptr->data_type() != kTypeUnknown) {
|
|
|
|
|
if (src_ptr->ElementsNum() != out_tensors_[0]->ElementsNum()) {
|
|
|
|
|
if (src_ptr->ElementsNum() != out_tensors_.at(0)->ElementsNum()) {
|
|
|
|
|
MS_LOG(ERROR) << "src_ptr->ElementsNum():" << src_ptr->ElementsNum()
|
|
|
|
|
<< " must be equal to out_tensors_[0]->ElementsNum():" << out_tensors_[0]->ElementsNum();
|
|
|
|
|
<< " must be equal to out_tensors_[0]->ElementsNum():" << out_tensors_.at(0)->ElementsNum();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto status = lite::Tensor::CopyTensorData(*src_ptr, out_tensors_[0]);
|
|
|
|
|
auto status = lite::Tensor::CopyTensorData(*src_ptr, out_tensors_.at(0));
|
|
|
|
|
if (status == RET_ERROR) {
|
|
|
|
|
MS_LOG(ERROR) << "copy tensor data failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -65,19 +68,17 @@ int TensorListGetItemCPUKernel::Run() {
|
|
|
|
|
} else {
|
|
|
|
|
// reset 0 and dtype = dtype_
|
|
|
|
|
// TODO(DT_VARIANT): dtype = DT_VARIANT is not handle
|
|
|
|
|
memset(out_tensors_[0]->MutableData(), 0, out_tensors_[0]->Size());
|
|
|
|
|
auto out_data = out_tensors_[0]->MutableData();
|
|
|
|
|
if (out_data == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "data of out_tensors_[0] is nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
memset(out_data, 0, out_tensors_[0]->Size());
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int TensorListGetItemCPUKernel::ReSize() {
|
|
|
|
|
auto ret = this->Init();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Init kernel failed!";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
int TensorListGetItemCPUKernel::ReSize() { return RET_OK; }
|
|
|
|
|
|
|
|
|
|
kernel::LiteKernel *CpuTensorListGetItemFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
|
|
|
|
const std::vector<lite::Tensor *> &outputs,
|
|
|
|
|