From 8b850474f46eb3f9e58c72f5bf2de3136e40c158 Mon Sep 17 00:00:00 2001 From: lzk Date: Wed, 27 Jan 2021 18:00:15 -0800 Subject: [PATCH] add tensorlist and bias_add fp16 ops --- mindspore/lite/nnacl/fp16/arithmetic_fp16.c | 37 +++++- mindspore/lite/nnacl/fp16/arithmetic_fp16.h | 12 +- mindspore/lite/src/ops/tensorlist_getitem.cc | 7 +- .../src/runtime/kernel/arm/fp16/bias_fp16.cc | 106 ++++++++++++++++++ .../src/runtime/kernel/arm/fp16/bias_fp16.h | 45 ++++++++ .../arm/fp32/tensorlist_fromtensor_fp32.cc | 63 ++++------- .../arm/fp32/tensorlist_fromtensor_fp32.h | 5 +- .../arm/fp32/tensorlist_getitem_fp32.cc | 44 ++------ .../arm/fp32/tensorlist_reserve_fp32.cc | 10 +- .../arm/fp32/tensorlist_setitem_fp32.cc | 10 +- .../kernel/arm/fp32/tensorlist_stack_fp32.cc | 22 +++- 11 files changed, 266 insertions(+), 95 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/bias_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/bias_fp16.h diff --git a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c index bb77985601..97b2b5fbdc 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c +++ b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,14 @@ #include "nnacl/common_func.h" #include "nnacl/nnacl_utils.h" -void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t ndim, int *inShape, int *inStrides, - int *outStrides, int *multiple) { +int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, + float16_t *out, int size, ArithmeticParameter *param) { + TileDimensionsFp16(in0, in1, tile_in0, tile_in1, param); + return ElementAddFp16(tile_in0, tile_in1, out, size); +} + +void TileOneDimensionFp16(const float16_t *inData, float16_t *outData, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple) { int srcDimSize = inShape[dim]; if (dim == ndim - 1) { for (int i = 0; i < multiple[dim]; i++) { @@ -37,7 +43,7 @@ void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t } } -void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, +void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, ArithmeticParameter *param) { CalcMultiplesAndStrides(param); TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, @@ -219,6 +225,12 @@ int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int float16x8_t vout = vaddq_f16(vin0, vin1); vst1q_f16(output + index, vout); } + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vst1_f16(output + index, vout); + } #endif for (; index < element_size; index++) { output[index] = input0[index] + input1[index]; @@ -270,6 +282,14 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, vout = vmaxq_f16(vout, zeros); vst1q_f16(output + index, vout); } + float16x4_t zeros1 = vdup_n_f16(0.0f); + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vout = vmax_f16(vout, zeros1); + vst1_f16(output + index, vout); + } #endif for (; index < element_size; index++) { float16_t res = input0[index] + input1[index]; @@ -328,6 +348,15 @@ int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); vst1q_f16(output + index, vout); } + float16x4_t zeros1 = vdup_n_f16(0.0); + float16x4_t bounds1 = vdup_n_f16(6.0); + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vout = vmin_f16(vmax_f16(vout, zeros1), bounds1); + vst1_f16(output + index, vout); + } #endif for (; index < element_size; index++) { output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6); diff --git a/mindspore/lite/nnacl/fp16/arithmetic_fp16.h b/mindspore/lite/nnacl/fp16/arithmetic_fp16.h index 840dfd2a85..00d06f3856 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_fp16.h +++ b/mindspore/lite/nnacl/fp16/arithmetic_fp16.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,12 @@ #ifdef __cplusplus extern "C" { #endif + +void TileOneDimensionFp16(const float16_t *inData, float16_t *outData, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple); +void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, + ArithmeticParameter *param); + int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, @@ -84,6 +90,8 @@ int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, + float16_t *out, int size, ArithmeticParameter *param); int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); @@ -111,8 +119,6 @@ int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); -void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, - ArithmeticParameter *param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/ops/tensorlist_getitem.cc b/mindspore/lite/src/ops/tensorlist_getitem.cc index 5c8a48395d..6299e3c713 100644 --- a/mindspore/lite/src/ops/tensorlist_getitem.cc +++ b/mindspore/lite/src/ops/tensorlist_getitem.cc @@ -125,11 +125,6 @@ int TensorListGetItem::InferShape(std::vector inputs_, std::vect MS_ASSERT(inputs_.at(1) != nullptr); MS_ASSERT(inputs_.at(2) != nullptr); auto input0 = reinterpret_cast(inputs_.at(0)); - if (input0->tensors_data_type() != GetElementDType()) { - MS_LOG(ERROR) << "op dtype: " << GetElementDType() - << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); - return RET_ERROR; - } auto get_index = inputs_.at(1); MS_ASSERT(get_index != nullptr); if (get_index->ElementsNum() != 1) { @@ -184,7 +179,7 @@ int TensorListGetItem::InferShape(std::vector inputs_, std::vect MS_LOG(ERROR) << "element_shape_ is not fullyDefined!"; return RET_ERROR; } - output->set_data_type(GetElementDType()); + output->set_data_type(input0->data_type()); output->set_shape(element_shape_); } output->set_format(input0->GetTensor(index_)->format()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/bias_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/bias_fp16.cc new file mode 100644 index 0000000000..60d1132dc0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/bias_fp16.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "include/errorcode.h" +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/fp16/bias_fp16.h" +#include "src/kernel_registry.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BiasAdd; + +namespace mindspore::kernel { + +int BiasCPUFp16Kernel::ReSize() { + auto dims = in_tensors_.at(0)->shape(); + bias_param_->ndim_ = dims.size(); + if (bias_param_->ndim_ < 1 || bias_param_->ndim_ > 5) { + MS_LOG(ERROR) << "input shape is invalid"; + return RET_ERROR; + } + for (size_t i = 0; i < bias_param_->ndim_; i++) { + bias_param_->in_shape0_[i] = dims[i]; + bias_param_->in_shape1_[i] = 1; + bias_param_->out_shape_[i] = dims[i]; + } + bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1]; + return RET_OK; +} + +int BiasCPUFp16Kernel::Run() { + auto in = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto out = reinterpret_cast(out_tensors_.at(0)->MutableData()); + size_t data_size = in_tensors_.at(0)->ElementsNum(); + MS_ASSERT(context_->allocator != nullptr); + auto *tile_in = reinterpret_cast(context_->allocator->Malloc(data_size * sizeof(float16_t))); + auto *tile_bias = reinterpret_cast(context_->allocator->Malloc(data_size * sizeof(float16_t))); + if (tile_in == nullptr || tile_bias == nullptr) { + MS_LOG(ERROR) << "Memory allocation failed"; + context_->allocator->Free(tile_in); + context_->allocator->Free(tile_bias); + return RET_NULL_PTR; + } + BroadcastAddFp16(in, bias_data_, tile_in, tile_bias, out, data_size, bias_param_); + context_->allocator->Free(tile_in); + context_->allocator->Free(tile_bias); + return RET_OK; +} + +BiasCPUFp16Kernel::~BiasCPUFp16Kernel() { + if ((bias_data_type_ == kNumberTypeFloat || bias_data_type_ == kNumberTypeFloat32) && bias_data_ != nullptr) { + free(bias_data_); + bias_data_ = nullptr; + } +} + +int BiasCPUFp16Kernel::Init() { + auto bias_tensor = in_tensors_.at(1); + MS_ASSERT(bias_tensor != nullptr); + bias_data_type_ = bias_tensor->data_type(); + if (bias_data_type_ == kNumberTypeFloat || bias_data_type_ == kNumberTypeFloat32) { + bias_data_ = reinterpret_cast(malloc(bias_tensor->ElementsNum() * sizeof(float16_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "bias_data_ is nullptr"; + return RET_NULL_PTR; + } + auto *bias = reinterpret_cast(bias_tensor->MutableData()); + if (bias != nullptr) { + MS_LOG(ERROR) << "bias is nullptr!"; + return RET_NULL_PTR; + } + for (int i = 0; i < bias_tensor->ElementsNum(); ++i) { + bias_data_[i] = (float16_t)(bias[i]); + } + } else { + bias_data_ = reinterpret_cast(bias_tensor->MutableData()); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "bias_data_ is nullptr"; + return RET_NULL_PTR; + } + } + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/bias_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/bias_fp16.h new file mode 100644 index 0000000000..525600551b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/bias_fp16.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_BIAS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_BIAS_H_ +#include +#include "src/lite_kernel.h" +#include "nnacl/fp16/arithmetic_fp16.h" + +namespace mindspore::kernel { +class BiasCPUFp16Kernel : public LiteKernel { + public: + BiasCPUFp16Kernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + bias_param_ = reinterpret_cast(parameter); + } + ~BiasCPUFp16Kernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + ArithmeticParameter *bias_param_ = nullptr; + float16_t *bias_data_ = nullptr; + TypeId bias_data_type_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_BIAS_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.cc index bd8cbcaeaf..e018e9bc9e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.cc @@ -53,25 +53,24 @@ int TensorListFromTensorCPUKernel::IsCompatibleShape() { } int TensorListFromTensorCPUKernel::Init() { - input0_ = in_tensors_[0]; // row tensor - input1_ = in_tensors_[1]; // element_shape tensor - output0_ = out_tensors_[0]; - return IsCompatibleShape(); -} - -int TensorListFromTensorCPUKernel::ReSize() { - auto ret = this->Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed!"; - return ret; +#ifdef ENABLE_FP16 + if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { + dtype_ = kNumberTypeFloat16; } +#endif return RET_OK; } +int TensorListFromTensorCPUKernel::ReSize() { return RET_OK; } + int TensorListFromTensorCPUKernel::Run() { input0_ = in_tensors_[0]; // row tensor input1_ = in_tensors_[1]; // element_shape tensor output0_ = out_tensors_[0]; + if (IsCompatibleShape() != RET_OK) { + MS_LOG(ERROR) << "IsNotCompatibleShape!"; + return RET_ERROR; + } if (input0_->shape().size() == 0) { MS_LOG(ERROR) << "input0_->shape().size():" << input0_->shape().size() << " must be greater than 0"; } @@ -86,7 +85,9 @@ int TensorListFromTensorCPUKernel::Run() { return RET_ERROR; } int devision_dim0 = input0_->ElementsNum() / dim0; - auto in_ptr = reinterpret_cast(input0_->data_c()); + auto data_offset = devision_dim0 * lite::DataTypeSize(dtype_); + auto in_data = reinterpret_cast(input0_->data_c()); + MS_ASSERT(in_data != nullptr); // copy data from input0(tensor) to output(tensorlist) vector<*tensor> for (int i = 0; i < dim0; ++i) { auto out_ptr = output0->GetTensor(i); @@ -96,37 +97,17 @@ int TensorListFromTensorCPUKernel::Run() { << " must be euqal to devision_dim0:" << devision_dim0; return RET_ERROR; } - memcpy(reinterpret_cast(out_ptr->MutableData()), in_ptr, devision_dim0 * sizeof(float)); - in_ptr += devision_dim0; + auto out_data = out_ptr->MutableData(); + MS_ASSERT(out_data != nullptr); + memcpy(out_data, in_data, data_offset); + in_data += data_offset; } return RET_OK; } -kernel::LiteKernel *CpuTensorListFromTensorFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *op_parameter, const lite::InnerContext *ctx, - const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - if (op_parameter == nullptr) { - MS_LOG(ERROR) << "Input op_parameter is nullptr!"; - return nullptr; - } - if (ctx == nullptr) { - MS_LOG(ERROR) << "Input context is nullptr!"; - free(op_parameter); - return nullptr; - } - MS_ASSERT(desc.type == schema::PrimitiveType_TensorListFromTensor); - op_parameter->thread_num_ = ctx->thread_num_; - auto *kernel = new (std::nothrow) TensorListFromTensorCPUKernel(op_parameter, inputs, outputs, ctx, primitive); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new TensorListFromTensorCPUKernel fail!"; - free(op_parameter); - return nullptr; - } - return kernel; -} - -REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListFromTensor, CpuTensorListFromTensorFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListFromTensor, CpuTensorListFromTensorFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListFromTensor, + LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListFromTensor, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListFromTensor, + LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.h index 4e3a5ece13..94061454b5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.h @@ -21,6 +21,7 @@ #include "src/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" +#include "nnacl/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListFromTensorCPUKernel : public LiteKernel { @@ -28,7 +29,8 @@ class TensorListFromTensorCPUKernel : public LiteKernel { TensorListFromTensorCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive), + dtype_(reinterpret_cast(parameter)->element_dtype_) {} ~TensorListFromTensorCPUKernel() = default; int Init() override; @@ -41,6 +43,7 @@ class TensorListFromTensorCPUKernel : public LiteKernel { lite::Tensor *output0_ = nullptr; lite::Tensor *input0_ = nullptr; lite::Tensor *input1_ = nullptr; + TypeId dtype_ = kTypeUnknown; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc index 6b4618fe76..7f88cb6ee4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc @@ -31,11 +31,11 @@ namespace mindspore::kernel { int TensorListGetItemCPUKernel::Init() { MS_ASSERT(in_tensors_.size() >= 2); MS_ASSERT(in_tensors_.at(0) != nullptr); - auto input0 = reinterpret_cast(in_tensors_.at(0)); - if (dtype_ != input0->tensors_data_type()) { - MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); - return RET_ERROR; +#ifdef ENABLE_FP16 + if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { + dtype_ = kNumberTypeFloat16; } +#endif return RET_OK; } @@ -45,6 +45,10 @@ int TensorListGetItemCPUKernel::Run() { MS_ASSERT(in_tensors_.at(1) != nullptr); MS_ASSERT(out_tensors_.at(0) != nullptr); auto input0 = reinterpret_cast(in_tensors_.at(0)); + if (dtype_ != input0->tensors_data_type()) { + MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); + return RET_ERROR; + } MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr); index_ = reinterpret_cast(in_tensors_.at(1)->data_c())[0]; int dim0 = input0->ElementsNum() - 1; @@ -66,8 +70,7 @@ int TensorListGetItemCPUKernel::Run() { return RET_ERROR; } } else { - // reset 0 and dtype = dtype_ - // TODO(DT_VARIANT): dtype = DT_VARIANT is not handle + // reset data buffer is zero auto out_data = out_tensors_[0]->data_c(); if (out_data == nullptr) { MS_LOG(ERROR) << "data of out_tensors_[0] is nullptr"; @@ -80,30 +83,7 @@ int TensorListGetItemCPUKernel::Run() { int TensorListGetItemCPUKernel::ReSize() { return RET_OK; } -kernel::LiteKernel *CpuTensorListGetItemFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *op_parameter, const lite::InnerContext *ctx, - const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - if (op_parameter == nullptr) { - MS_LOG(ERROR) << "Input op_parameter is nullptr!"; - return nullptr; - } - if (ctx == nullptr) { - MS_LOG(ERROR) << "Input context is nullptr!"; - free(op_parameter); - return nullptr; - } - MS_ASSERT(desc.type == schema::PrimitiveType_TensorListGetItem); - auto *kernel = new (std::nothrow) TensorListGetItemCPUKernel(op_parameter, inputs, outputs, ctx, primitive); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new TensorListGetItemCPUKernel fail!"; - free(op_parameter); - return nullptr; - } - return kernel; -} - -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListGetItem, CpuTensorListGetItemFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListGetItem, CpuTensorListGetItemFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListGetItem, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListGetItem, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListGetItem, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.cc index d32263dad8..48b4802b82 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.cc @@ -27,7 +27,14 @@ using mindspore::schema::PrimitiveType_TensorListReserve; namespace mindspore::kernel { -int TensorListReserveCPUKernel::Init() { return RET_OK; } +int TensorListReserveCPUKernel::Init() { +#ifdef ENABLE_FP16 + if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && element_dtype_ == kNumberTypeFloat32) { + element_dtype_ = kNumberTypeFloat16; + } +#endif + return RET_OK; +} int TensorListReserveCPUKernel::Run() { auto input0 = in_tensors_.at(0); @@ -48,5 +55,6 @@ int TensorListReserveCPUKernel::Run() { int TensorListReserveCPUKernel::ReSize() { return RET_OK; } REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListReserve, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListReserve, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListReserve, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc index 9ce67e1dd5..83a481380a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc @@ -28,7 +28,14 @@ using mindspore::schema::PrimitiveType_TensorListSetItem; namespace mindspore::kernel { -int TensorListSetItemCPUKernel::Init() { return RET_OK; } +int TensorListSetItemCPUKernel::Init() { +#ifdef ENABLE_FP16 + if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { + dtype_ = kNumberTypeFloat16; + } +#endif + return RET_OK; +} int TensorListSetItemCPUKernel::CheckParam() { if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) { @@ -143,5 +150,6 @@ int TensorListSetItemCPUKernel::Run() { int TensorListSetItemCPUKernel::ReSize() { return RET_OK; } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListSetItem, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListSetItem, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListSetItem, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc index f141a2409d..fe673e3dae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc @@ -60,6 +60,11 @@ int TensorListStackCPUKernel::Init() { MS_ASSERT(input0_ != nullptr); output0_ = out_tensors_[0]; MS_ASSERT(output0_ != nullptr); +#ifdef ENABLE_FP16 + if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { + dtype_ = kNumberTypeFloat16; + } +#endif return RET_OK; } @@ -159,17 +164,21 @@ int TensorListStackCPUKernel::Run() { MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num << "must be equal to in_ele_num:" << in_ele_num; return RET_ERROR; } - auto out_ptr = reinterpret_cast(output0_->MutableData()); + auto out_data = reinterpret_cast(output0_->MutableData()); + auto unknown_type_offset = TypeUnknownSize * lite::DataTypeSize(dtype_); + MS_ASSERT(out_data != nullptr); for (int i = 0; i < num_element_; ++i) { auto in_ptr = input0_->GetTensor(i); MS_ASSERT(in_ptr != nullptr); if (in_ptr->data_type() != kTypeUnknown) { - int in_size = in_ptr->ElementsNum(); - memcpy(out_ptr, in_ptr->data_c(), in_size * sizeof(float)); - out_ptr += in_size; + int data_size = in_ptr->ElementsNum() * lite::DataTypeSize(dtype_); + auto in_data = in_ptr->data_c(); + MS_ASSERT(in_data != nullptr); + memcpy(out_data, in_data, data_size); + out_data += data_size; } else { - memset(out_ptr, 0, TypeUnknownSize * sizeof(float)); - out_ptr += TypeUnknownSize; + memset(out_data, 0, unknown_type_offset); + out_data += unknown_type_offset; } } return RET_OK; @@ -178,5 +187,6 @@ int TensorListStackCPUKernel::Run() { int TensorListStackCPUKernel::ReSize() { return RET_OK; } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListStack, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListStack, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListStack, LiteKernelCreator) } // namespace mindspore::kernel