add tflite parsers for PRELU and NEG operators

pull/6793/head
AGroupofProbiotocs 5 years ago
parent 685a77c7c3
commit 7786342511

@ -95,8 +95,8 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found.";
return RET_ERROR;
}
auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front()
: out_tensors_.front()->GetQuantParams().front();
auto quant_arg = !out_tensors_.front()->GetQuantParams().empty() ? out_tensors_.front()->GetQuantParams().front()
: in_tensors_.front()->GetQuantParams().front();
int ret;
if (uint8_ptr_ == nullptr) {
if (inverse_) {

@ -97,10 +97,6 @@ int ResizeBaseCPUKernel::CheckInputsOuputs() {
MS_LOG(ERROR) << "Resize input num should be no more than" << kMaxInputNum << ", but got " << in_tensors_.size();
return RET_ERROR;
}
auto input = in_tensors_.at(0);
if (input == nullptr) {
return RET_NULL_PTR;
}
if (out_tensors_.size() != kOutputNum) {
MS_LOG(ERROR) << "Resize output num should be " << kOutputNum << ", but got " << out_tensors_.size();
return RET_ERROR;

@ -300,6 +300,15 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
}
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "NEG") == 0) {
MS_LOG(DEBUG) << "parse TfliteNegParser";
auto attr = std::make_unique<schema::NegT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Neg;
op->primitive->value.value = attr.release();
}
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
@ -415,6 +424,7 @@ TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser());
TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser());
TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser());
TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser());
TfliteNodeRegister g_tfliteNegParser("NEG", new TfliteNegParser());
TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser());
TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser());

@ -157,6 +157,11 @@ class TfliteFloorParser : public TfliteSingleInputOpParser {
TfliteFloorParser() : TfliteSingleInputOpParser() {}
};
class TfliteNegParser : public TfliteSingleInputOpParser {
public:
TfliteNegParser() : TfliteSingleInputOpParser() {}
};
class TfliteCompareOpParser : public TfliteNodeParser {
public:
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}

@ -46,7 +46,9 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR;
}
if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8) {
if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) &&
(GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 ||
GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_prelu_parser.h"
#include <vector>
#include <memory>
#include <map>
namespace mindspore {
namespace lite {
STATUS TflitePReLUParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TflitePReLUParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->channelShared = true;
op->primitive->value.type = schema::PrimitiveType_PReLU;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
TfliteNodeRegister g_tflitePReLUParser("PRELU", new TflitePReLUParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PRELU_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PRELU_PARSER_H
#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TflitePReLUParser : public TfliteNodeParser {
public:
TflitePReLUParser() : TfliteNodeParser("PRELU") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PRELU_PARSER_H

@ -46,8 +46,9 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl
MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR;
}
if (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 ||
GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8) {
if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) &&
(GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 ||
GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";

Loading…
Cancel
Save