diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 2bc0b01e27..c0b008be42 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -244,7 +244,7 @@ table Conv2DGradFilter { dilateW: int; dilateH: int; hasBias: bool = false; // DEPRECATED - filter_shape: [int]; + filter_shape: [int]; // DEPRECATED activationType: ActivationType = 0; } @@ -265,7 +265,7 @@ table Conv2DGradInput { dilateW: int; dilateH: int; hasBias: bool = false; // DEPRECATED - input_shape: [int]; + input_shape: [int]; // DEPRECATED activationType: ActivationType = 0; } diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index 422b99c5b4..3963d962d4 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -219,11 +219,11 @@ Registry conv2DGradFilterRegistry(schema::PrimitiveType_Conv2DGradFilter, Conv2D #endif int Conv2DGradFilter::InferShape(std::vector inputs, std::vector outputs) { - if (2 != inputs.size()) { - MS_LOG(ERROR) << "Conv2d Grad Filter should have 2 inputs, but it got " << inputs.size(); + if (inputs.size() < 2) { + MS_LOG(ERROR) << "Conv2d Grad Filter should be at least two input, but it got " << inputs.size(); return RET_ERROR; } - if (1 != outputs.size()) { + if (outputs.size() != 1) { MS_LOG(ERROR) << "Conv2d Grad Filter should have one output but it got " << outputs.size(); return RET_ERROR; } diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index 3ac88ec571..83c8f88b95 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -220,11 +220,11 @@ Registry Conv2DGradInputRegistry(schema::PrimitiveType_Conv2DGradInput, Conv2DGr #endif int Conv2DGradInput::InferShape(std::vector inputs, std::vector outputs) { - if (2 != inputs.size()) { - MS_LOG(ERROR) << "Conv2d Grad Input should have 2 inputs"; + if (inputs.size() < 2) { + MS_LOG(ERROR) << "Conv2d Grad Input should be at least two input"; return RET_ERROR; } - if (1 != outputs.size()) { + if (outputs.size() != 1) { MS_LOG(ERROR) << "Conv2d Grad output should have one output"; return RET_ERROR; } diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index 027366c7cd..f4a1da13ee 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -98,8 +98,8 @@ Registry GatherRegistry(schema::PrimitiveType_Gather, GatherCreator); int Gather::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); - if (inputs_.size() != kDoubleNum) { - MS_LOG(DEBUG) << "Gather should have two inputs"; + if (inputs_.size() < kDoubleNum) { + MS_LOG(DEBUG) << "Gather should be at least two inputs"; } if (outputs_.size() != kSingleNum) { MS_LOG(ERROR) << "Gather should have one outputs"; diff --git a/mindspore/lite/src/ops/group_conv2d_grad_input.cc b/mindspore/lite/src/ops/group_conv2d_grad_input.cc index ce4bff0aef..7858392340 100644 --- a/mindspore/lite/src/ops/group_conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/group_conv2d_grad_input.cc @@ -146,8 +146,8 @@ Registry GroupConv2DGradInputRegistry(schema::PrimitiveType_GroupConv2DGradInput #endif int GroupConv2DGradInput::InferShape(std::vector inputs, std::vector outputs) { - if (2 != inputs.size()) { - MS_LOG(ERROR) << "Conv2d Grad input should have 2 inputs"; + if (inputs.size() < 2) { + MS_LOG(ERROR) << "Conv2d Grad input should be at least two input"; return RET_ERROR; } if (1 != outputs.size()) { diff --git a/mindspore/lite/test/models_mindspore.cfg b/mindspore/lite/test/models_mindspore.cfg index d10996b213..d3fb96061c 100644 --- a/mindspore/lite/test/models_mindspore.cfg +++ b/mindspore/lite/test/models_mindspore.cfg @@ -1,6 +1,6 @@ ssd.mindir 1.5 mobilenetv2_438.mindir 1.5 -gate_u_net_small-1_110.mindir 1.5 +#gate_u_net_small-1_110.mindir 1.5 shufflenetv2.mindir 1.5 #inceptionv3.mindir 1.5 googlenet.mindir 1.5 diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 107e68d1ec..f295bde218 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -571,8 +571,34 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); meta_graphT->allTensors.emplace_back(std::move(paramTensor)); } else if (value->isa()) { - MS_LOG(DEBUG) << "Value type is ValueSequence."; - return RET_OK; + auto valueAbstract = valueNode->abstract(); + auto abstractSequnce = utils::cast(valueAbstract); + if (abstractSequnce->isa()) { + auto abstractTuple = utils::cast(valueAbstract); + auto x_shape_data = abstractTuple->elements(); + std::vector shape; + for (std::size_t i = 0; i < abstractTuple->size(); ++i) { + auto value_track = x_shape_data[i]->GetValueTrack(); + MS_ASSERT(value_track != nullptr); + if (value_track->isa()) { + shape.push_back((GetValue(value_track))); + } else if (value_track->isa()) { + shape.push_back((GetValue(value_track))); + } else { + MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << "."; + return RET_ERROR; + } + } + auto typePtr = abstractTuple->elements()[0]->GetTypeTrack(); + paramTensor->dataType = kNumberTypeInt32; + paramTensor->dims = {static_cast(shape.size())}; + paramTensor->nodeType = schema::NodeType_ValueNode; + paramTensor->data.resize(shape.size() * sizeof(int)); + memcpy(paramTensor->data.data(), shape.data(), shape.size() * sizeof(int)); + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + } } else if (value->isa()) { auto valueAbstract = valueNode->abstract(); auto abstractScalar = utils::cast(valueAbstract); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc index f8840632ac..34846361c0 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc @@ -36,8 +36,7 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) { if (node->quantType == schema::QuantType_AwareTraining) { continue; } - if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || - GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { + if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { MS_ASSERT(false); } auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));