!8909 [lite] fix tflite deconv parser and adjust mindir int64

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
pull/8909/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ea5bacec49

@ -51,7 +51,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(INFO) << "BiasAdd's attr axis is set to default"; MS_LOG(INFO) << "BiasAdd's attr axis is set to default";
attr->axis = {1}; attr->axis = {1};
} else { } else {
attr->axis = CastToInt(prim.GetAttr("axis"), true); attr->axis = CastToInt(prim.GetAttr("axis"));
} }
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
} }

@ -49,7 +49,7 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
MS_LOG(WARNING) << "get axis failed"; MS_LOG(WARNING) << "get axis failed";
attr->axis = {0}; attr->axis = {0};
} else { } else {
attr->axis = CastToInt(prim.GetAttr("axis"), true); attr->axis = CastToInt(prim.GetAttr("axis"));
} }
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
} }

@ -51,7 +51,7 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "new primitiveT value failed"; MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR; return RET_ERROR;
} }
auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front(); auto prim_axis = CastToInt(prim.GetAttr("axis")).front();
attr->axis = prim_axis; attr->axis = prim_axis;
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) { if (this->primitive_->value.value == nullptr) {

@ -143,21 +143,21 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
} else { } else {
attr->format = schema::Format::Format_NUM_OF_FORMAT; attr->format = schema::Format::Format_NUM_OF_FORMAT;
} }
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true); auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true); auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
@ -179,7 +179,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
int channel_mutiplier = 1; int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) { if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier")).front();
} }
attr->channelMultiplier = channel_mutiplier; attr->channelMultiplier = channel_mutiplier;
@ -220,25 +220,25 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive
} else { } else {
attr->format = schema::Format::Format_NUM_OF_FORMAT; attr->format = schema::Format::Format_NUM_OF_FORMAT;
} }
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true); auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[2]; attr->dilateH = dilation[2];
attr->dilateW = dilation[3]; attr->dilateW = dilation[3];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true); auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") { if (pad_mode == "valid") {
@ -278,7 +278,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model"; MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
int group = CastToInt(groupAttr, false).front(); int group = CastToInt(groupAttr).front();
if (group > 1) { if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
} else { } else {

@ -94,7 +94,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
MS_LOG(ERROR) << "new primitiveT value failed"; MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR; return RET_ERROR;
} }
attr->group = CastToInt(prim.GetAttr("group"), false).front(); attr->group = CastToInt(prim.GetAttr("group")).front();
auto format = GetValue<std::string>(prim.GetAttr("data_format")); auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") { if (format == "NCHW") {
attr->format = schema::Format_NCHW; attr->format = schema::Format_NCHW;
@ -103,25 +103,25 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
} else { } else {
attr->format = schema::Format_NUM_OF_FORMAT; attr->format = schema::Format_NUM_OF_FORMAT;
} }
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true); auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true); auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[0]; attr->strideH = stride[0];
attr->strideW = stride[1]; attr->strideW = stride[1];
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") { if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID; attr->padMode = schema::PadMode_VALID;
@ -154,7 +154,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
for (size_t i = 0; i < valTuplPtr->size(); i++) { for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i]; auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr); MS_ASSERT(elem != nullptr);
attr->filter_shape[nchw2nhwc[i]] = CastToInt(elem, false).front(); attr->filter_shape[nchw2nhwc[i]] = CastToInt(elem).front();
} }
} }
} }

@ -92,7 +92,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
MS_LOG(ERROR) << "new primitiveT value failed"; MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR; return RET_ERROR;
} }
attr->group = CastToInt(prim.GetAttr("group"), false).front(); attr->group = CastToInt(prim.GetAttr("group")).front();
if (attr->group > 1) { if (attr->group > 1) {
this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput; this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput;
} }
@ -104,25 +104,25 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else { } else {
attr->format = schema::Format_NUM_OF_FORMAT; attr->format = schema::Format_NUM_OF_FORMAT;
} }
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true); auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true); auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[0]; attr->strideH = stride[0];
attr->strideW = stride[1]; attr->strideW = stride[1];
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") { if (pad_mode == "valid") {
@ -156,7 +156,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
for (size_t i = 0; i < valTuplPtr->size(); i++) { for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i]; auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr); MS_ASSERT(elem != nullptr);
attr->input_shape[nchw2nhwc[i]] = CastToInt(elem, false).front(); attr->input_shape[nchw2nhwc[i]] = CastToInt(elem).front();
} }
} }
} }

