diff --git a/mindspore/lite/nnacl/fp16/batchnorm_fp16.c b/mindspore/lite/nnacl/fp16/batchnorm_fp16.c index baa9e6dfdf..facae90c3f 100644 --- a/mindspore/lite/nnacl/fp16/batchnorm_fp16.c +++ b/mindspore/lite/nnacl/fp16/batchnorm_fp16.c @@ -17,8 +17,8 @@ #include "nnacl/fp16/batchnorm_fp16.h" #include -void BatchNormFp16(const void *input, const void *mean, const void *variance, - BatchNormParameter *param, int task_id, void *output) { +void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, + BatchNormParameter *param, int task_id, float16_t *output) { int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_); int completed_units = task_id * units_per_thread; int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); @@ -27,8 +27,9 @@ void BatchNormFp16(const void *input, const void *mean, const void *variance, for (int i = 0; i < cur_unit; i++) { for (int c = 0; c < param->channel_; c++) { float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); - ((float16_t *)output)[cur_offset + c] = - (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; + if (variance_sqrt != 0) { + output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; + } } cur_offset += param->channel_; } @@ -44,8 +45,12 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset for (int i = 0; i < cur_unit; i++) { for (int c = 0; c < param->channel_; c++) { float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); - float16_t norm_val = (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; - ((float16_t *)output)[cur_offset + c] = norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c]; + if (variance_sqrt != 0) { + float16_t norm_val = + (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; + ((float16_t *)output)[cur_offset + c] = + norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c]; + } } cur_offset += param->channel_; } diff --git a/mindspore/lite/nnacl/fp16/batchnorm_fp16.h b/mindspore/lite/nnacl/fp16/batchnorm_fp16.h index 673bcd46fa..8f6d6aa485 100644 --- a/mindspore/lite/nnacl/fp16/batchnorm_fp16.h +++ b/mindspore/lite/nnacl/fp16/batchnorm_fp16.h @@ -25,8 +25,8 @@ extern "C" { #endif -void BatchNormFp16(const void *input, const void *mean, const void *variance, BatchNormParameter *param, int task_id, - void *output); +void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, BatchNormParameter *param, + int task_id, float16_t *output); void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean, const void *variance, BatchNormParameter *param, int task_id, void *output); diff --git a/mindspore/lite/nnacl/fp32/batchnorm.c b/mindspore/lite/nnacl/fp32/batchnorm.c index 5efde546ce..49926d1c4a 100644 --- a/mindspore/lite/nnacl/fp32/batchnorm.c +++ b/mindspore/lite/nnacl/fp32/batchnorm.c @@ -15,7 +15,6 @@ */ #include "nnacl/fp32/batchnorm.h" -#include "nnacl/fp16/batchnorm_fp16.h" #include #include "nnacl/batchnorm_parameter.h" #include "nnacl/op_base.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.cc index 8805e384a8..c3735673d3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/fp16/batchnorm_fp16.h" +#include "src/runtime/kernel/arm/fp16/common_fp16.h" #include "nnacl/fp16/batchnorm_fp16.h" #include "nnacl/fp16/cast_fp16.h" #include "src/kernel_registry.h" @@ -24,8 +25,9 @@ using mindspore::schema::PrimitiveType_BatchNorm; namespace mindspore::kernel { int BatchnormFp16CPUKernel::InitConstTensor() { - isFloat32Tensor_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32; - if (isFloat32Tensor_) { + is_input_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32; + is_output_fp32_ = out_tensors_.at(0)->data_type() == kNumberTypeFloat32; + if (is_input_fp32_) { auto mean_fp32 = in_tensors_.at(1); auto variance_fp32 = in_tensors_.at(2); mean_ = malloc(mean_fp32->ElementsNum() * sizeof(float16_t)); @@ -50,30 +52,24 @@ int BatchnormFp16CPUKernel::Run() { MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; return ret; } - auto input_fp32 = in_tensors_.at(0); - auto output_fp32 = out_tensors_.at(0); - if (isFloat32Tensor_) { - input_ = context_->allocator->Malloc(input_fp32->ElementsNum() * sizeof(float16_t)); - output_ = context_->allocator->Malloc(output_fp32->ElementsNum() * sizeof(float16_t)); - if (input_ == nullptr || output_ == nullptr) { - FreeInputAndOutput(); - return RET_ERROR; - } - Float32ToFloat16(reinterpret_cast(input_fp32->Data()), - reinterpret_cast(input_), input_fp32->ElementsNum()); - } else { - input_ = in_tensors_.at(0)->Data(); - output_ = out_tensors_.at(0)->Data(); + auto input_tensor = in_tensors_.at(0); + auto output_tensor = out_tensors_.at(0); + input_ = ConvertInputFp32toFp16(input_tensor, context_); + output_ = MallocOutputFp16(output_tensor, context_); + if (input_ == nullptr || output_ == nullptr) { + FreeInputAndOutput(); + MS_LOG(ERROR) << "input or output is nullptr"; + return RET_ERROR; } + ret = ParallelLaunch(THREAD_POOL_DEFAULT, BatchNormRun, this, op_parameter_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]"; } - if (isFloat32Tensor_) { - Float16ToFloat32(reinterpret_cast(output_), reinterpret_cast(output_fp32->Data()), - output_fp32->ElementsNum()); - FreeInputAndOutput(); + if (is_output_fp32_) { + Float16ToFloat32(output_, reinterpret_cast(output_tensor->Data()), output_tensor->ElementsNum()); } + FreeInputAndOutput(); return ret; } @@ -84,11 +80,11 @@ int BatchnormFp16CPUKernel::DoExecute(int task_id) { } void BatchnormFp16CPUKernel::FreeInputAndOutput() { - if (input_ != nullptr) { + if (is_input_fp32_) { context_->allocator->Free(input_); input_ = nullptr; } - if (output_ != nullptr) { + if (is_output_fp32_) { context_->allocator->Free(output_); output_ = nullptr; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.h index b886747fa8..eeec184169 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/batchnorm_fp16.h @@ -35,9 +35,10 @@ class BatchnormFp16CPUKernel : public BatchnormCPUKernel { private: void FreeInputAndOutput(); - bool isFloat32Tensor_ = false; - void *input_ = nullptr; - void *output_ = nullptr; + bool is_input_fp32_ = false; + bool is_output_fp32_ = false; + float16_t *input_ = nullptr; + float16_t *output_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc index f7d2429858..61123741e5 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc @@ -45,7 +45,8 @@ TEST_F(TestTopKFp32, TopK) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), nullptr, desc, nullptr); + auto ctx = std::make_shared(); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run();