diff --git a/mindspore/lite/src/ops/bias_add.cc b/mindspore/lite/src/ops/bias_add.cc index dc8c4f933f..13ecc02d4e 100644 --- a/mindspore/lite/src/ops/bias_add.cc +++ b/mindspore/lite/src/ops/bias_add.cc @@ -51,7 +51,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector &in MS_LOG(INFO) << "BiasAdd's attr axis is set to default"; attr->axis = {1}; } else { - attr->axis = CastToInt(prim.GetAttr("axis"), true); + attr->axis = CastToInt(prim.GetAttr("axis")); } this->primitive_->value.value = attr; } diff --git a/mindspore/lite/src/ops/bias_grad.cc b/mindspore/lite/src/ops/bias_grad.cc index 85017d3561..95385a0534 100644 --- a/mindspore/lite/src/ops/bias_grad.cc +++ b/mindspore/lite/src/ops/bias_grad.cc @@ -49,7 +49,7 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector &i MS_LOG(WARNING) << "get axis failed"; attr->axis = {0}; } else { - attr->axis = CastToInt(prim.GetAttr("axis"), true); + attr->axis = CastToInt(prim.GetAttr("axis")); } this->primitive_->value.value = attr; } diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index 457f8de8e6..d119cb28b6 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -51,7 +51,7 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector &inp MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } - auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front(); + auto prim_axis = CastToInt(prim.GetAttr("axis")).front(); attr->axis = prim_axis; this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 2e673b1aad..ff59613349 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -143,21 +143,21 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT } else { attr->format = schema::Format::Format_NUM_OF_FORMAT; } - auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); + auto pad_list = CastToInt(prim.GetAttr("pad_list")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = CastToInt(prim.GetAttr("dilation"), true); + auto dilation = CastToInt(prim.GetAttr("dilation")); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = CastToInt(prim.GetAttr("stride"), true); + auto stride = CastToInt(prim.GetAttr("stride")); attr->strideH = stride[2]; attr->strideW = stride[3]; @@ -179,7 +179,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT int channel_mutiplier = 1; if (prim.GetAttr("channel_mutiplier") != nullptr) { - channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); + channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier")).front(); } attr->channelMultiplier = channel_mutiplier; @@ -220,25 +220,25 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive } else { attr->format = schema::Format::Format_NUM_OF_FORMAT; } - auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); + auto pad_list = CastToInt(prim.GetAttr("pad_list")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = CastToInt(prim.GetAttr("dilation"), true); + auto dilation = CastToInt(prim.GetAttr("dilation")); attr->dilateH = dilation[2]; attr->dilateW = dilation[3]; - auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = CastToInt(prim.GetAttr("stride"), true); + auto stride = CastToInt(prim.GetAttr("stride")); attr->strideH = stride[2]; attr->strideW = stride[3]; - attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); + attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); auto pad_mode = GetValue(prim.GetAttr("pad_mode")); if (pad_mode == "valid") { @@ -278,7 +278,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector &inp MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model"; return RET_NULL_PTR; } - int group = CastToInt(groupAttr, false).front(); + int group = CastToInt(groupAttr).front(); if (group > 1) { PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); } else { diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index 6556d3b86f..e7bd75da16 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -94,7 +94,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vectorgroup = CastToInt(prim.GetAttr("group"), false).front(); + attr->group = CastToInt(prim.GetAttr("group")).front(); auto format = GetValue(prim.GetAttr("data_format")); if (format == "NCHW") { attr->format = schema::Format_NCHW; @@ -103,25 +103,25 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vectorformat = schema::Format_NUM_OF_FORMAT; } - auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); + auto pad_list = CastToInt(prim.GetAttr("pad_list")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = CastToInt(prim.GetAttr("dilation"), true); + auto dilation = CastToInt(prim.GetAttr("dilation")); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = CastToInt(prim.GetAttr("stride"), true); + auto stride = CastToInt(prim.GetAttr("stride")); attr->strideH = stride[0]; attr->strideW = stride[1]; - attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); + attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); auto pad_mode = GetValue(prim.GetAttr("pad_mode")); if (pad_mode == "valid") { attr->padMode = schema::PadMode_VALID; @@ -154,7 +154,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vectorsize(); i++) { auto elem = (*valTuplPtr)[i]; MS_ASSERT(elem != nullptr); - attr->filter_shape[nchw2nhwc[i]] = CastToInt(elem, false).front(); + attr->filter_shape[nchw2nhwc[i]] = CastToInt(elem).front(); } } } diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index 0a8501dc47..1173a659f9 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -92,7 +92,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vectorgroup = CastToInt(prim.GetAttr("group"), false).front(); + attr->group = CastToInt(prim.GetAttr("group")).front(); if (attr->group > 1) { this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput; } @@ -104,25 +104,25 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vectorformat = schema::Format_NUM_OF_FORMAT; } - auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); + auto pad_list = CastToInt(prim.GetAttr("pad_list")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = CastToInt(prim.GetAttr("dilation"), true); + auto dilation = CastToInt(prim.GetAttr("dilation")); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = CastToInt(prim.GetAttr("stride"), true); + auto stride = CastToInt(prim.GetAttr("stride")); attr->strideH = stride[0]; attr->strideW = stride[1]; - attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); + attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); auto pad_mode = GetValue(prim.GetAttr("pad_mode")); if (pad_mode == "valid") { @@ -156,7 +156,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vectorsize(); i++) { auto elem = (*valTuplPtr)[i]; MS_ASSERT(elem != nullptr); - attr->input_shape[nchw2nhwc[i]] = CastToInt(elem, false).front(); + attr->input_shape[nchw2nhwc[i]] = CastToInt(elem).front(); } } } diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 19f1f90a8f..722bc8ee57 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -136,21 +136,21 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv } else { attr->format = schema::Format::Format_NUM_OF_FORMAT; } - auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); + auto pad_list = CastToInt(prim.GetAttr("pad_list")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = CastToInt(prim.GetAttr("dilation"), true); + auto dilation = CastToInt(prim.GetAttr("dilation")); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = CastToInt(prim.GetAttr("stride"), true); + auto stride = CastToInt(prim.GetAttr("stride")); attr->strideH = stride[0]; attr->strideW = stride[1]; @@ -172,7 +172,7 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv int channel_mutiplier = 1; if (prim.GetAttr("channel_mutiplier") != nullptr) { - channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); + channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier")).front(); } attr->channelMultiplier = channel_mutiplier; @@ -203,25 +203,25 @@ void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::Primi } else { attr->format = schema::Format_NUM_OF_FORMAT; } - auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); + auto pad_list = CastToInt(prim.GetAttr("pad_list")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = CastToInt(prim.GetAttr("dilation"), true); + auto dilation = CastToInt(prim.GetAttr("dilation")); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; - auto stride = CastToInt(prim.GetAttr("stride"), true); + auto stride = CastToInt(prim.GetAttr("stride")); attr->strideH = stride[0]; attr->strideW = stride[1]; - attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); + attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); auto pad_mode = GetValue(prim.GetAttr("pad_mode")); if (pad_mode == "valid" || pad_mode == "VALID") { @@ -256,7 +256,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector &i MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; return RET_ERROR; } - int group = CastToInt(prim.GetAttr("group"), false).front(); + int group = CastToInt(prim.GetAttr("group")).front(); if (group == 1) { PopulaterDeConv2DSingleGroup(prim, this->primitive_, group); } else if (group > 1) { diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index fe66ec13e4..5ab48bfb4f 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -86,27 +86,27 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vectorformat = schema::Format::Format_NUM_OF_FORMAT; } - auto pad_list = CastToInt(prim.GetAttr("pads"), true); + auto pad_list = CastToInt(prim.GetAttr("pads")); attr->padUp = pad_list[0]; attr->padDown = pad_list[1]; attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; - auto dilation = CastToInt(prim.GetAttr("dilation"), true); + auto dilation = CastToInt(prim.GetAttr("dilation")); attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; if (utils::isa(prim.GetAttr("kernel_size"))) { - auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; } else { - auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), false).front(); + auto kernel_size = CastToInt(prim.GetAttr("kernel_size")).front(); attr->kernelH = kernel_size; attr->kernelW = kernel_size; } - auto stride = CastToInt(prim.GetAttr("stride"), true); + auto stride = CastToInt(prim.GetAttr("stride")); attr->strideH = stride[2]; attr->strideW = stride[3]; @@ -124,7 +124,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vectoractivationType = schema::ActivationType_NO_ACTIVATION; } - auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); + auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier")).front(); attr->channelMultiplier = channel_multiplier; MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo); diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index 35d5840974..bbeef87bbc 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -53,7 +53,7 @@ int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector // use axis instead of dim if (inputs[1]->isa()) { auto axis_tensor = inputs[1]->cast(); - int axis = CastToInt(axis_tensor->value(), false).front(); + int axis = CastToInt(axis_tensor->value()).front(); attr->dim = axis; } else { MS_LOG(ERROR) << "input axis is not value node."; diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index 99b97c3205..6e13952cec 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -59,7 +59,7 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector &inp } if (inputs[2]->isa()) { ValueNodePtr axis_tensor = inputs[2]->cast(); - int axis = CastToInt(axis_tensor->value(), false).front(); + int axis = CastToInt(axis_tensor->value()).front(); gather_attr->axis = axis; } else { MS_LOG(ERROR) << "input axis is not value node."; diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index b3bd9c9672..440e7959a9 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -48,7 +48,7 @@ int OneHot::UnPackAttr(const Primitive &prim, const std::vector &inp } attr->axis = -1; if (prim.GetAttr("axis") != nullptr) { - attr->axis = CastToInt(prim.GetAttr("axis"), false).front(); + attr->axis = CastToInt(prim.GetAttr("axis")).front(); } this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index 30a9f54b31..329665e807 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -110,11 +110,11 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector &in attr->padMode = schema::PadMode_NOTSET; } - auto kernel_size = CastToInt(prim.GetAttr("ksize"), true); + auto kernel_size = CastToInt(prim.GetAttr("ksize")); attr->windowH = kernel_size[2]; attr->windowW = kernel_size[3]; - auto stride = CastToInt(prim.GetAttr("strides"), true); + auto stride = CastToInt(prim.GetAttr("strides")); attr->strideH = stride[2]; attr->strideW = stride[3]; this->primitive_->value.value = attr; diff --git a/mindspore/lite/src/ops/pooling_grad.cc b/mindspore/lite/src/ops/pooling_grad.cc index e27360deaf..7afa0e4102 100644 --- a/mindspore/lite/src/ops/pooling_grad.cc +++ b/mindspore/lite/src/ops/pooling_grad.cc @@ -99,11 +99,11 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector attr->padMode = schema::PadMode_NOTSET; } - auto kernel_size = CastToInt(prim.GetAttr("ksize"), true); + auto kernel_size = CastToInt(prim.GetAttr("ksize")); attr->windowH = kernel_size[2]; attr->windowW = kernel_size[3]; - auto stride = CastToInt(prim.GetAttr("strides"), true); + auto stride = CastToInt(prim.GetAttr("strides")); attr->strideH = stride[2]; attr->strideW = stride[3]; this->primitive_->value.value = attr; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 68ac35e3ed..5932ca09a9 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -181,17 +181,13 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector CastToInt(const ValuePtr value, bool is_vector) { +std::vector CastToInt(const ValuePtr value) { if (value == nullptr) { MS_LOG(WARNING) << "valueptr is nullptr."; return {}; } std::vector cur_value; - if (is_vector) { - if (!utils::isa(value)) { - MS_LOG(WARNING) << "valueptr is not a sequence, value may be a scalar."; - return {}; - } + if (utils::isa(value)) { if (value->cast()->value().front()->type()->number_type() == kNumberTypeInt64) { auto origin_value = GetValue>(value); for (size_t index = 0; index < origin_value.size(); ++index) { @@ -337,7 +333,7 @@ void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vectorsize(); i++) { auto elem = tuple->value()[i]; MS_ASSERT(elem != nullptr); - data->emplace_back(CastToInt(elem, false).front()); + data->emplace_back(CastToInt(elem).front()); } } } diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index b5c8c1fa54..5b5e49c49d 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -57,7 +57,7 @@ static std::map kActivationTypeMap{ {"LeakyRelu", schema::ActivationType_LEAKY_RELU}, {"Tanh", schema::ActivationType_TANH}, {"Logistic", schema::ActivationType_SIGMOID}}; -std::vector CastToInt(const ValuePtr value, bool is_vector); +std::vector CastToInt(const ValuePtr value); class PrimitiveC : public mindspore::Primitive { public: // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index cc92c8a660..de770a4f95 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -84,10 +84,10 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector &inp for (size_t i = 0; i < valTuplPtr->size(); i++) { auto elem = (*valTuplPtr)[i]; MS_ASSERT(elem != nullptr); - attr->axes.emplace_back(CastToInt(elem, false).front()); + attr->axes.emplace_back(CastToInt(elem).front()); } } else { - int axes_item = CastToInt(value, false).front(); + int axes_item = CastToInt(value).front(); attr->axes.push_back(axes_item); } } diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index 6cbdb7bdb5..daf466622c 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -60,10 +60,10 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector &in for (size_t i = 0; i < tuple->size(); ++i) { auto elem = tuple->value()[i]; MS_ASSERT(elem != nullptr); - attr->shape.emplace_back(CastToInt(elem, false).front()); + attr->shape.emplace_back(CastToInt(elem).front()); } } else { - int dim = CastToInt(val, false).front(); + int dim = CastToInt(val).front(); attr->shape = {dim}; } } diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index bab8d9dd4b..12f57ba512 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -67,7 +67,7 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector &inp MS_LOG(ERROR) << "wrong resize type"; return RET_ERROR; } - std::vector targetSize = CastToInt(prim.GetAttr("size"), true); + std::vector targetSize = CastToInt(prim.GetAttr("size")); attr->newHeight = targetSize[0]; attr->newWidth = targetSize[1]; attr->alignCorners = GetValue(prim.GetAttr("align_corners")); diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 7b306c9db9..d25e4d6d80 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -73,7 +73,7 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector &inpu for (size_t i = 0; i < valTuplPtr->size(); i++) { auto elem = (*valTuplPtr)[i]; MS_ASSERT(elem != nullptr); - attr->begin.emplace_back(CastToInt(elem, false).front()); + attr->begin.emplace_back(CastToInt(elem).front()); } } } @@ -90,7 +90,7 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector &inpu for (size_t i = 0; i < valTuplPtr->size(); i++) { auto elem = (*valTuplPtr)[i]; MS_ASSERT(elem != nullptr); - attr->size.emplace_back(CastToInt(elem, false).front()); + attr->size.emplace_back(CastToInt(elem).front()); } } } diff --git a/mindspore/lite/src/ops/softmax.cc b/mindspore/lite/src/ops/softmax.cc index d89521f5bb..62adaada57 100644 --- a/mindspore/lite/src/ops/softmax.cc +++ b/mindspore/lite/src/ops/softmax.cc @@ -43,7 +43,7 @@ int SoftMax::UnPackAttr(const Primitive &prim, const std::vector &in MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } - auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front(); + auto prim_axis = CastToInt(prim.GetAttr("axis")).front(); attr->axis = prim_axis; this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc index 7d73e13b63..3cf4435d02 100644 --- a/mindspore/lite/src/ops/squeeze.cc +++ b/mindspore/lite/src/ops/squeeze.cc @@ -50,7 +50,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector &in MS_LOG(INFO) << "Squeeze's attr xis is set to default"; attr->axis = {0}; } else { - attr->axis = CastToInt(prim.GetAttr("axis"), true); + attr->axis = CastToInt(prim.GetAttr("axis")); } this->primitive_->value.value = attr; } diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index ea8c0b3d74..bd8038ff24 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -74,11 +74,11 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vectorbeginMask = CastToInt(prim.GetAttr("begin_mask"), false).front(); - attr->endMask = CastToInt(prim.GetAttr("end_mask"), false).front(); - attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask"), false).front(); - attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask"), false).front(); - attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask"), false).front(); + attr->beginMask = CastToInt(prim.GetAttr("begin_mask")).front(); + attr->endMask = CastToInt(prim.GetAttr("end_mask")).front(); + attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask")).front(); + attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask")).front(); + attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask")).front(); auto inputNodeFirst = inputs[kAnfPopulaterInputNumOne]; std::vector beginVec; GetAttrDataFromInput(inputNodeFirst, &beginVec); diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 412e2565d5..5dd481cbc2 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -56,7 +56,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector &input MS_LOG(INFO) << "Tile's attr dims is set to default"; attr->dims = {1}; } else { - attr->dims = CastToInt(prim.GetAttr("dims"), true); + attr->dims = CastToInt(prim.GetAttr("dims")); } if (inputs.size() == kAnfPopulaterInputNumTwo) { auto inputNode = inputs[kAnfPopulaterInputNumOne]; @@ -72,10 +72,10 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector &input for (size_t i = 0; i < valTuplPtr->size(); i++) { auto elem = (*valTuplPtr)[i]; MS_ASSERT(elem != nullptr); - attr->multiples.emplace_back(CastToInt(elem, false).front()); + attr->multiples.emplace_back(CastToInt(elem).front()); } } else { - int multiple = CastToInt(value, false).front(); + int multiple = CastToInt(value).front(); attr->multiples = {multiple}; } } diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc index c895377e10..f6ef25a9f9 100644 --- a/mindspore/lite/src/ops/transpose.cc +++ b/mindspore/lite/src/ops/transpose.cc @@ -64,7 +64,7 @@ int Transpose::UnPackAttr(const Primitive &prim, const std::vector & for (size_t i = 0; i < tuple->size(); i++) { auto elem = tuple->value()[i]; MS_ASSERT(elem != nullptr); - attr->perm.emplace_back(CastToInt(elem, false).front()); + attr->perm.emplace_back(CastToInt(elem).front()); } } } diff --git a/mindspore/lite/src/ops/unsorted_segment_sum.cc b/mindspore/lite/src/ops/unsorted_segment_sum.cc index 67c8242d10..5daa78def7 100644 --- a/mindspore/lite/src/ops/unsorted_segment_sum.cc +++ b/mindspore/lite/src/ops/unsorted_segment_sum.cc @@ -48,7 +48,7 @@ int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector attr = std::make_unique(); if (inputs[2]->isa()) { ValuePtr value = inputs[2]->cast()->value(); - attr->numSegments = CastToInt(value, false).front(); + attr->numSegments = CastToInt(value).front(); this->primitive_->value.value = attr.release(); } } diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 610f29374d..d4a0b22a5d 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -423,7 +423,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano paramTensor->dataType = typePtr->type_id(); paramTensor->dims = {1}; paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; - int real_data = CastToInt(value, false).front(); + int real_data = CastToInt(value).front(); paramTensor->data.resize(sizeof(int32_t)); auto ret = memcpy_s(paramTensor->data.data(), sizeof(int32_t), &real_data, sizeof(int32_t)); if (ret != EOK) { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index 74df4db85e..1befe246df 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -43,7 +43,11 @@ PrimitiveC *TfliteDeConvParser::ParseLitePrimitive(const std::unique_ptrpadMode = GetPadMode(tflite_attr->padding); attr->format = schema::Format::Format_NHWC; attr->activationType = schema::ActivationType_NO_ACTIVATION; - attr->hasBias = true; + if (tflite_op->inputs.size() > 3) { + attr->hasBias = true; + } else { + attr->hasBias = false; + } // get the conv op weight tensor auto weight_index = tflite_op->inputs[1]; diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 36a058a76b..65ab60821a 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -833,7 +833,7 @@ STATUS PostTrainingQuantizer::QuantNode() { MS_LOG(WARNING) << "index value node is null"; continue; } - size_t index = CastToInt(index_value_node->value(), false).front(); + size_t index = CastToInt(index_value_node->value()).front(); auto input_node = cnode->input(1); MS_ASSERT(input_node != nullptr); auto input_cnode = std::dynamic_pointer_cast(input_node); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 4c47422e37..cb548e3495 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -589,7 +589,7 @@ size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) { MS_ASSERT(output_index_value_node != nullptr); auto value_node = output_index_value_node->cast(); MS_ASSERT(value_node != nullptr); - return IntToSize(lite::CastToInt(value_node->value(), false).front()); + return IntToSize(lite::CastToInt(value_node->value()).front()); } std::shared_ptr>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, const AnfNodePtr &node, diff --git a/mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc index 6931a3642a..f7f1e30afc 100644 --- a/mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc @@ -67,7 +67,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; return lite::RET_ERROR; } - int index = lite::CastToInt(index_vnode->cast()->value(), false).front(); + int index = lite::CastToInt(index_vnode->cast()->value()).front(); int input_cnode_inputs_size = get_item_input_cnode->inputs().size(); if ((index + 1) >= input_cnode_inputs_size) { MS_LOG(ERROR) << "value node index is out of range.";