@ -136,21 +136,21 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
} else { } else {
attr->format = schema::Format::Format_NUM_OF_FORMAT; attr->format = schema::Format::Format_NUM_OF_FORMAT;
} }
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true); auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true); auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[0]; attr->strideH = stride[0];
attr->strideW = stride[1]; attr->strideW = stride[1];
@ -172,7 +172,7 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
int channel_mutiplier = 1; int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) { if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier")).front();
} }
attr->channelMultiplier = channel_mutiplier; attr->channelMultiplier = channel_mutiplier;
@ -203,25 +203,25 @@ void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::Primi
} else { } else {
attr->format = schema::Format_NUM_OF_FORMAT; attr->format = schema::Format_NUM_OF_FORMAT;
} }
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true); auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true); auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[0]; attr->strideH = stride[0];
attr->strideW = stride[1]; attr->strideW = stride[1];
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid" || pad_mode == "VALID") { if (pad_mode == "valid" || pad_mode == "VALID") {
@ -256,7 +256,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR; return RET_ERROR;
} }
int group = CastToInt(prim.GetAttr("group"), false).front(); int group = CastToInt(prim.GetAttr("group")).front();
if (group == 1) { if (group == 1) {
PopulaterDeConv2DSingleGroup(prim, this->primitive_, group); PopulaterDeConv2DSingleGroup(prim, this->primitive_, group);
} else if (group > 1) { } else if (group > 1) {

@ -86,27 +86,27 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else { } else {
attr->format = schema::Format::Format_NUM_OF_FORMAT; attr->format = schema::Format::Format_NUM_OF_FORMAT;
} }
auto pad_list = CastToInt(prim.GetAttr("pads"), true); auto pad_list = CastToInt(prim.GetAttr("pads"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true); auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
if (utils::isa<ValueSequeue>(prim.GetAttr("kernel_size"))) { if (utils::isa<ValueSequeue>(prim.GetAttr("kernel_size"))) {
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
} else { } else {
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), false).front(); auto kernel_size = CastToInt(prim.GetAttr("kernel_size")).front();
attr->kernelH = kernel_size; attr->kernelH = kernel_size;
attr->kernelW = kernel_size; attr->kernelW = kernel_size;
} }
auto stride = CastToInt(prim.GetAttr("stride"), true); auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
@ -124,7 +124,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else { } else {
attr->activationType = schema::ActivationType_NO_ACTIVATION; attr->activationType = schema::ActivationType_NO_ACTIVATION;
} }
auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier")).front();
attr->channelMultiplier = channel_multiplier; attr->channelMultiplier = channel_multiplier;
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo); MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);

@ -53,7 +53,7 @@ int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
// use axis instead of dim // use axis instead of dim
if (inputs[1]->isa<ValueNode>()) { if (inputs[1]->isa<ValueNode>()) {
auto axis_tensor = inputs[1]->cast<ValueNodePtr>(); auto axis_tensor = inputs[1]->cast<ValueNodePtr>();
int axis = CastToInt(axis_tensor->value(), false).front(); int axis = CastToInt(axis_tensor->value()).front();
attr->dim = axis; attr->dim = axis;
} else { } else {
MS_LOG(ERROR) << "input axis is not value node."; MS_LOG(ERROR) << "input axis is not value node.";

@ -59,7 +59,7 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
} }
if (inputs[2]->isa<ValueNode>()) { if (inputs[2]->isa<ValueNode>()) {
ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>(); ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>();
int axis = CastToInt(axis_tensor->value(), false).front(); int axis = CastToInt(axis_tensor->value()).front();
gather_attr->axis = axis; gather_attr->axis = axis;
} else { } else {
MS_LOG(ERROR) << "input axis is not value node."; MS_LOG(ERROR) << "input axis is not value node.";

@ -48,7 +48,7 @@ int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
} }
attr->axis = -1; attr->axis = -1;
if (prim.GetAttr("axis") != nullptr) { if (prim.GetAttr("axis") != nullptr) {
attr->axis = CastToInt(prim.GetAttr("axis"), false).front(); attr->axis = CastToInt(prim.GetAttr("axis")).front();
} }
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) { if (this->primitive_->value.value == nullptr) {

@ -110,11 +110,11 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
attr->padMode = schema::PadMode_NOTSET; attr->padMode = schema::PadMode_NOTSET;
} }
auto kernel_size = CastToInt(prim.GetAttr("ksize"), true); auto kernel_size = CastToInt(prim.GetAttr("ksize"));
attr->windowH = kernel_size[2]; attr->windowH = kernel_size[2];
attr->windowW = kernel_size[3]; attr->windowW = kernel_size[3];
auto stride = CastToInt(prim.GetAttr("strides"), true); auto stride = CastToInt(prim.GetAttr("strides"));
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
this->primitive_->value.value = attr; this->primitive_->value.value = attr;

