From ee29ccad087eec79c5f6164a72e3e8a22848e3a2 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Tue, 22 Dec 2020 16:01:36 +0800 Subject: [PATCH] =?UTF-8?q?add=20argmax=E3=80=81layernorm=E3=80=81batchmat?= =?UTF-8?q?mul=20for=20minidr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspore/lite/schema/model.fbs | 1 + mindspore/lite/src/ops/argmax.cc | 31 ++- mindspore/lite/src/ops/argmax.h | 1 + mindspore/lite/src/ops/argmin.cc | 1 + mindspore/lite/src/ops/gelu.cc | 53 ++++ mindspore/lite/src/ops/gelu.h | 40 +++ mindspore/lite/src/ops/layer_norm.cc | 42 +++- mindspore/lite/src/ops/layer_norm.h | 1 + mindspore/lite/src/ops/primitive_c.cc | 36 ++- mindspore/lite/test/CMakeLists.txt | 1 + .../lite/tools/anf_exporter/anf_exporter.cc | 3 +- mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../lite/tools/converter/anf_transform.cc | 7 + .../graph/mindir_inputs_adjust_pass.cc | 236 ++++++++++++++++++ .../graph/mindir_inputs_adjust_pass.h | 41 +++ .../optimizer/graph/onnx_inputs_adjust_pass.h | 2 - 16 files changed, 485 insertions(+), 12 deletions(-) create mode 100644 mindspore/lite/src/ops/gelu.cc create mode 100644 mindspore/lite/src/ops/gelu.h create mode 100644 mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.cc create mode 100644 mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.h diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 20f0746e4e..20c666e814 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -261,6 +261,7 @@ union PrimitiveType { Reciprocal, Merge, Mod, + GeLU, } enum QuantType: int { diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc index 1e8af0ea8d..e78ba4160c 100644 --- a/mindspore/lite/src/ops/argmax.cc +++ b/mindspore/lite/src/ops/argmax.cc @@ -34,7 +34,36 @@ void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgM void ArgMax::SetTopK(int top_k) { this->primitive_->value.AsArgMax()->topK = top_k; } void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->keepDims = keep_dims; } void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; } - +int ArgMax::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitive error"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_ArgMax; + } + if (this->primitive_->value.type != schema::PrimitiveType_ArgMax) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto argmax_attr = new (std::nothrow) schema::ArgMaxT(); + if (argmax_attr == nullptr) { + MS_LOG(ERROR) << "new primitive value.value error"; + return RET_ERROR; + } + if (prim.GetAttr("axis") != nullptr) { + argmax_attr->axis = static_cast(GetValue(prim.GetAttr("axis"))); + } + if (prim.GetAttr("keep_dims") != nullptr) { + argmax_attr->keepDims = static_cast(GetValue(prim.GetAttr("keep_dims"))); + } + argmax_attr->outMaxValue = false; + this->primitive_->value.value = argmax_attr; + } + return RET_OK; +} #else int ArgMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/argmax.h b/mindspore/lite/src/ops/argmax.h index 4fe49309af..d208c2b60a 100644 --- a/mindspore/lite/src/ops/argmax.h +++ b/mindspore/lite/src/ops/argmax.h @@ -37,6 +37,7 @@ class ArgMax : public PrimitiveC { void SetTopK(int top_k); void SetKeepDims(bool keep_dims); void SetAxisType(int axis_type); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc index f7a4a72fdd..daf856a21d 100644 --- a/mindspore/lite/src/ops/argmin.cc +++ b/mindspore/lite/src/ops/argmin.cc @@ -61,6 +61,7 @@ int ArgMin::UnPackAttr(const Primitive &prim, const std::vector &inp if (prim.GetAttr("keep_dims") != nullptr) { attr->keepDims = static_cast(GetValue(prim.GetAttr("keep_dims"))); } + attr->outMaxValue = false; } return RET_OK; } diff --git a/mindspore/lite/src/ops/gelu.cc b/mindspore/lite/src/ops/gelu.cc new file mode 100644 index 0000000000..234f8e7454 --- /dev/null +++ b/mindspore/lite/src/ops/gelu.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/gelu.h" +#include +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#include "src/tensor.h" +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int GeLU::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_GeLU; + } + if (this->primitive_->value.type != schema::PrimitiveType_GeLU) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::GeLUT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/gelu.h b/mindspore/lite/src/ops/gelu.h new file mode 100644 index 0000000000..d2fc914a75 --- /dev/null +++ b/mindspore/lite/src/ops/gelu.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_GELU_H_ +#define LITE_MINDSPORE_LITE_C_OPS_GELU_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class GeLU : public PrimitiveC { + public: + GeLU() = default; + ~GeLU() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(GeLU, PrimitiveC); + explicit GeLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#endif +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_GELU_H_ diff --git a/mindspore/lite/src/ops/layer_norm.cc b/mindspore/lite/src/ops/layer_norm.cc index edb9272fec..a5b1c597c2 100644 --- a/mindspore/lite/src/ops/layer_norm.cc +++ b/mindspore/lite/src/ops/layer_norm.cc @@ -35,7 +35,42 @@ void LayerNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsLayerNorm( void LayerNorm::SetElementwiseAffine(bool elementwiseAffine) { this->primitive_->value.AsLayerNorm()->elementwiseAffine = elementwiseAffine; } - +int LayerNorm::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitive error"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_LayerNorm; + } + if (this->primitive_->value.type != schema::PrimitiveType_LayerNorm) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto layer_norm_attr = new (std::nothrow) schema::LayerNormT(); + if (layer_norm_attr == nullptr) { + MS_LOG(ERROR) << "new primitive value.value error"; + return RET_ERROR; + } + auto value_attr = prim.GetAttr("epsilon"); + if (value_attr != nullptr) { + layer_norm_attr->epsilon = GetValue(value_attr); + } else { + layer_norm_attr->epsilon = 1e-7; + } + value_attr = prim.GetAttr("normalized_shape"); + if (value_attr != nullptr) { + layer_norm_attr->normalizedShape = CastToInt(value_attr); + } + if (inputs.size() == 3) { + layer_norm_attr->elementwiseAffine = true; + } + this->primitive_->value.value = layer_norm_attr; + } + return RET_OK; +} #else int LayerNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); @@ -100,13 +135,12 @@ int LayerNorm::InferShape(std::vector inputs_, std::vector layernorm + // instance norm -> layernorm only for nchw if (input->format() == schema::Format_NCHW) { normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 2, input_shape.end()); elementwise_mode_ = 1; } else { - MS_LOG(INFO) << "normalized_shape attr invalid"; - return RET_PARAM_INVALID; + normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 1, input_shape.end()); } } size_t first_index = input_shape.size() - normlized_shape_.size(); diff --git a/mindspore/lite/src/ops/layer_norm.h b/mindspore/lite/src/ops/layer_norm.h index 40856264f7..4d83c1863e 100644 --- a/mindspore/lite/src/ops/layer_norm.h +++ b/mindspore/lite/src/ops/layer_norm.h @@ -35,6 +35,7 @@ class LayerNorm : public PrimitiveC { void SetNormalizedShape(const std::vector &normalizedShape); void SetEpsilon(float epsilon); void SetElementwiseAffine(bool elementwiseAffine); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 056146919e..ab9ef80043 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -160,6 +160,7 @@ #include "src/ops/merge.h" #include "src/ops/switch.h" #include "src/ops/partial.h" +#include "src/ops/gelu.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -330,9 +331,28 @@ void PrimitiveC::PopulaterOutputQuantParam(const Primitive &prim, bool narrowRan void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector &inputs) { auto narrow_range = prim.GetAttr("narrow_range"); - bool narrowRangeQuantParam = narrow_range != nullptr && GetValue(narrow_range); + bool narrowRangeQuantParam = false; + if (narrow_range != nullptr) { + if (utils::isa(narrow_range)) { + auto narrow_range_tensor = narrow_range->cast(); + narrowRangeQuantParam = *reinterpret_cast(narrow_range_tensor->data_c()); + } else if (utils::isa::type>(narrow_range)) { + narrowRangeQuantParam = GetValue(narrow_range); + } else { + MS_LOG(ERROR) << "valueptr is invalid."; + return; + } + } auto num_bits = prim.GetAttr("num_bits"); - int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue(num_bits) : 8; + int32_t numbitsRangeQuantParam = 8; + if (num_bits != nullptr) { + if (utils::isa(num_bits)) { + auto num_bits_tensor = num_bits->cast(); + numbitsRangeQuantParam = *reinterpret_cast(num_bits_tensor->data_c()); + } else if (utils::isa::type>(num_bits)) { + numbitsRangeQuantParam = GetValue(num_bits); + } + } PopulaterInputQuantParam(prim, inputs, narrowRangeQuantParam, numbitsRangeQuantParam); PopulaterOutputQuantParam(prim, narrowRangeQuantParam, numbitsRangeQuantParam); } @@ -511,7 +531,7 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "make_tuple") { return NewPrimitiveC(prim, inputs, quantType); - } else if (op_type == "MatMul") { + } else if (op_type == "MatMul" || op_type == "BatchMatMul") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Mul") { return NewPrimitiveC(prim, inputs, quantType); @@ -601,7 +621,7 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Mod") { return NewPrimitiveC(prim, inputs, quantType); - } else if (op_type == "ArgMinWithValue") { + } else if (op_type == "ArgMin" || op_type == "ArgMinWithValue") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Range") { return NewPrimitiveC(prim, inputs, quantType); @@ -621,6 +641,12 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Merge") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "LayerNorm") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "ArgMax" || op_type == "ArgMaxWithValue") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Gelu") { + return NewPrimitiveC(prim, inputs, quantType); #ifdef SUPPORT_TRAIN } else if (op_type == "SoftmaxCrossEntropyWithLogits") { @@ -965,6 +991,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) Partial(primitive); case schema::PrimitiveType_Assert: return new (std::nothrow) AssertOP(primitive); + case schema::PrimitiveType_GeLU: + return new (std::nothrow) GeLU(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 4dd031127a..cec95bf66a 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -199,6 +199,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc + ${LITE_DIR}/tools/optimizer/graph/mindir_inputs_adjust_pass.cc ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc ${LITE_DIR}/tools/optimizer/graph/while_pass.cc ) diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index f295bde218..3e3fe2dc2c 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -737,7 +737,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrallTensors.emplace_back(msTensor); if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || - IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) { + IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) || + IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) { break; } #endif diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 8756a9eabe..6992601f4f 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -60,6 +60,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/mindir_adjust_pass.cc ../optimizer/graph/onnx_inputs_adjust_pass.cc ../optimizer/graph/while_pass.cc + ../optimizer/graph/mindir_inputs_adjust_pass.cc ) add_subdirectory(../anf_importer anf_importer) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 84baa2138d..094c54b59d 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -30,6 +30,7 @@ #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" #include "tools/optimizer/fusion/conv_conv_fusion.h" #include "tools/optimizer/graph/mindir_adjust_pass.h" +#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" #include "tools/optimizer/graph/identity_remove_pass.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h" @@ -77,6 +78,12 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } + auto mindir_inputs_adjust_pass = std::make_shared(); + if (!mindir_inputs_adjust_pass->Run(old_graph)) { + MS_LOG(ERROR) << "mindir inputs adjust failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } } // onnx pre adjustment diff --git a/mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.cc new file mode 100644 index 0000000000..490cec179f --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.cc @@ -0,0 +1,236 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" +#include +#include +#include "src/common/log_adapter.h" +#include "src/ops/primitive_c.h" +#include "src/tensor.h" + +using mindspore::lite::PrimitiveC; +namespace mindspore { +namespace opt { +namespace { +template +void CopyAttrForArgMinMax(T *left, T *right) { + MS_ASSERT(left != null && right != nullptr); + left->axis = right->axis; + left->outMaxValue = right->outMaxValue; + left->axisType = right->axisType; + left->keepDims = right->keepDims; + left->topK = right->topK; +} +} // namespace + +bool MindirInputAdjustOpPass::CheckCNodeIsArgMinMax(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + auto prim_node = cnode->inputs().at(0); + MS_ASSERT(prim_node != nullptr); + auto prim_value_node = prim_node->cast(); + if (prim_value_node == nullptr) { + MS_LOG(DEBUG) << "cnode first input is not valueNode."; + return false; + } + auto value = prim_value_node->value(); + MS_ASSERT(value != nullptr); + auto prim_c = value->cast(); + if (prim_c == nullptr) { + MS_LOG(DEBUG) << "prim is not primitiveC."; + return false; + } + auto prim = prim_c->primitiveT(); + MS_ASSERT(prim != nullptr); + return prim->value.type == schema::PrimitiveType_ArgMax || prim->value.type == schema::PrimitiveType_ArgMin; +} + +int MindirInputAdjustOpPass::AdjustArgMinMaxInputs(std::vector *inputs, bool index_or_value) { + MS_ASSERT(inputs != nullptr); + auto prim_node = inputs->at(0); + MS_ASSERT(prim_node != nullptr); + auto prim_value_node = prim_node->cast(); + if (prim_value_node == nullptr) { + MS_LOG(ERROR) << "cnode first input is not valueNode."; + return lite::RET_ERROR; + } + auto prim_value = prim_value_node->value(); + if (prim_value == nullptr) { + MS_LOG(ERROR) << "valueNode value is nullptr."; + return lite::RET_ERROR; + } + auto prim_c = prim_value->cast(); + if (prim_c == nullptr) { + MS_LOG(ERROR) << "value is not primitiveC."; + return lite::RET_ERROR; + } + auto prim = prim_c->primitiveT(); + MS_ASSERT(prim != nullptr && prim->value.value != nullptr); + auto attr = prim->value.value; + if (prim->value.type == schema::PrimitiveType_ArgMax) { + reinterpret_cast(attr)->outMaxValue = index_or_value; + } else if (prim->value.type == schema::PrimitiveType_ArgMin) { + reinterpret_cast(attr)->outMaxValue = index_or_value; + } + return lite::RET_OK; +} + +int MindirInputAdjustOpPass::CopyPrimitiveCForArgMinMax(std::vector *inputs) { + MS_ASSERT(inputs != nullptr); + auto prim_node = inputs->at(0); + MS_ASSERT(prim_node != nullptr); + auto prim_value_node = prim_node->cast(); + if (prim_value_node == nullptr) { + MS_LOG(ERROR) << "cnode first input is not valueNode."; + return lite::RET_ERROR; + } + auto prim_value = prim_value_node->value(); + if (prim_value == nullptr) { + MS_LOG(ERROR) << "valueNode value is nullptr."; + return lite::RET_ERROR; + } + auto prim_c = prim_value->cast(); + if (prim_c == nullptr) { + MS_LOG(ERROR) << "value is not primitiveC."; + return lite::RET_ERROR; + } + auto prim = prim_c->primitiveT(); + MS_ASSERT(prim != nullptr && prim->value.value != nullptr); + auto primitive = std::make_unique(); + if (prim->value.type == schema::PrimitiveType_ArgMax) { + primitive->value.type = schema::PrimitiveType_ArgMax; + auto attr = std::make_unique(); + CopyAttrForArgMinMax(attr.get(), reinterpret_cast(prim->value.value)); + primitive->value.value = attr.release(); + } else { + primitive->value.type = schema::PrimitiveType_ArgMin; + auto attr = std::make_unique(); + CopyAttrForArgMinMax(attr.get(), reinterpret_cast(prim->value.value)); + primitive->value.value = attr.release(); + } + auto primitive_c = PrimitiveC::Create(primitive.release()); + auto value_node = NewValueNode(std::shared_ptr(primitive_c)); + inputs->erase(inputs->begin()); + inputs->insert(inputs->begin(), value_node); + return lite::RET_OK; +} + +int MindirInputAdjustOpPass::BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, + const CNodePtr &argmin_max) { + MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr); + auto inputs = argmin_max->inputs(); + if (CopyPrimitiveCForArgMinMax(&inputs) != lite::RET_OK) { + MS_LOG(ERROR) << "copy argmin or argmax failed."; + return lite::RET_ERROR; + } + if (AdjustArgMinMaxInputs(&inputs, false) != lite::RET_OK) { + MS_LOG(ERROR) << "adjust argmin or argmax attr failed."; + return lite::RET_ERROR; + } + auto new_cnode = graph->NewCNode(inputs); + new_cnode->set_fullname_with_scope(argmin_max->fullname_with_scope() + "_index"); + auto type_ptr = TypeIdToType(kTypeUnknown); + std::vector shape_vector; + new_cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); + auto manager = graph->manager(); + MS_ASSERT(manager != nullptr); + manager->Replace(tuple_get_item, new_cnode); + return lite::RET_OK; +} + +int MindirInputAdjustOpPass::AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, + const CNodePtr &argmin_max) { + MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr); + auto inputs = argmin_max->inputs(); + if (AdjustArgMinMaxInputs(&inputs, true) != lite::RET_OK) { + MS_LOG(ERROR) << "adjust argmin or argmax attr failed."; + return lite::RET_ERROR; + } + auto type_ptr = TypeIdToType(kTypeUnknown); + std::vector shape_vector; + auto abtract_tensor = std::make_shared(type_ptr, shape_vector); + argmin_max->set_abstract(abtract_tensor); + auto manager = graph->manager(); + MS_ASSERT(manager != nullptr); + manager->Replace(tuple_get_item, argmin_max); + return lite::RET_OK; +} + +int MindirInputAdjustOpPass::AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode) { + MS_ASSERT(graph != nullptr && cnode != nullptr); + auto inputs = cnode->inputs(); + if (inputs.size() != 3) { + MS_LOG(ERROR) << "tupleGetItem inputs size is invalid: " << inputs.size(); + return lite::RET_ERROR; + } + auto argmin_max = inputs.at(1); + MS_ASSERT(argmin_max != nullptr); + auto argmin_max_cnode = argmin_max->cast(); + if (argmin_max_cnode == nullptr) { + MS_LOG(ERROR) << "the second input is not a cnode."; + return lite::RET_ERROR; + } + if (!CheckCNodeIsArgMinMax(argmin_max_cnode)) { + MS_LOG(DEBUG) << "tuple_get_item first input is not argmin and argmax."; + return lite::RET_OK; + } + auto index_vnode = inputs.at(2); + auto value_node = index_vnode->cast(); + if (value_node == nullptr) { + MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; + return lite::RET_ERROR; + } + int index = lite::CastToInt(value_node->value()).front(); + if (index == 0) { + if (BuildCNodeForArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) { + MS_LOG(ERROR) << "build new cnode failed."; + return lite::RET_ERROR; + } + } else if (index == 1) { + if (AdjustArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) { + MS_LOG(ERROR) << "adjust argmin_max failed."; + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +bool MindirInputAdjustOpPass::Run(const FuncGraphPtr &graph) { + MS_ASSERT(graph != nullptr); + auto manager = Manage(graph, true); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr."; + return lite::RET_NULL_PTR; + } + auto node_list = TopoSort(graph->get_return()); + int status = lite::RET_OK; + for (auto &node : node_list) { + auto cnode = node->cast(); + if (cnode == nullptr) { + MS_LOG(DEBUG) << "node is not cnode."; + continue; + } + auto type = opt::GetCNodeType(node); + if (type == schema::PrimitiveType_TupleGetItem) { + status = AdjustTupleGetItemWithArgMinMax(graph, cnode); + } + if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { + MS_LOG(ERROR) << "adjust input pass is failed."; + return false; + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.h new file mode 100644 index 0000000000..7040f81253 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_ + +#include +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "src/param_value_lite.h" + +namespace mindspore::opt { +class MindirInputAdjustOpPass : public Pass { + public: + MindirInputAdjustOpPass() : Pass("mindir_inputs_adjust_pass") {} + ~MindirInputAdjustOpPass() override = default; + bool CheckCNodeIsArgMinMax(const CNodePtr &cnode); + int AdjustArgMinMaxInputs(std::vector *inputs, bool index_or_value); + int CopyPrimitiveCForArgMinMax(std::vector *inputs); + int BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max); + int AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max); + int AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode); + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h index b8b11097e0..447371cf9a 100644 --- a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h +++ b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h @@ -43,8 +43,6 @@ class OnnxInputAdjustOpPass : public Pass { STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); bool Run(const FuncGraphPtr &func_graph) override; - - private: }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_ONNX_INPUTS_ADJUST_PASS_H_