|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "src/runtime/kernel/arm/fp16/batchnorm_fp16.h"
|
|
|
|
|
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
|
|
|
|
|
#include "nnacl/fp16/batchnorm_fp16.h"
|
|
|
|
|
#include "nnacl/fp16/cast_fp16.h"
|
|
|
|
|
#include "src/kernel_registry.h"
|
|
|
|
@ -24,8 +25,9 @@ using mindspore::schema::PrimitiveType_BatchNorm;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
int BatchnormFp16CPUKernel::InitConstTensor() {
|
|
|
|
|
isFloat32Tensor_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
|
|
|
|
|
if (isFloat32Tensor_) {
|
|
|
|
|
is_input_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
|
|
|
|
|
is_output_fp32_ = out_tensors_.at(0)->data_type() == kNumberTypeFloat32;
|
|
|
|
|
if (is_input_fp32_) {
|
|
|
|
|
auto mean_fp32 = in_tensors_.at(1);
|
|
|
|
|
auto variance_fp32 = in_tensors_.at(2);
|
|
|
|
|
mean_ = malloc(mean_fp32->ElementsNum() * sizeof(float16_t));
|
|
|
|
@ -50,30 +52,24 @@ int BatchnormFp16CPUKernel::Run() {
|
|
|
|
|
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
auto input_fp32 = in_tensors_.at(0);
|
|
|
|
|
auto output_fp32 = out_tensors_.at(0);
|
|
|
|
|
if (isFloat32Tensor_) {
|
|
|
|
|
input_ = context_->allocator->Malloc(input_fp32->ElementsNum() * sizeof(float16_t));
|
|
|
|
|
output_ = context_->allocator->Malloc(output_fp32->ElementsNum() * sizeof(float16_t));
|
|
|
|
|
auto input_tensor = in_tensors_.at(0);
|
|
|
|
|
auto output_tensor = out_tensors_.at(0);
|
|
|
|
|
input_ = ConvertInputFp32toFp16(input_tensor, context_);
|
|
|
|
|
output_ = MallocOutputFp16(output_tensor, context_);
|
|
|
|
|
if (input_ == nullptr || output_ == nullptr) {
|
|
|
|
|
FreeInputAndOutput();
|
|
|
|
|
MS_LOG(ERROR) << "input or output is nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(input_fp32->Data()),
|
|
|
|
|
reinterpret_cast<float16_t *>(input_), input_fp32->ElementsNum());
|
|
|
|
|
} else {
|
|
|
|
|
input_ = in_tensors_.at(0)->Data();
|
|
|
|
|
output_ = out_tensors_.at(0)->Data();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ret = ParallelLaunch(THREAD_POOL_DEFAULT, BatchNormRun, this, op_parameter_->thread_num_);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]";
|
|
|
|
|
}
|
|
|
|
|
if (isFloat32Tensor_) {
|
|
|
|
|
Float16ToFloat32(reinterpret_cast<float16_t *>(output_), reinterpret_cast<float *>(output_fp32->Data()),
|
|
|
|
|
output_fp32->ElementsNum());
|
|
|
|
|
FreeInputAndOutput();
|
|
|
|
|
if (is_output_fp32_) {
|
|
|
|
|
Float16ToFloat32(output_, reinterpret_cast<float *>(output_tensor->Data()), output_tensor->ElementsNum());
|
|
|
|
|
}
|
|
|
|
|
FreeInputAndOutput();
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -84,11 +80,11 @@ int BatchnormFp16CPUKernel::DoExecute(int task_id) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BatchnormFp16CPUKernel::FreeInputAndOutput() {
|
|
|
|
|
if (input_ != nullptr) {
|
|
|
|
|
if (is_input_fp32_) {
|
|
|
|
|
context_->allocator->Free(input_);
|
|
|
|
|
input_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (output_ != nullptr) {
|
|
|
|
|
if (is_output_fp32_) {
|
|
|
|
|
context_->allocator->Free(output_);
|
|
|
|
|
output_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|