@ -99,11 +99,11 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
attr->padMode = schema::PadMode_NOTSET; attr->padMode = schema::PadMode_NOTSET;
} }
auto kernel_size = CastToInt(prim.GetAttr("ksize"), true); auto kernel_size = CastToInt(prim.GetAttr("ksize"));
attr->windowH = kernel_size[2]; attr->windowH = kernel_size[2];
attr->windowW = kernel_size[3]; attr->windowW = kernel_size[3];
auto stride = CastToInt(prim.GetAttr("strides"), true); auto stride = CastToInt(prim.GetAttr("strides"));
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
this->primitive_->value.value = attr; this->primitive_->value.value = attr;

@ -181,17 +181,13 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
std::vector<int> CastToInt(const ValuePtr value, bool is_vector) { std::vector<int> CastToInt(const ValuePtr value) {
if (value == nullptr) { if (value == nullptr) {
MS_LOG(WARNING) << "valueptr is nullptr."; MS_LOG(WARNING) << "valueptr is nullptr.";
return {}; return {};
} }
std::vector<int> cur_value; std::vector<int> cur_value;
if (is_vector) { if (utils::isa<ValueSequeuePtr>(value)) {
if (!utils::isa<ValueSequeuePtr>(value)) {
MS_LOG(WARNING) << "valueptr is not a sequence, value may be a scalar.";
return {};
}
if (value->cast<ValueSequeuePtr>()->value().front()->type()->number_type() == kNumberTypeInt64) { if (value->cast<ValueSequeuePtr>()->value().front()->type()->number_type() == kNumberTypeInt64) {
auto origin_value = GetValue<std::vector<int64_t>>(value); auto origin_value = GetValue<std::vector<int64_t>>(value);
for (size_t index = 0; index < origin_value.size(); ++index) { for (size_t index = 0; index < origin_value.size(); ++index) {
@ -337,7 +333,7 @@ void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<i
for (size_t i = 0; i < tuple->size(); i++) { for (size_t i = 0; i < tuple->size(); i++) {
auto elem = tuple->value()[i]; auto elem = tuple->value()[i];
MS_ASSERT(elem != nullptr); MS_ASSERT(elem != nullptr);
data->emplace_back(CastToInt(elem, false).front()); data->emplace_back(CastToInt(elem).front());
} }
} }
} }

@ -57,7 +57,7 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{
{"LeakyRelu", schema::ActivationType_LEAKY_RELU}, {"LeakyRelu", schema::ActivationType_LEAKY_RELU},
{"Tanh", schema::ActivationType_TANH}, {"Tanh", schema::ActivationType_TANH},
{"Logistic", schema::ActivationType_SIGMOID}}; {"Logistic", schema::ActivationType_SIGMOID}};
std::vector<int> CastToInt(const ValuePtr value, bool is_vector); std::vector<int> CastToInt(const ValuePtr value);
class PrimitiveC : public mindspore::Primitive { class PrimitiveC : public mindspore::Primitive {
public: public:
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().

@ -84,10 +84,10 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
for (size_t i = 0; i < valTuplPtr->size(); i++) { for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i]; auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr); MS_ASSERT(elem != nullptr);
attr->axes.emplace_back(CastToInt(elem, false).front()); attr->axes.emplace_back(CastToInt(elem).front());
} }
} else { } else {
int axes_item = CastToInt(value, false).front(); int axes_item = CastToInt(value).front();
attr->axes.push_back(axes_item); attr->axes.push_back(axes_item);
} }
} }

@ -60,10 +60,10 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
for (size_t i = 0; i < tuple->size(); ++i) { for (size_t i = 0; i < tuple->size(); ++i) {
auto elem = tuple->value()[i]; auto elem = tuple->value()[i];
MS_ASSERT(elem != nullptr); MS_ASSERT(elem != nullptr);
attr->shape.emplace_back(CastToInt(elem, false).front()); attr->shape.emplace_back(CastToInt(elem).front());
} }
} else { } else {
int dim = CastToInt(val, false).front(); int dim = CastToInt(val).front();
attr->shape = {dim}; attr->shape = {dim};
} }
} }

@ -67,7 +67,7 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "wrong resize type"; MS_LOG(ERROR) << "wrong resize type";
return RET_ERROR; return RET_ERROR;
} }
std::vector<int> targetSize = CastToInt(prim.GetAttr("size"), true); std::vector<int> targetSize = CastToInt(prim.GetAttr("size"));
attr->newHeight = targetSize[0]; attr->newHeight = targetSize[0];
attr->newWidth = targetSize[1]; attr->newWidth = targetSize[1];
attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners")); attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners"));

