diff --git a/mindspore/core/ir/dtype/type_id.h b/mindspore/core/ir/dtype/type_id.h index 028fb5ec89..7933346157 100644 --- a/mindspore/core/ir/dtype/type_id.h +++ b/mindspore/core/ir/dtype/type_id.h @@ -78,6 +78,7 @@ enum TypeId : int { kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, + kNumberTypeComplex64, kNumberTypeEnd }; } // namespace mindspore diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index ba529639d0..3705b2ad0a 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -212,6 +212,9 @@ union PrimitiveType { CustomExtractFeatures, AudioSpectrogram, Mfcc, + Rfft, + FftReal, + FftImag, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 874c8f399f..509b4b298b 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -987,4 +987,14 @@ table Mfcc { freqLowerLimit : float; filterBankChannelNum : int; dctCoeffNum : int; -} \ No newline at end of file +} + +table Rfft { + fftLength : int; +} + +table FftReal { +} + +table FftImag { +} diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index 43a7a07f69..8ccd593a2d 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -110,7 +110,8 @@ std::vector LiteKernelUtil::SubgraphInputTensors(const std::vect for (const auto &kernel : input_kernels) { for (const auto &tensor : kernel->in_tensors()) { auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor); - if (iter == all_output_tensors.end() && tensor->data_c() == nullptr) { + if (iter == all_output_tensors.end() && + !(tensor->category() == mindspore::lite::Tensor::CONST && tensor->data_c() != nullptr)) { input_tensors.emplace_back(tensor); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index ea2cc3d724..92c983d57c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -171,16 +171,16 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & auto conv_param = reinterpret_cast(op_parameter); int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; - conv_param->input_h_ = inputs.front()->Height(); - conv_param->input_w_ = inputs.front()->Width(); - conv_param->input_channel_ = inputs.front()->Channel(); - conv_param->output_h_ = outputs.front()->Height(); - conv_param->output_w_ = outputs.front()->Width(); - conv_param->output_channel_ = outputs.front()->Channel(); - conv_param->op_parameter_.thread_num_ = ctx->thread_num_; bool use_winograd = false; int out_unit; if (primitive != nullptr && primitive->GetInferFlag()) { + conv_param->input_h_ = inputs.front()->Height(); + conv_param->input_w_ = inputs.front()->Width(); + conv_param->input_channel_ = inputs.front()->Channel(); + conv_param->output_h_ = outputs.front()->Height(); + conv_param->output_w_ = outputs.front()->Width(); + conv_param->output_channel_ = outputs.front()->Channel(); + conv_param->op_parameter_.thread_num_ = ctx->thread_num_; CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc index 91d1a0fe68..630fdaf9ff 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc @@ -137,6 +137,49 @@ STATUS TfliteCustomParser::ExtractFeatures(const std::vector &custom_at return RET_OK; } +STATUS TfliteCustomParser::Rfft(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + std::vector fft_length; + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, fft_length)) { + MS_LOG(ERROR) << "rfft -> fftLength get failed"; + return RET_ERROR; + } + attr->fftLength = fft_length[0]; + op->primitive->value.type = schema::PrimitiveType_Rfft; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS TfliteCustomParser::FftReal(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_FftReal; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS TfliteCustomParser::FftImag(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_FftImag; + op->primitive->value.value = attr.release(); + return RET_OK; +} + STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteCustomParser"; @@ -163,6 +206,12 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni status = ExtractFeatures(custom_attr, op, tflite_op); } else if (custom_type == "AudioSpectrogram") { status = AudioSpectrogram(custom_attr, op, tflite_op); + } else if (custom_type == "FlexRFFT") { + status = Rfft(custom_attr, op, tflite_op, tflite_model); + } else if (custom_type == "FlexReal") { + status = FftReal(custom_attr, op, tflite_op); + } else if (custom_type == "FlexImag") { + status = FftImag(custom_attr, op, tflite_op); } else { MS_LOG(ERROR) << "the custom op hasn't been supported now"; status = RET_NOT_FIND_OP; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h index 91a0c7a669..17ad6a515c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h @@ -49,6 +49,15 @@ class TfliteCustomParser : public TfliteNodeParser { STATUS ExtractFeatures(const std::vector &custom_attr, schema::CNodeT *op, const std::unique_ptr &tflite_op); + + STATUS Rfft(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model); + + STATUS FftReal(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op); + + STATUS FftImag(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index 0b103d0d12..635f09e77c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -133,12 +133,12 @@ std::map tfMsActivationF }; std::map type_map = { - {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, - {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, - {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, - {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, - {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString}, -}; + {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, + {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, + {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, + {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString}, + {tflite::TensorType_COMPLEX64, TypeId::kNumberTypeComplex64}}; schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { return tfMsActivationFunctionMap.at(tfliteAFType);