From c1ce164e42f368887886736ccc3b0027ed27b2e5 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Tue, 17 Nov 2020 14:53:34 +0800 Subject: [PATCH] adjust mindir model attr changes --- mindspore/lite/src/ops/bias_add.cc | 2 +- mindspore/lite/src/ops/bias_grad.cc | 2 +- mindspore/lite/src/ops/concat.cc | 2 +- mindspore/lite/src/ops/conv2d.cc | 22 +++++++------- mindspore/lite/src/ops/conv2d_grad_filter.cc | 12 ++++---- mindspore/lite/src/ops/conv2d_grad_input.cc | 12 ++++---- mindspore/lite/src/ops/deconv2d.cc | 22 +++++++------- mindspore/lite/src/ops/depthwise_conv2d.cc | 12 ++++---- mindspore/lite/src/ops/expand_dims.cc | 2 +- mindspore/lite/src/ops/gather.cc | 2 +- mindspore/lite/src/ops/one_hot.cc | 2 +- mindspore/lite/src/ops/pooling.cc | 4 +-- mindspore/lite/src/ops/pooling_grad.cc | 4 +-- mindspore/lite/src/ops/primitive_c.cc | 29 +++++++++++++++++++ mindspore/lite/src/ops/primitive_c.h | 2 ++ mindspore/lite/src/ops/reduce.cc | 2 +- mindspore/lite/src/ops/reshape.cc | 2 +- mindspore/lite/src/ops/resize.cc | 2 +- mindspore/lite/src/ops/softmax.cc | 2 +- mindspore/lite/src/ops/squeeze.cc | 2 +- mindspore/lite/src/ops/strided_slice.cc | 10 +++---- mindspore/lite/src/ops/tile.cc | 4 +-- .../lite/src/ops/unsorted_segment_sum.cc | 2 +- .../lite/tools/anf_exporter/anf_exporter.cc | 4 ++- 24 files changed, 97 insertions(+), 64 deletions(-) diff --git a/mindspore/lite/src/ops/bias_add.cc b/mindspore/lite/src/ops/bias_add.cc index 849738cbb3..a8863264cb 100644 --- a/mindspore/lite/src/ops/bias_add.cc +++ b/mindspore/lite/src/ops/bias_add.cc @@ -51,7 +51,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector &in MS_LOG(INFO) << "BiasAdd's attr axis is set to default"; attr->axis = {1}; } else { - attr->axis = GetValue>(prim.GetAttr("axis")); + attr->axis = CastToInt(prim.GetAttr("axis"), true); } this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/bias_grad.cc b/mindspore/lite/src/ops/bias_grad.cc index 27cc25bca0..35ae4903e8 100644 --- a/mindspore/lite/src/ops/bias_grad.cc +++ b/mindspore/lite/src/ops/bias_grad.cc @@ -49,7 +49,7 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector &i MS_LOG(WARNING) << "get axis failed"; attr->axis = {0}; } else { - attr->axis = GetValue>(prim.GetAttr("axis")); + attr->axis = CastToInt(prim.GetAttr("axis"), true); } this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index 898adca7d8..aa739c61e7 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -51,7 +51,7 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector &inp MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } - auto prim_axis = GetValue(prim.GetAttr("axis")); + auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front(); attr->axis = prim_axis; this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 3e663955b9..d28618cee9 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -139,21 +139,21 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT } else { attr->format = schema::Format::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(prim.GetAttr("pad_list")); + auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(prim.GetAttr("dilation")); + auto dilation = CastToInt(prim.GetAttr("dilation"), true); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = GetValue>(prim.GetAttr("stride")); + auto stride = CastToInt(prim.GetAttr("stride"), true); attr->strideH = stride[2]; attr->strideW = stride[3]; @@ -175,7 +175,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT int channel_mutiplier = 1; if (prim.GetAttr("channel_mutiplier") != nullptr) { - channel_mutiplier = GetValue(prim.GetAttr("channel_multiplier")); + channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); } attr->channelMultiplier = channel_mutiplier; @@ -212,25 +212,25 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive } else { attr->format = schema::Format::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(prim.GetAttr("pad_list")); + auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(prim.GetAttr("dilation")); + auto dilation = CastToInt(prim.GetAttr("dilation"), true); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = GetValue>(prim.GetAttr("stride")); + auto stride = CastToInt(prim.GetAttr("stride"), true); attr->strideH = stride[2]; attr->strideW = stride[3]; - attr->channelOut = GetValue(prim.GetAttr("out_channel")); + attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); auto pad_mode = GetValue(prim.GetAttr("pad_mode")); if (pad_mode == "valid") { @@ -270,7 +270,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector &inp MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model"; return RET_NULL_PTR; } - int group = GetValue(groupAttr); + int group = CastToInt(groupAttr, false).front(); if (group > 1) { PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); } else { diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index 2199c9f0b4..aa403d2e45 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -94,7 +94,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vectorgroup = GetValue(prim.GetAttr("group")); + attr->group = CastToInt(prim.GetAttr("group"), false).front(); auto format = GetValue(prim.GetAttr("data_format")); if (format == "NCHW") { attr->format = schema::Format_NCHW; @@ -103,25 +103,25 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vectorformat = schema::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(prim.GetAttr("pad_list")); + auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(prim.GetAttr("dilation")); + auto dilation = CastToInt(prim.GetAttr("dilation"), true); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = GetValue>(prim.GetAttr("stride")); + auto stride = CastToInt(prim.GetAttr("stride"), true); attr->strideH = stride[0]; attr->strideW = stride[1]; - attr->channelOut = GetValue(prim.GetAttr("out_channel")); + attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); auto pad_mode = GetValue(prim.GetAttr("pad_mode")); if (pad_mode == "valid") { attr->padMode = schema::PadMode_VALID; diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index 37889e5bb7..26a85610c4 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -92,7 +92,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vectorgroup = GetValue(prim.GetAttr("group")); + attr->group = CastToInt(prim.GetAttr("group"), false).front(); if (attr->group > 1) { this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput; } @@ -104,25 +104,25 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vectorformat = schema::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(prim.GetAttr("pad_list")); + auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(prim.GetAttr("dilation")); + auto dilation = CastToInt(prim.GetAttr("dilation"), true); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = GetValue>(prim.GetAttr("stride")); + auto stride = CastToInt(prim.GetAttr("stride"), true); attr->strideH = stride[0]; attr->strideW = stride[1]; - attr->channelOut = GetValue(prim.GetAttr("out_channel")); + attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); auto pad_mode = GetValue(prim.GetAttr("pad_mode")); if (pad_mode == "valid") { diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 641372f93a..41300e6b33 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -132,21 +132,21 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv } else { attr->format = schema::Format::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(prim.GetAttr("pad_list")); + auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(prim.GetAttr("dilation")); + auto dilation = CastToInt(prim.GetAttr("dilation"), true); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = GetValue>(prim.GetAttr("stride")); + auto stride = CastToInt(prim.GetAttr("stride"), true); attr->strideH = stride[0]; attr->strideW = stride[1]; @@ -168,7 +168,7 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv int channel_mutiplier = 1; if (prim.GetAttr("channel_mutiplier") != nullptr) { - channel_mutiplier = GetValue(prim.GetAttr("channel_multiplier")); + channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); } attr->channelMultiplier = channel_mutiplier; @@ -195,25 +195,25 @@ void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::Primi } else { attr->format = schema::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(prim.GetAttr("pad_list")); + auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(prim.GetAttr("dilation")); + auto dilation = CastToInt(prim.GetAttr("dilation"), true); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = GetValue>(prim.GetAttr("stride")); + auto stride = CastToInt(prim.GetAttr("stride"), true); attr->strideH = stride[0]; attr->strideW = stride[1]; - attr->channelOut = GetValue(prim.GetAttr("out_channel")); + attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); auto pad_mode = GetValue(prim.GetAttr("pad_mode")); if (pad_mode == "valid" || pad_mode == "VALID") { @@ -248,7 +248,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector &i MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; return RET_ERROR; } - int group = GetValue(prim.GetAttr("group")); + int group = CastToInt(prim.GetAttr("group"), false).front(); if (group == 1) { PopulaterDeConv2DSingleGroup(prim, this->primitive_, group); } else if (group > 1) { diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index f54bd3141f..4bc0765bb1 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -86,27 +86,27 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vectorformat = schema::Format::Format_NUM_OF_FORMAT; } - auto pad_list = GetValue>(prim.GetAttr("pads")); + auto pad_list = CastToInt(prim.GetAttr("pads"), true); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = GetValue>(prim.GetAttr("dilation")); + auto dilation = CastToInt(prim.GetAttr("dilation"), true); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; if (utils::isa(prim.GetAttr("kernel_size"))) { - auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; } else { - auto kernel_size = GetValue(prim.GetAttr("kernel_size")); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), false).front(); attr->kernelH = kernel_size; attr->kernelW = kernel_size; } - auto stride = GetValue>(prim.GetAttr("stride")); + auto stride = CastToInt(prim.GetAttr("stride"), true); attr->strideH = stride[2]; attr->strideW = stride[3]; @@ -124,7 +124,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vectoractivationType = schema::ActivationType_NO_ACTIVATION; } - auto channel_multiplier = GetValue(prim.GetAttr("channel_multiplier")); + auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); attr->channelMultiplier = channel_multiplier; MS_ASSERT(inputs.size() == kAnfPopulaterTwo); diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index ee8bc767cc..d15438ceef 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -53,7 +53,7 @@ int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector // use axis instead of dim if (inputs[1]->isa()) { auto axis_tensor = inputs[1]->cast(); - int axis = GetValue(axis_tensor->value()); + int axis = CastToInt(axis_tensor->value(), false).front(); attr->dim = axis; } else { MS_LOG(ERROR) << "input axis is not value node."; diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index 16a966099a..5777a81ba0 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -59,7 +59,7 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector &inp } if (inputs[2]->isa()) { ValueNodePtr axis_tensor = inputs[2]->cast(); - int axis = GetValue(axis_tensor->value()); + int axis = CastToInt(axis_tensor->value(), false).front(); gather_attr->axis = axis; } else { MS_LOG(ERROR) << "input axis is not value node."; diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index 93f1857bcd..e979897647 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -48,7 +48,7 @@ int OneHot::UnPackAttr(const Primitive &prim, const std::vector &inp } attr->axis = -1; if (prim.GetAttr("axis") != nullptr) { - attr->axis = GetValue(prim.GetAttr("axis")); + attr->axis = CastToInt(prim.GetAttr("axis"), false).front(); } this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index 3e9ceb1b40..e5ea31cc1d 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -110,11 +110,11 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector &in attr->padMode = schema::PadMode_NOTSET; } - auto kernel_size = GetValue>(prim.GetAttr("ksize")); + auto kernel_size = CastToInt(prim.GetAttr("ksize"), true); attr->windowH = kernel_size[2]; attr->windowW = kernel_size[3]; - auto stride = GetValue>(prim.GetAttr("strides")); + auto stride = CastToInt(prim.GetAttr("strides"), true); attr->strideH = stride[2]; attr->strideW = stride[3]; this->primitive_->value.value = attr; diff --git a/mindspore/lite/src/ops/pooling_grad.cc b/mindspore/lite/src/ops/pooling_grad.cc index 9d2e633d2b..9a2dd6a853 100644 --- a/mindspore/lite/src/ops/pooling_grad.cc +++ b/mindspore/lite/src/ops/pooling_grad.cc @@ -99,11 +99,11 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector attr->padMode = schema::PadMode_NOTSET; } - auto kernel_size = GetValue>(prim.GetAttr("ksize")); + auto kernel_size = CastToInt(prim.GetAttr("ksize"), true); attr->windowH = kernel_size[2]; attr->windowW = kernel_size[3]; - auto stride = GetValue>(prim.GetAttr("strides")); + auto stride = CastToInt(prim.GetAttr("strides"), true); attr->strideH = stride[2]; attr->strideW = stride[3]; this->primitive_->value.value = attr; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 5cef1e50d0..15d7c2c541 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -180,6 +180,35 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE +std::vector CastToInt(const ValuePtr value, bool is_vector) { + if (value == nullptr) { + MS_LOG(WARNING) << "valueptr is nullptr."; + return {}; + } + std::vector cur_value; + if (is_vector) { + if (!utils::isa(value)) { + MS_LOG(WARNING) << "valueptr is not a sequence, value may be a scalar."; + return {}; + } + if (value->cast()->value().front()->type()->type_name() == "Int64Imm") { + auto origin_value = GetValue>(value); + for (size_t index = 0; index < origin_value.size(); ++index) { + cur_value.push_back(static_cast(origin_value[index])); + } + } else { + cur_value = GetValue>(value); + } + } else { + if (value->type_name() == "Int64Imm") { + cur_value.push_back(static_cast(GetValue(value))); + } else { + cur_value.push_back(GetValue(value)); + } + } + return cur_value; +} + void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) { const float qmin = 0; const float qmax = 255; diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 542157d231..5efaba4ab3 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -52,6 +52,8 @@ static std::map kActivationTypeMap{{"ReLU", {"Sigmoid", schema::ActivationType_SIGMOID}, {"HSwish", schema::ActivationType_HSWISH}, {"HSigmoid", schema::ActivationType_HSIGMOID}}; +std::vector CastToInt(const ValuePtr value, bool is_vector); + class PrimitiveC : public mindspore::Primitive { public: // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index 03c5eea50b..205092f20a 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -87,7 +87,7 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector &inp attr->axes.emplace_back(elem->value()); } } else { - int axes_item = GetValue(value); + int axes_item = CastToInt(value, false).front(); attr->axes.push_back(axes_item); } } diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index 1a0f11a640..4203375265 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -63,7 +63,7 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector &in attr->shape.emplace_back(static_cast(elem->value())); } } else { - int dim = GetValue(val); + int dim = CastToInt(val, false).front(); attr->shape = {dim}; } } diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index a2638be972..9f5e21df97 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -67,7 +67,7 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector &inp MS_LOG(ERROR) << "wrong resize type"; return RET_ERROR; } - std::vector targetSize = GetValue>(prim.GetAttr("size")); + std::vector targetSize = CastToInt(prim.GetAttr("size"), true); attr->newHeight = targetSize[0]; attr->newWidth = targetSize[1]; attr->alignCorners = GetValue(prim.GetAttr("align_corners")); diff --git a/mindspore/lite/src/ops/softmax.cc b/mindspore/lite/src/ops/softmax.cc index 099c6ee3ac..d079f9ed9c 100644 --- a/mindspore/lite/src/ops/softmax.cc +++ b/mindspore/lite/src/ops/softmax.cc @@ -43,7 +43,7 @@ int SoftMax::UnPackAttr(const Primitive &prim, const std::vector &in MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } - auto prim_axis = GetValue(prim.GetAttr("axis")); + auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front(); attr->axis = prim_axis; this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc index eaa1c654dd..faef4bf90e 100644 --- a/mindspore/lite/src/ops/squeeze.cc +++ b/mindspore/lite/src/ops/squeeze.cc @@ -50,7 +50,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector &in MS_LOG(INFO) << "Squeeze's attr xis is set to default"; attr->axis = {0}; } else { - attr->axis = GetValue>(prim.GetAttr("axis")); + attr->axis = CastToInt(prim.GetAttr("axis"), true); } this->primitive_->value.value = attr; } diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index e40bde14bf..b5075f3549 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -73,11 +73,11 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vectorbeginMask = GetValue(prim.GetAttr("begin_mask")); - attr->endMask = GetValue(prim.GetAttr("end_mask")); - attr->ellipsisMask = GetValue(prim.GetAttr("ellipsis_mask")); - attr->newAxisMask = GetValue(prim.GetAttr("new_axis_mask")); - attr->shrinkAxisMask = GetValue(prim.GetAttr("shrink_axis_mask")); + attr->beginMask = CastToInt(prim.GetAttr("begin_mask"), false).front(); + attr->endMask = CastToInt(prim.GetAttr("end_mask"), false).front(); + attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask"), false).front(); + attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask"), false).front(); + attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask"), false).front(); auto inputNodeFirst = inputs[kAnfPopulaterOne]; std::vector beginVec; GetAttrDataFromInput(inputNodeFirst, &beginVec); diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 8b18d709fe..303be8ed37 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -56,7 +56,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector &input MS_LOG(INFO) << "Tile's attr dims is set to default"; attr->dims = {1}; } else { - attr->dims = GetValue>(prim.GetAttr("dims")); + attr->dims = CastToInt(prim.GetAttr("dims"), true); } if (inputs.size() == kAnfPopulaterTwo) { auto inputNode = inputs[kAnfPopulaterOne]; @@ -75,7 +75,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector &input attr->multiples.emplace_back(elem->value()); } } else { - int multiple = GetValue(value); + int multiple = CastToInt(value, false).front(); attr->multiples = {multiple}; } } diff --git a/mindspore/lite/src/ops/unsorted_segment_sum.cc b/mindspore/lite/src/ops/unsorted_segment_sum.cc index 3356128d5d..5394472200 100644 --- a/mindspore/lite/src/ops/unsorted_segment_sum.cc +++ b/mindspore/lite/src/ops/unsorted_segment_sum.cc @@ -48,7 +48,7 @@ int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector attr = std::make_unique(); if (inputs[2]->isa()) { ValuePtr value = inputs[2]->cast()->value(); - attr->numSegments = GetValue(value); + attr->numSegments = CastToInt(value, false).front(); this->primitive_->value.value = attr.release(); } } diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 61799f64b5..11895b0583 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -314,7 +314,9 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr input_anode, s return RET_ERROR; } auto input_index_key = - get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(GetValue(value_node->value())); + get_item_input_cnode->fullname_with_scope() + "_o:" + + std::to_string(value_node->value()->type_name() == "Int64Imm" ? GetValue(value_node->value()) + : GetValue(value_node->value())); auto iter = node_id_map_.find(input_index_key); if (iter == node_id_map_.end()) { #ifdef SUPPORT_TRAIN