@ -73,7 +73,7 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
for (size_t i = 0; i < valTuplPtr->size(); i++) { for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i]; auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr); MS_ASSERT(elem != nullptr);
attr->begin.emplace_back(CastToInt(elem, false).front()); attr->begin.emplace_back(CastToInt(elem).front());
} }
} }
} }
@ -90,7 +90,7 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
for (size_t i = 0; i < valTuplPtr->size(); i++) { for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i]; auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr); MS_ASSERT(elem != nullptr);
attr->size.emplace_back(CastToInt(elem, false).front()); attr->size.emplace_back(CastToInt(elem).front());
} }
} }
} }

@ -43,7 +43,7 @@ int SoftMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(ERROR) << "new primitiveT value failed"; MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR; return RET_ERROR;
} }
auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front(); auto prim_axis = CastToInt(prim.GetAttr("axis")).front();
attr->axis = prim_axis; attr->axis = prim_axis;
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) { if (this->primitive_->value.value == nullptr) {

@ -50,7 +50,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(INFO) << "Squeeze's attr xis is set to default"; MS_LOG(INFO) << "Squeeze's attr xis is set to default";
attr->axis = {0}; attr->axis = {0};
} else { } else {
attr->axis = CastToInt(prim.GetAttr("axis"), true); attr->axis = CastToInt(prim.GetAttr("axis"));
} }
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
} }

@ -74,11 +74,11 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr
MS_LOG(ERROR) << "new StridedSlice failed"; MS_LOG(ERROR) << "new StridedSlice failed";
return RET_ERROR; return RET_ERROR;
} }
attr->beginMask = CastToInt(prim.GetAttr("begin_mask"), false).front(); attr->beginMask = CastToInt(prim.GetAttr("begin_mask")).front();
attr->endMask = CastToInt(prim.GetAttr("end_mask"), false).front(); attr->endMask = CastToInt(prim.GetAttr("end_mask")).front();
attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask"), false).front(); attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask")).front();
attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask"), false).front(); attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask")).front();
attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask"), false).front(); attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask")).front();
auto inputNodeFirst = inputs[kAnfPopulaterInputNumOne]; auto inputNodeFirst = inputs[kAnfPopulaterInputNumOne];
std::vector<int> beginVec; std::vector<int> beginVec;
GetAttrDataFromInput(inputNodeFirst, &beginVec); GetAttrDataFromInput(inputNodeFirst, &beginVec);

@ -56,7 +56,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
MS_LOG(INFO) << "Tile's attr dims is set to default"; MS_LOG(INFO) << "Tile's attr dims is set to default";
attr->dims = {1}; attr->dims = {1};
} else { } else {
attr->dims = CastToInt(prim.GetAttr("dims"), true); attr->dims = CastToInt(prim.GetAttr("dims"));
} }
if (inputs.size() == kAnfPopulaterInputNumTwo) { if (inputs.size() == kAnfPopulaterInputNumTwo) {
auto inputNode = inputs[kAnfPopulaterInputNumOne]; auto inputNode = inputs[kAnfPopulaterInputNumOne];
@ -72,10 +72,10 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
for (size_t i = 0; i < valTuplPtr->size(); i++) { for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i]; auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr); MS_ASSERT(elem != nullptr);
attr->multiples.emplace_back(CastToInt(elem, false).front()); attr->multiples.emplace_back(CastToInt(elem).front());
} }
} else { } else {
int multiple = CastToInt(value, false).front(); int multiple = CastToInt(value).front();
attr->multiples = {multiple}; attr->multiples = {multiple};
} }
} }

@ -64,7 +64,7 @@ int Transpose::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &
for (size_t i = 0; i < tuple->size(); i++) { for (size_t i = 0; i < tuple->size(); i++) {
auto elem = tuple->value()[i]; auto elem = tuple->value()[i];
MS_ASSERT(elem != nullptr); MS_ASSERT(elem != nullptr);
attr->perm.emplace_back(CastToInt(elem, false).front()); attr->perm.emplace_back(CastToInt(elem).front());
} }
} }
} }

@ -48,7 +48,7 @@ int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector<AnfN
std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>(); std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>();
if (inputs[2]->isa<ValueNode>()) { if (inputs[2]->isa<ValueNode>()) {
ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value(); ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value();
attr->numSegments = CastToInt(value, false).front(); attr->numSegments = CastToInt(value).front();
this->primitive_->value.value = attr.release(); this->primitive_->value.value = attr.release();
} }
} }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save