diff --git a/mindspore/lite/include/errorcode.h b/mindspore/lite/include/errorcode.h index 65519b6f73..edd1d7da52 100644 --- a/mindspore/lite/include/errorcode.h +++ b/mindspore/lite/include/errorcode.h @@ -32,6 +32,7 @@ constexpr int RET_PARAM_INVALID = -3; /**< Invalid parameter.*/ constexpr int RET_NO_CHANGE = -4; /**< No change. */ constexpr int RET_SUCCESS_EXIT = -5; /**< No error but exit. */ constexpr int RET_MEMORY_FAILED = -6; /**< Fail to create memory. */ +constexpr int RET_NOT_SUPPORT = -7; /**< Fail to support. */ /* Executor error code, range: [-101,-200] */ constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to check range. */ @@ -53,6 +54,10 @@ constexpr int RET_FORMAT_ERR = -401; /**< Failed to checking tensor format. */ /* InferShape error code, range: [-501,-600] */ constexpr int RET_INFER_ERR = -501; /**< Failed to infer shape. */ constexpr int RET_INFER_INVALID = -502; /**< Invalid infer shape before runtime. */ + +/* User input param error code, range: [-601, 700]*/ +constexpr int RET_INPUT_PARAM_INVALID = -601; /**< Invalid input param by user. */ +constexpr int RET_INPUT_PARAM_LACK = -602; /**< LACK input param by user. */ } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 52a6b77962..fa885526e7 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -203,6 +203,8 @@ union PrimitiveType { LogGrad, BatchToSpaceND, LshProjection, + HashtableLookup, + SkipGram, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index b32aebb856..c728751791 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -948,3 +948,12 @@ table BlackBox { table LshProjection { type : LshProjectionType; } + +table HashtableLookup { +} + +table SkipGram { + includeAllGrams : bool; + maxSkipSize : int; + ngramSize : int; +} diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 1940bc3e28..039f78ed0f 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -109,15 +109,15 @@ int RunConverter(int argc, const char **argv) { std::unique_ptr flags(new (std::nothrow) converter::Flags); if (flags == nullptr) { MS_LOG(ERROR) << "new flags error "; + std::cout << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << std::endl; return RET_MEMORY_FAILED; } auto status = flags->Init(argc, argv); - if (status == RET_SUCCESS_EXIT) { - return status; - } - if (status != 0) { - MS_LOG(ERROR) << "converter::Flags Init failed: " << status; - std::cout << "CONVERTER::FLAGS INIT FAILED" << std::endl; + if (status != RET_OK) { + if (status != RET_SUCCESS_EXIT) { + MS_LOG(ERROR) << "converter::Flags Init failed: " << status; + std::cout << "CONVERTER::FLAGS INIT FAILED:" << status << std::endl; + } return status; } // Load graph @@ -148,13 +148,14 @@ int RunConverter(int argc, const char **argv) { } break; default: { MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk; - return 1; + std::cout << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << std::endl; + return RET_INPUT_PARAM_INVALID; } } status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); if (fb_graph == nullptr) { MS_LOG(ERROR) << "Convert model return nullptr"; - std::cout << "CONVERT RESULT: FAILED!" << std::endl; + std::cout << "CONVERT RESULT FAILED:" << status << std::endl; return status; } @@ -164,14 +165,14 @@ int RunConverter(int argc, const char **argv) { status = storage.Save(*fb_graph, flags->outputFile); if (status != 0) { MS_LOG(ERROR) << "Save graph failed"; - std::cout << "SAVE GRAPH FAILED!" << std::endl; - return RET_ERROR; + std::cout << "SAVE GRAPH FAILED:" << status << std::endl; + return status; } delete fb_graph; MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!"; - std::cout << "CONVERT RESULT: SUCCESS!" << std::endl; - return RET_OK; + std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl; + return status; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 833a3ae40b..7095675789 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -55,7 +55,7 @@ int Flags::Init(int argc, const char **argv) { if (err.IsSome()) { std::cerr << err.Get(); std::cerr << this->Usage() << std::endl; - return 1; + return RET_INPUT_PARAM_INVALID; } if (this->help) { @@ -64,21 +64,21 @@ int Flags::Init(int argc, const char **argv) { } if (this->modelFile.empty()) { std::cerr << "INPUT MISSING: model file path is necessary"; - return 1; + return RET_INPUT_PARAM_LACK; } if (this->outputFile.empty()) { std::cerr << "INPUT MISSING: output file path is necessary"; - return 1; + return RET_INPUT_PARAM_LACK; } if (this->outputFile.rfind('/') == this->outputFile.length() - 1) { std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path"; - return 1; + return RET_INPUT_PARAM_INVALID; } if (this->fmkIn.empty()) { std::cerr << "INPUT MISSING: fmk is necessary"; - return 1; + return RET_INPUT_PARAM_LACK; } if (this->inputInferenceTypeIn == "FLOAT") { this->inputInferenceType = TypeId::kNumberTypeFloat; @@ -87,7 +87,7 @@ int Flags::Init(int argc, const char **argv) { } else { std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8", this->inputInferenceTypeIn.c_str(); - return 1; + return RET_INPUT_PARAM_INVALID; } if (this->inferenceTypeIn == "FLOAT") { @@ -97,7 +97,7 @@ int Flags::Init(int argc, const char **argv) { } else { std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8", this->inferenceTypeIn.c_str(); - return 1; + return RET_INPUT_PARAM_INVALID; } if (this->fmkIn == "CAFFE") { @@ -110,12 +110,12 @@ int Flags::Init(int argc, const char **argv) { this->fmk = FmkType_ONNX; } else { std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX"; - return 1; + return RET_INPUT_PARAM_INVALID; } if (this->fmk != FmkType_CAFFE && !weightFile.empty()) { std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; - return 1; + return RET_INPUT_PARAM_INVALID; } if (this->quantTypeIn == "AwareTraining") { this->quantType = QuantType_AwareTraining; @@ -127,7 +127,7 @@ int Flags::Init(int argc, const char **argv) { this->quantType = QuantType_QUANT_NONE; } else { std::cerr << "INPUT ILLEGAL: quantType must be AwareTraining|WeightQuant|PostTraining"; - return 1; + return RET_INPUT_PARAM_INVALID; } @@ -137,9 +137,9 @@ int Flags::Init(int argc, const char **argv) { this->trainModel = false; } else { std::cerr << "INPUT ILLEGAL: trainModel must be true|false "; - return 1; + return RET_INPUT_PARAM_INVALID; } - return 0; + return RET_OK; } } // namespace converter } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc index 5a67a11abc..7a2a474eff 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -176,6 +176,9 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; return RET_ERROR; } + } else if (attr->group != 1) { + MS_LOG(ERROR) << "group conv hasn't supported"; + return RET_NOT_SUPPORT; } else { op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.value = attr.release(); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc index de1f234603..356df10005 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc @@ -78,5 +78,6 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node } OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser()); +OnnxNodeRegistrar g_onnxLRNxParser("LRN", new OnnxLrnParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc index 654f66c9b9..efa0f1d408 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc @@ -42,19 +42,19 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &t MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - - const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions(); - if (tflite_attr == nullptr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; + std::vector dims; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, dims)) { + MS_LOG(ERROR) << "get expand_dims -> dim failed"; + return RET_ERROR; } - - attr->dim = -1; - - MS_LOG(ERROR) << "The attr dim is folded by TFLite."; - return RET_ERROR; + attr->dim = dims[0]; + op->primitive->value.type = schema::PrimitiveType_ExpandDims; + op->primitive->value.value = attr.release(); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), + tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), + tflite_tensors.size(), schema::Format::Format_NHWC); } - TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc new file mode 100644 index 0000000000..8e989ae164 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_HashtableLookup; + op->primitive->value.value = attr.release(); + for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), + tflite_tensors.size(), schema::Format::Format_NHWC); + } + for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), + tflite_tensors.size(), schema::Format::Format_NHWC); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h new file mode 100644 index 0000000000..0a37d9f259 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H + +#include +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteHashtableLookupParser : public TfliteNodeParser { + public: + TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, + std::vector *tensors_id, std::vector *tensors_format, + std::map *tensors_id_map) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc index 0502afc269..3e50c489d4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -42,18 +42,43 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_o MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - - const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions(); - if (tflite_attr == nullptr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - - attr->paddingMode = schema::PaddingMode_CONSTANT; - attr->constantValue = 0.0f; - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) { - MS_LOG(ERROR) << "get pad -> paddings failed"; - return RET_ERROR; + std::vector node_name_str; + Split(op->name, &node_name_str, "-"); + const char *node_name = node_name_str.data()->c_str(); + if (std::strcmp(node_name, "Pad") == 0) { + const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + attr->paddingMode = schema::PaddingMode_CONSTANT; + attr->constantValue = 0.0f; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) { + MS_LOG(ERROR) << "get pad -> paddings failed"; + return RET_ERROR; + } + } else if (std::strcmp(node_name, "MirrorPad") == 0) { + const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + switch (tflite_attr->mode) { + case tflite::MirrorPadMode_REFLECT: + attr->paddingMode = schema::PaddingMode_REFLECT; + break; + case tflite::MirrorPadMode_SYMMETRIC: + attr->paddingMode = schema::PaddingMode_SYMMETRIC; + break; + default: + MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; + return RET_INVALID_OP_ATTR; + } + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), + tflite_tensors.size(), schema::Format::Format_NHWC); + } else { + MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; + return RET_NOT_SUPPORT; } op->primitive->value.type = schema::PrimitiveType_Pad; @@ -67,5 +92,6 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_o } TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser()); +TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc new file mode 100644 index 0000000000..0329520140 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_skip_gram_parser.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteSkipGramParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, std::vector *tensors_id, + std::vector *tensors_format, std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSkipGramParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsSkipGramOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + attr->includeAllGrams = tflite_attr->include_all_ngrams; + attr->maxSkipSize = tflite_attr->max_skip_size; + attr->ngramSize = tflite_attr->ngram_size; + + op->primitive->value.type = schema::PrimitiveType_SkipGram; + op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), + tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), + tflite_tensors.size(), schema::Format::Format_NHWC); + return RET_OK; +} + +TfliteNodeRegister g_TfliteSkiGramParser("SKipGram", new TfliteSkipGramParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h new file mode 100644 index 0000000000..5ebe6f5846 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H + +#include +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSkipGramParser : public TfliteNodeParser { + public: + TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, + std::vector *tensors_id, std::vector *tensors_format, + std::map *tensors_id_map) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index 352003645e..a32ffaaefb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -57,7 +57,7 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_POW, "Pow"}, {tflite::BuiltinOperator_ARG_MIN, "Argmin"}, {tflite::BuiltinOperator_CEIL, "Ceil"}, - // {tflite::BuiltinOperator_EXPAND_DIMS, "ExpandDims"}, + {tflite::BuiltinOperator_EXPAND_DIMS, "ExpandDims"}, {tflite::BuiltinOperator_FILL, "Fill"}, {tflite::BuiltinOperator_DIV, "Div"}, {tflite::BuiltinOperator_FLOOR, "flOOR"}, @@ -117,6 +117,7 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_UNIQUE, "Unique"}, {tflite::BuiltinOperator_UNPACK, "Unstack"}, {tflite::BuiltinOperator_CUSTOM, "Custom"}, + {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, }; std::map tfMsActivationFunctionMap{ diff --git a/mindspore/lite/tools/converter/return_code.h b/mindspore/lite/tools/converter/return_code.h index 9a0fc87e9f..be5df87c76 100644 --- a/mindspore/lite/tools/converter/return_code.h +++ b/mindspore/lite/tools/converter/return_code.h @@ -33,7 +33,7 @@ class ReturnCode { statusCode = status; } } - STATUS GetReturnCode() { + STATUS GetReturnCode() const { return statusCode; } private: diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index e6e8ec12d1..320ccc75a8 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -85,8 +85,8 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { param_value->set_tensor_type(type_id); param_value->set_format(tensor->GetFormat()); if (tensor->MutableData() != nullptr) { - auto size = tensor->ElementsNum(); - auto tensor_data = new (std::nothrow) float[size]; + auto size = tensor->Size(); + auto tensor_data = new (std::nothrow) uint8_t[size]; if (tensor_data == nullptr) { MS_LOG(ERROR) << "tensor_data is nullptr"; return nullptr; @@ -98,7 +98,7 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { return nullptr; } param_value->set_tensor_addr(tensor_data); - param_value->set_tensor_size(size * sizeof(float) / sizeof(uint8_t)); + param_value->set_tensor_size(size); } parameter->set_default_param(param_value); return parameter;