!6990 MSLITE adjust conv_param position and add tflite custom parser

Merge pull request !6990 from 徐安越/master
pull/6990/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f9353bb963

@ -78,6 +78,7 @@ enum TypeId : int {
kNumberTypeFloat16,
kNumberTypeFloat32,
kNumberTypeFloat64,
kNumberTypeComplex64,
kNumberTypeEnd
};
} // namespace mindspore

@ -212,6 +212,9 @@ union PrimitiveType {
CustomExtractFeatures,
AudioSpectrogram,
Mfcc,
Rfft,
FftReal,
FftImag,
}
enum QuantType: int {

@ -988,3 +988,13 @@ table Mfcc {
filterBankChannelNum : int;
dctCoeffNum : int;
}
table Rfft {
fftLength : int;
}
table FftReal {
}
table FftImag {
}

@ -110,7 +110,8 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
for (const auto &kernel : input_kernels) {
for (const auto &tensor : kernel->in_tensors()) {
auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor);
if (iter == all_output_tensors.end() && tensor->data_c() == nullptr) {
if (iter == all_output_tensors.end() &&
!(tensor->category() == mindspore::lite::Tensor::CONST && tensor->data_c() != nullptr)) {
input_tensors.emplace_back(tensor);
}
}

@ -171,6 +171,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
bool use_winograd = false;
int out_unit;
if (primitive != nullptr && primitive->GetInferFlag()) {
conv_param->input_h_ = inputs.front()->Height();
conv_param->input_w_ = inputs.front()->Width();
conv_param->input_channel_ = inputs.front()->Channel();
@ -178,9 +181,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
conv_param->output_w_ = outputs.front()->Width();
conv_param->output_channel_ = outputs.front()->Channel();
conv_param->op_parameter_.thread_num_ = ctx->thread_num_;
bool use_winograd = false;
int out_unit;
if (primitive != nullptr && primitive->GetInferFlag()) {
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param);
}

@ -137,6 +137,49 @@ STATUS TfliteCustomParser::ExtractFeatures(const std::vector<uint8_t> &custom_at
return RET_OK;
}
STATUS TfliteCustomParser::Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
std::unique_ptr<schema::RfftT> attr = std::make_unique<schema::RfftT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
std::vector<int> fft_length;
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, fft_length)) {
MS_LOG(ERROR) << "rfft -> fftLength get failed";
return RET_ERROR;
}
attr->fftLength = fft_length[0];
op->primitive->value.type = schema::PrimitiveType_Rfft;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteCustomParser::FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
std::unique_ptr<schema::FftRealT> attr = std::make_unique<schema::FftRealT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_FftReal;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteCustomParser::FftImag(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
std::unique_ptr<schema::FftImagT> attr = std::make_unique<schema::FftImagT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_FftImag;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteCustomParser";
@ -163,6 +206,12 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
status = ExtractFeatures(custom_attr, op, tflite_op);
} else if (custom_type == "AudioSpectrogram") {
status = AudioSpectrogram(custom_attr, op, tflite_op);
} else if (custom_type == "FlexRFFT") {
status = Rfft(custom_attr, op, tflite_op, tflite_model);
} else if (custom_type == "FlexReal") {
status = FftReal(custom_attr, op, tflite_op);
} else if (custom_type == "FlexImag") {
status = FftImag(custom_attr, op, tflite_op);
} else {
MS_LOG(ERROR) << "the custom op hasn't been supported now";
status = RET_NOT_FIND_OP;

@ -49,6 +49,15 @@ class TfliteCustomParser : public TfliteNodeParser {
STATUS ExtractFeatures(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
STATUS Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op, const std::unique_ptr<tflite::ModelT> &tflite_model);
STATUS FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
STATUS FftImag(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
};
} // namespace lite
} // namespace mindspore

@ -138,7 +138,7 @@ std::map<int, TypeId> type_map = {
{tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8},
{tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8},
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString},
};
{tflite::TensorType_COMPLEX64, TypeId::kNumberTypeComplex64}};
schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) {
return tfMsActivationFunctionMap.at(tfliteAFType);

Loading…
Cancel
Save