!9270 remove attr of sparseToDense op

From: @lyvette
Reviewed-by: 
Signed-off-by:
pull/9270/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 311ee57a21

@ -23,16 +23,13 @@ namespace mindspore {
namespace lite {
OpParameter *PopulateSparseToDenseParameter(const mindspore::lite::PrimitiveC *primitive) {
SparseToDenseParameter *sparse_to_dense_param =
reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter)));
auto *sparse_to_dense_param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter)));
if (sparse_to_dense_param == nullptr) {
MS_LOG(ERROR) << "malloc SparseToDenseParameter failed.";
return nullptr;
}
memset(sparse_to_dense_param, 0, sizeof(SparseToDenseParameter));
sparse_to_dense_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::SparseToDense *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
sparse_to_dense_param->validate_indices_ = param->GetValidateIndices();
return reinterpret_cast<OpParameter *>(sparse_to_dense_param);
}

@ -22,16 +22,7 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool SparseToDense::GetValidateIndices() const { return this->primitive_->value.AsSparseToDense()->validateIndices; }
void SparseToDense::SetValidateIndices(bool validate_indices) {
this->primitive_->value.AsSparseToDense()->validateIndices = validate_indices;
}
#else
bool SparseToDense::GetValidateIndices() const { return this->primitive_->value_as_SparseToDense()->validateIndices(); }
#ifndef PRIMITIVE_WRITEABLE
int SparseToDense::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
@ -40,7 +31,7 @@ int SparseToDense::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
MS_LOG(ERROR) << "value_as_SparseToDense return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateSparseToDense(*fbb, attr->validateIndices());
auto val_offset = schema::CreateSparseToDense(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SparseToDense, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;

@ -36,14 +36,12 @@ class SparseToDense : public PrimitiveC {
void SetOutputShape(const std::vector<int> &output_shape);
void SetSparseValue(const std::vector<int> &sparse_value);
void SetDefaultValue(const std::vector<int> &default_value);
void SetValidateIndices(bool validate_indices);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
std::vector<int> GetOutputShape() const;
std::vector<int> GetSparseValue() const;
std::vector<int> GetDefaultValue() const;
bool GetValidateIndices() const;
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite

@ -31,10 +31,4 @@ TEST_F(TestTfliteParserSparseToDense, OpType) {
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SparseToDense) << "wrong Op Type";
}
TEST_F(TestTfliteParserSparseToDense, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense();
ASSERT_EQ(val->validateIndices, false);
}
} // namespace mindspore

@ -35,7 +35,6 @@ PrimitiveC *TfliteSparseToDenseParser::ParseLitePrimitive(const std::unique_ptr<
return nullptr;
}
attr->validateIndices = false;
primitive->value.type = schema::PrimitiveType_SparseToDense;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());

Loading…
Cancel
Save