From 5953855106eb59d58b1a67509d9fab05933d443e Mon Sep 17 00:00:00 2001 From: wsc Date: Wed, 16 Sep 2020 15:59:36 +0800 Subject: [PATCH] Add converter method for operator 'Split'. --- mindspore/lite/src/ops/primitive_c.cc | 2 + mindspore/lite/src/ops/split.cc | 37 ++++++++++++++++++- mindspore/lite/src/ops/split.h | 1 + .../lite/tools/anf_exporter/anf_exporter.cc | 9 +++-- 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index e8191d0ef7..29f101416a 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -418,6 +418,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Cast") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Split") { + return NewPrimitiveC(prim, inputs, quantType); #ifdef SUPPORT_TRAIN } else if (op_type == "SoftmaxCrossEntropyWithLogits") { diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc index 02941142fa..d10cb05676 100644 --- a/mindspore/lite/src/ops/split.cc +++ b/mindspore/lite/src/ops/split.cc @@ -29,6 +29,37 @@ void Split::SetSizeSplits(const std::vector &size_splits) { } void Split::SetSplitDim(int split_dim) { this->primitive_->value.AsSplit()->splitDim = split_dim; } +int Split::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Split; + } + if (this->primitive_->value.type != schema::PrimitiveType_Split) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SplitT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + attr->splitDim = GetValue(prim.GetAttr("axis")); + attr->numberSplit = GetValue(prim.GetAttr("output_num")); + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + + return RET_OK; +} + #else int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); } @@ -99,12 +130,14 @@ int Split::InferShape(std::vector inputs_, std::vector outpu output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end()); int split_dim_i = input_shape[split_dim]; // support split size is -1 in the end. - if (i == number_split - 1 && size_split[i] == -1) { + if (size_split.empty()) { + split_dim_i = input_shape[split_dim] / number_split; + } else if (i == number_split - 1 && size_split[i] == -1) { for (size_t j = 0; j < size_split.size() - 1; ++j) { split_dim_i -= size_split[j]; } } else { - split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i]; + split_dim_i = size_split[i]; } output_shape[split_dim] = split_dim_i; outputs_[i]->set_shape(output_shape); diff --git a/mindspore/lite/src/ops/split.h b/mindspore/lite/src/ops/split.h index dc33c7141c..95cb4fe638 100644 --- a/mindspore/lite/src/ops/split.h +++ b/mindspore/lite/src/ops/split.h @@ -35,6 +35,7 @@ class Split : public PrimitiveC { void SetNumberSplit(int number_split); void SetSizeSplits(const std::vector &size_splits); void SetSplitDim(int split_dim); + int UnPackAttr(const Primitive &prim, const std::vector &inputs); #else Split() = default; diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 1283f8815a..3a4c82262a 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -419,10 +419,13 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr input_anode, MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; } #endif + } else if (value->isa()) { + MS_LOG(INFO) << "Value is a number."; + return RET_OK; } else { - MS_LOG(ERROR) << "Not support value type , need add support."; - return RET_ERROR; - } + MS_LOG(ERROR) << "Not support value type , need add support."; + return RET_ERROR; + } return RET_OK; }