|
|
@ -23,26 +23,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace lite {
|
|
|
|
namespace lite {
|
|
|
|
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
|
|
|
STATUS TfliteCustomParser::DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
|
|
|
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
|
|
|
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
|
|
|
MS_LOG(DEBUG) << "parse TfliteCustomParser";
|
|
|
|
|
|
|
|
if (op == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "op is null";
|
|
|
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
|
|
|
if (op->primitive == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "op->primitive is null";
|
|
|
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<schema::DetectionPostProcessT> attr = std::make_unique<schema::DetectionPostProcessT>();
|
|
|
|
std::unique_ptr<schema::DetectionPostProcessT> attr = std::make_unique<schema::DetectionPostProcessT>();
|
|
|
|
if (attr == nullptr) {
|
|
|
|
if (attr == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const auto &custom_attr = tflite_op->custom_options;
|
|
|
|
|
|
|
|
auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap();
|
|
|
|
auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap();
|
|
|
|
attr->format = schema::Format::Format_NHWC;
|
|
|
|
attr->format = schema::Format::Format_NHWC;
|
|
|
|
attr->inputSize = tflite_op->inputs.size();
|
|
|
|
attr->inputSize = tflite_op->inputs.size();
|
|
|
@ -73,7 +61,115 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|
|
|
|
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_DetectionPostProcess;
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_DetectionPostProcess;
|
|
|
|
op->primitive->value.value = attr.release();
|
|
|
|
op->primitive->value.value = attr.release();
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STATUS TfliteCustomParser::AudioSpectrogram(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
|
|
|
|
|
|
|
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
|
|
|
|
|
|
|
std::unique_ptr<schema::AudioSpectrogramT> attr = std::make_unique<schema::AudioSpectrogramT>();
|
|
|
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap();
|
|
|
|
|
|
|
|
attr->windowSize = attr_map["window_size"].AsInt64();
|
|
|
|
|
|
|
|
attr->stride = attr_map["stride"].AsInt64();
|
|
|
|
|
|
|
|
attr->magSquare = attr_map["magnitude_squared"].AsBool();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_AudioSpectrogram;
|
|
|
|
|
|
|
|
op->primitive->value.value = attr.release();
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STATUS TfliteCustomParser::Mfcc(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
|
|
|
|
|
|
|
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
|
|
|
|
|
|
|
std::unique_ptr<schema::MfccT> attr = std::make_unique<schema::MfccT>();
|
|
|
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap();
|
|
|
|
|
|
|
|
attr->freqUpperLimit = attr_map["upper_frequency_limit"].AsInt64();
|
|
|
|
|
|
|
|
attr->freqLowerLimit = attr_map["lower_frequency_limit"].AsInt64();
|
|
|
|
|
|
|
|
attr->filterBankChannelNum = attr_map["filterbank_channel_count"].AsInt64();
|
|
|
|
|
|
|
|
attr->dctCoeffNum = attr_map["dct_coefficient_count"].AsInt64();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_Mfcc;
|
|
|
|
|
|
|
|
op->primitive->value.value = attr.release();
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STATUS TfliteCustomParser::Predict(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
|
|
|
|
|
|
|
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
|
|
|
|
|
|
|
std::unique_ptr<schema::CustomPredictT> attr = std::make_unique<schema::CustomPredictT>();
|
|
|
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
attr->outputNum = reinterpret_cast<const int *>(custom_attr.data())[0];
|
|
|
|
|
|
|
|
attr->weightThreshold = reinterpret_cast<const float *>(custom_attr.data())[1];
|
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_CustomPredict;
|
|
|
|
|
|
|
|
op->primitive->value.value = attr.release();
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STATUS TfliteCustomParser::Normalize(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
|
|
|
|
|
|
|
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
|
|
|
|
|
|
|
std::unique_ptr<schema::CustomNormalizeT> attr = std::make_unique<schema::CustomNormalizeT>();
|
|
|
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_CustomNormalize;
|
|
|
|
|
|
|
|
op->primitive->value.value = attr.release();
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STATUS TfliteCustomParser::ExtractFeatures(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
|
|
|
|
|
|
|
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
|
|
|
|
|
|
|
std::unique_ptr<schema::CustomExtractFeaturesT> attr = std::make_unique<schema::CustomExtractFeaturesT>();
|
|
|
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_CustomExtractFeatures;
|
|
|
|
|
|
|
|
op->primitive->value.value = attr.release();
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
|
|
|
|
|
|
|
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "parse TfliteCustomParser";
|
|
|
|
|
|
|
|
if (op == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "op is null";
|
|
|
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
|
|
|
if (op->primitive == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "op->primitive is null";
|
|
|
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
const auto &custom_attr = tflite_op->custom_options;
|
|
|
|
|
|
|
|
const auto &opcode_index = tflite_op->opcode_index;
|
|
|
|
|
|
|
|
const auto &custom_type = tflite_model->operator_codes[opcode_index]->custom_code;
|
|
|
|
|
|
|
|
int status = RET_OK;
|
|
|
|
|
|
|
|
if (custom_type == "TFLite_Detection_PostProcess") {
|
|
|
|
|
|
|
|
status = DetectPostProcess(custom_attr, op, tflite_op);
|
|
|
|
|
|
|
|
} else if (custom_type == "Predict") {
|
|
|
|
|
|
|
|
status = Predict(custom_attr, op, tflite_op);
|
|
|
|
|
|
|
|
} else if (custom_type == "Normalize") {
|
|
|
|
|
|
|
|
status = Normalize(custom_attr, op, tflite_op);
|
|
|
|
|
|
|
|
} else if (custom_type == "ExtractFeatures") {
|
|
|
|
|
|
|
|
status = ExtractFeatures(custom_attr, op, tflite_op);
|
|
|
|
|
|
|
|
} else if (custom_type == "AudioSpectrogram") {
|
|
|
|
|
|
|
|
status = AudioSpectrogram(custom_attr, op, tflite_op);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "the custom op hasn't been supported now";
|
|
|
|
|
|
|
|
status = RET_NOT_FIND_OP;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
|
|
|
return status;
|
|
|
|
|
|
|
|
}
|
|
|
|
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
|
|
|
|
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
|
|
|
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
|
|
|
schema::Format::Format_NHWC);
|
|
|
|
schema::Format::Format_NHWC);
|
|
|
@ -82,7 +178,7 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|
|
|
AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
|
|
|
AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
|
|
|
schema::Format::Format_NHWC);
|
|
|
|
schema::Format::Format_NHWC);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return RET_OK;
|
|
|
|
return status;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser());
|
|
|
|
TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser());
|
|
|
|