!9498 [lite]adjust onnx upsample parser

From: @xu_anyue
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
pull/9498/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ab70bffb3a

@ -352,7 +352,7 @@ table FakeQuantWithMinMaxVars {
}
table BiasAdd {
axis: [int];
axis: [int]; // DEPRECATED
}
table ROIPooling {

@ -24,10 +24,6 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BiasAdd::GetAxis() const { return this->primitive_->value.AsBiasAdd()->axis; }
void BiasAdd::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; }
int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
@ -67,21 +63,11 @@ int BiasAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
MS_LOG(ERROR) << "value_as_BiasAdd return nullptr";
return RET_ERROR;
}
std::vector<int32_t> axis;
if (attr->axis() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
axis.push_back(attr->axis()->data()[i]);
}
}
auto val_offset = schema::CreateBiasAddDirect(*fbb, &axis);
auto val_offset = schema::CreateBiasAddDirect(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasAdd, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
std::vector<int> BiasAdd::GetAxis() const {
auto fb_vector = this->primitive_->value_as_BiasAdd()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<BiasAdd>(primitive); }
Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator);

@ -33,11 +33,9 @@ class BiasAdd : public PrimitiveC {
MS_DECLARE_PARENT(BiasAdd, PrimitiveC);
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetAxis(const std::vector<int> &axis);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
std::vector<int> GetAxis() const;
};
} // namespace lite
} // namespace mindspore

@ -38,22 +38,17 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
return RET_NULL_PTR;
}
attr->method = schema::ResizeMethod_NEAREST;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "mode") {
if ("nearest" == onnx_node_attr.s()) {
attr->method = schema::ResizeMethod_NEAREST;
} else if ("bilinear" == onnx_node_attr.s()) {
attr->method = schema::ResizeMethod_LINEAR;
} else {
MS_LOG(ERROR) << "Resize do not support upsample mode";
return RET_ERROR;
if (onnx_node_attr.s() != "nearest" && onnx_node_attr.s() != "linear") {
MS_LOG(ERROR) << "the upsample mode don't support now.";
return RET_NOT_SUPPORT;
}
attr->method = onnx_node_attr.s() == "nearest" ? schema::ResizeMethod_NEAREST : schema::ResizeMethod_LINEAR;
}
}
attr->newWidth = 1;
attr->newHeight = 1;
attr->alignCorners = false;
op->primitive->value.type = schema::PrimitiveType_Resize;
op->primitive->value.value = attr.release();
return RET_OK;

Loading…
Cancel
Save