!9565 remove unique optype attr

From: @cjh9368
Reviewed-by: @hangangqiang
Signed-off-by:
pull/9565/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 7831c337c2

@ -936,7 +936,7 @@ table ScatterND {
} }
table Unique { table Unique {
outType: int; outType: int; // DEPRECATED
} }
table Unstack { table Unstack {

@ -176,7 +176,7 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
} }
flatbuffers::Verifier verify((const uint8_t *)model_buf, size); flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
int schema_version = VersionVerify(&verify); int schema_version = VersionVerify(&verify);
if (schema_version == -1) { if (schema_version == SCHEMA_INVALID) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
return nullptr; return nullptr;
} }

@ -22,14 +22,7 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
#ifdef PRIMITIVE_WRITEABLE #ifndef PRIMITIVE_WRITEABLE
int Unique::GetOutType() const { return this->primitive_->value.AsUnique()->outType; }
void Unique::SetOutType(int out_type) { this->primitive_->value.AsUnique()->outType = out_type; }
#else
int Unique::GetOutType() const { return this->primitive_->value_as_Unique()->outType(); }
int Unique::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { int Unique::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb); MS_ASSERT(nullptr != fbb);

@ -32,14 +32,12 @@ class Unique : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Unique, PrimitiveC); MS_DECLARE_PARENT(Unique, PrimitiveC);
explicit Unique(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} explicit Unique(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetOutType(int out_type);
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetOutType() const;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -34,8 +34,6 @@ TEST_F(TestTfliteParserUnique, OpType) {
TEST_F(TestTfliteParserUnique, AttrValue) { TEST_F(TestTfliteParserUnique, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsUnique();
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
ASSERT_EQ(val->outType, 34);
} }
} // namespace mindspore } // namespace mindspore

@ -182,7 +182,7 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
parameter->set_abstract(abstract_tensor); parameter->set_abstract(abstract_tensor);
parameter->set_name("graph-input-" + std::to_string(i)); parameter->set_name("graph_input-" + std::to_string(i));
nodes_.insert(std::pair(layer.top(0), parameter)); nodes_.insert(std::pair(layer.top(0), parameter));
return RET_OK; return RET_OK;
} }
@ -205,7 +205,7 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
parameter->set_abstract(abstract_tensor); parameter->set_abstract(abstract_tensor);
parameter->set_name("graph-input-" + caffe_model_.input(i)); parameter->set_name("graph_input-" + caffe_model_.input(i));
nodes_.insert(std::pair(caffe_model_.input(i), parameter)); nodes_.insert(std::pair(caffe_model_.input(i), parameter));
} }
} else { } else {
@ -219,7 +219,7 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor); parameter->set_abstract(abstract_tensor);
parameter->set_name("graph-input-" + caffe_model_.input(i)); parameter->set_name("graph_input-" + caffe_model_.input(i));
nodes_.insert(std::pair(caffe_model_.input(i), parameter)); nodes_.insert(std::pair(caffe_model_.input(i), parameter));
} }
} }

@ -153,6 +153,18 @@ STATUS TfliteModelParser::ConvertOps() {
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
continue; continue;
} }
if (tflite_op_type == tflite::BuiltinOperator_CONV_2D ||
tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D ||
tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) {
if (op_inputs.size() == 2) {
parameter->set_name(op_name + "/weight");
} else if (op_inputs.size() == 3) {
parameter->set_name(op_name + "/bias");
}
} else {
parameter->set_name(op_name + "/input-" + std::to_string(op_inputs.size() - 1));
}
op_inputs.emplace_back(parameter); op_inputs.emplace_back(parameter);
nodes_.insert(std::pair(input_idx, parameter)); nodes_.insert(std::pair(input_idx, parameter));
} }

@ -40,7 +40,6 @@ PrimitiveC *TfliteUniqueParser::ParseLitePrimitive(const std::unique_ptr<tflite:
MS_LOG(ERROR) << "get op unique attr failed"; MS_LOG(ERROR) << "get op unique attr failed";
return nullptr; return nullptr;
} }
attr->outType = GetTfliteDataType(tflite_attr->idx_out_type);
primitive->value.type = schema::PrimitiveType_Unique; primitive->value.type = schema::PrimitiveType_Unique;
primitive->value.value = attr.release(); primitive->value.value = attr.release();

Loading…
Cancel
Save