|
|
|
@ -30,7 +30,22 @@ using mindspore::lite::RET_OK;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Gather;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
GatherFp16CPUKernel::~GatherFp16CPUKernel() {
|
|
|
|
|
if (input_data_) {
|
|
|
|
|
context_->allocator->Free(input_data_);
|
|
|
|
|
input_data_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GatherFp16CPUKernel::Init() {
|
|
|
|
|
auto input_tensor = in_tensors_.at(0);
|
|
|
|
|
if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data_c() != nullptr) {
|
|
|
|
|
const_input_ = true;
|
|
|
|
|
input_data_ =
|
|
|
|
|
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
|
|
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
@ -128,11 +143,13 @@ int GatherFp16CPUKernel::Run() {
|
|
|
|
|
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
auto input_tensor = in_tensors_.at(0);
|
|
|
|
|
if (input_tensor->data_type() == kNumberTypeFloat32) {
|
|
|
|
|
input_data_ =
|
|
|
|
|
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
|
|
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
|
|
|
|
|
if (!const_input_) {
|
|
|
|
|
auto input_tensor = in_tensors_.at(0);
|
|
|
|
|
if (input_tensor->data_type() == kNumberTypeFloat32) {
|
|
|
|
|
input_data_ =
|
|
|
|
|
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
|
|
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ret = ParallelLaunch(this->context_->thread_pool_, GatherRunFp16, this, op_parameter_->thread_num_);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
@ -142,7 +159,7 @@ int GatherFp16CPUKernel::Run() {
|
|
|
|
|
context_->allocator->Free(indices_data_);
|
|
|
|
|
indices_data_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (input_data_) {
|
|
|
|
|
if (!const_input_ && input_data_) {
|
|
|
|
|
context_->allocator->Free(input_data_);
|
|
|
|
|
input_data_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|