resolve issue

pull/11608/head
cjh9368 4 years ago
parent 659b5d8e10
commit 1cd3b2fe03

@ -25,6 +25,13 @@
#include "src/param_value_lite.h"
namespace mindspore::lite {
bool IsSkipedLayer(const caffe::LayerParameter &layer) {
if (layer.type() == "Input" || layer.type() == "Dropout" || layer.type() == "Split") {
return true;
}
return layer.include_size() == 1 && layer.include(0).phase() == caffe::TRAIN;
}
CaffeModelParser::CaffeModelParser() = default;
CaffeModelParser::~CaffeModelParser() = default;
@ -68,6 +75,11 @@ STATUS CaffeModelParser::ConvertLayers() {
}
for (int i = 0; i < caffe_model_.layer_size(); i++) {
auto layer = caffe_model_.layer(i);
// save caffe layers
for (int top_idx = 0; top_idx < layer.top_size(); top_idx++) {
caffe_layers_[layer.top(top_idx)] = layer;
}
caffe::LayerParameter weight;
if (weight_layers.find(layer.name()) != weight_layers.end()) {
weight = weight_layers.find(layer.name())->second;
@ -385,11 +397,17 @@ STATUS CaffeModelParser::ConvertBottom(const caffe::LayerParameter &layer, std::
return RET_NULL_PTR;
}
for (int i = 0; i < layer.bottom_size(); i++) {
if (nodes_.find(layer.bottom(i)) == nodes_.end()) {
string origin_layer = GetOriginLayerName(layer.bottom(i));
if (origin_layer.empty()) {
MS_LOG(ERROR) << "layer not found";
return RET_ERROR;
}
if (nodes_.find(origin_layer) == nodes_.end()) {
MS_LOG(ERROR) << "layer bottom " << layer.bottom(i) << " is not found";
return RET_NOT_FIND_OP;
}
input_nodes->emplace_back(nodes_.find(layer.bottom(i))->second);
input_nodes->emplace_back(nodes_.find(origin_layer)->second);
}
return RET_OK;
}
@ -422,11 +440,22 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN
return RET_OK;
}
bool CaffeModelParser::IsSkipedLayer(const caffe::LayerParameter &layer) {
if (layer.type() == "Input" || layer.type() == "Dropout") {
return true;
std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name) {
if (caffe_layers_.find(layer_name) == caffe_layers_.end()) {
return layer_name;
}
return layer.include_size() == 1 && layer.include(0).phase() == caffe::TRAIN;
auto layer = caffe_layers_.at(layer_name);
if (layer.type() != "Split") {
return layer_name;
}
while (layer.type() == "Split") {
string input_name = layer.bottom(0);
if (caffe_layers_.find(input_name) == caffe_layers_.end()) {
return input_name;
}
layer = caffe_layers_.at(input_name);
}
return layer.name();
}
MetaGraphT *CaffeModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,

@ -55,10 +55,11 @@ class CaffeModelParser : public ModelParser {
STATUS ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode);
bool IsSkipedLayer(const caffe::LayerParameter &layer);
std::string GetOriginLayerName(const std::string &layer_name);
caffe::NetParameter caffe_model_;
caffe::NetParameter caffe_weight_;
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
std::unordered_map<std::string, AnfNodePtr> nodes_;
FuncGraphPtr func_graph_ptr_;
};

@ -50,11 +50,9 @@ PrimitiveC *CaffeReduceParser::ParseLitePrimitive(const caffe::LayerParameter &p
std::vector<int32_t> axes;
if (reduce_param.has_axis()) {
axes.push_back(1);
axes.push_back(reduce_param.axis());
axes = std::vector<int>(1, reduce_param.axis());
} else {
axes.push_back(1);
axes.push_back(0);
axes = std::vector<int>(1, 0);
}
attr->axes = axes;

@ -78,18 +78,21 @@ lite::PrimitiveC *OnnxPReluParser::ParseLitePrimitive(const onnx::GraphProto &on
MS_LOG(ERROR) << "input error: params[0] is null";
return nullptr;
}
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
if (slope_size == 1) {
attr->slope.push_back(*slope_raw_data);
attr->channelShared = true;
if (slope->float_data_size() > 0) {
const int64_t slope_size = slope->float_data_size();
for (int64_t i = 0; i < slope_size; i++) {
attr->slope.emplace_back(slope->float_data(i));
}
attr->channelShared = slope_size == 1;
} else {
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
attr->slope.resize(slope_size);
attr->channelShared = false;
if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return nullptr;
}
attr->channelShared = slope_size == 1;
}
} else {
MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors.";

@ -36,7 +36,10 @@ PrimitiveC *TfliteFillParser::ParseLitePrimitive(const std::unique_ptr<tflite::O
}
if (tflite_op->inputs.size() > 1) {
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) {
const auto &tflite_model_buffers = tflite_model->buffers;
const auto &data = tflite_model_buffers.at(tflite_op->inputs[1])->data;
if (!data.empty() &&
GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) {
MS_LOG(ERROR) << "get fill -> dims failed";
return nullptr;
}

@ -81,6 +81,20 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::
return func_graph_;
}
std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) {
std::string tensor_name = op_name + "/input-" + std::to_string(index);
if (op_type == tflite::BuiltinOperator_CONV_2D || op_type == tflite::BuiltinOperator_TRANSPOSE_CONV ||
op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D || op_type == tflite::BuiltinOperator_FULLY_CONNECTED) {
if (index == 1) {
tensor_name = op_name + "/weight";
}
if (index == 2) {
tensor_name = op_name + "/bias";
}
}
return tensor_name;
}
STATUS TfliteModelParser::ConvertOps() {
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
@ -136,18 +150,7 @@ STATUS TfliteModelParser::ConvertOps() {
if (!input_tensor->name.empty()) {
tensor_name = input_tensor->name;
} else {
tensor_name = op_name + "/input-" + std::to_string(op_inputs.size());
if (tflite_op_type == tflite::BuiltinOperator_CONV_2D ||
tflite_op_type == tflite::BuiltinOperator_TRANSPOSE_CONV ||
tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D ||
tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) {
if (i == 1) {
tensor_name = op_name + "/weight";
}
if (i == 2) {
tensor_name = op_name + "/bias";
}
}
tensor_name = GetTensorName(i, tflite_op_type, op_name);
}
auto parameter = func_graph_->add_parameter();
status = ConvertConstTensor(input_tensor.get(), parameter.get(), tensor_name);
@ -155,18 +158,7 @@ STATUS TfliteModelParser::ConvertOps() {
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
continue;
}
if (tflite_op_type == tflite::BuiltinOperator_CONV_2D ||
tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D ||
tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) {
if (op_inputs.size() == 2) {
parameter->set_name(op_name + "/weight");
} else if (op_inputs.size() == 3) {
parameter->set_name(op_name + "/bias");
}
} else {
parameter->set_name(op_name + "/input-" + std::to_string(op_inputs.size() - 1));
}
parameter->set_name(tensor_name);
op_inputs.emplace_back(parameter);
nodes_.insert(std::pair(input_idx, parameter));
}
@ -364,7 +356,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para
MS_LOG(ERROR) << "parameter is null, get const tensor failed.";
return RET_NULL_PTR;
}
const auto &tfliteModelBuffers = tflite_model_->buffers;
const auto &tflite_model_buffers = tflite_model_->buffers;
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
@ -378,7 +370,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para
param_value->set_tensor_shape(tensor->shape);
param_value->set_tensor_type(GetTfliteDataType(tensor->type));
param_value->set_format(schema::Format::Format_NHWC);
const auto &data = tfliteModelBuffers.at(tensor->buffer)->data;
const auto &data = tflite_model_buffers.at(tensor->buffer)->data;
if (!data.empty()) {
auto size = data.size();
char *tensor_data = new (std::nothrow) char[size];

Loading…
Cancel
Save