!10324 [lite]add argmax、layernorm、batchmatmul for minidr
	
		
	
				
					
				
			From: @xu_anyue Reviewed-by: Signed-off-by:pull/10324/MERGE
						commit
						00455b9559
					
				| @ -0,0 +1,53 @@ | ||||
| /**
 | ||||
|  * Copyright 2019-2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| #include "src/ops/gelu.h" | ||||
| #include <memory> | ||||
| #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | ||||
| #include "src/tensor.h" | ||||
| #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | ||||
| #endif | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | ||||
| int GeLU::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | ||||
|   if (this->primitive_ == nullptr) { | ||||
|     this->primitive_ = new (std::nothrow) schema::PrimitiveT; | ||||
|     if (this->primitive_ == nullptr) { | ||||
|       MS_LOG(ERROR) << "new primitiveT failed"; | ||||
|       return RET_ERROR; | ||||
|     } | ||||
|     this->primitive_->value.type = schema::PrimitiveType_GeLU; | ||||
|   } | ||||
|   if (this->primitive_->value.type != schema::PrimitiveType_GeLU) { | ||||
|     MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | ||||
|     return RET_ERROR; | ||||
|   } | ||||
|   if (this->primitive_->value.value == nullptr) { | ||||
|     this->primitive_->value.value = new (std::nothrow) schema::GeLUT(); | ||||
|     if (this->primitive_->value.value == nullptr) { | ||||
|       MS_LOG(ERROR) << "new primitiveT value failed"; | ||||
|       return RET_ERROR; | ||||
|     } | ||||
|   } | ||||
|   return RET_OK; | ||||
| } | ||||
| #endif | ||||
| }  // namespace lite
 | ||||
| }  // namespace mindspore
 | ||||
| @ -0,0 +1,40 @@ | ||||
| /**
 | ||||
|  * Copyright 2019-2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_GELU_H_ | ||||
| #define LITE_MINDSPORE_LITE_C_OPS_GELU_H_ | ||||
| 
 | ||||
| #include <vector> | ||||
| #include <set> | ||||
| #include <cmath> | ||||
| #include "src/ops/primitive_c.h" | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace lite { | ||||
| class GeLU : public PrimitiveC { | ||||
|  public: | ||||
|   GeLU() = default; | ||||
|   ~GeLU() = default; | ||||
| #ifdef PRIMITIVE_WRITEABLE | ||||
|   MS_DECLARE_PARENT(GeLU, PrimitiveC); | ||||
|   explicit GeLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
|   int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | ||||
| #endif | ||||
| }; | ||||
| }  // namespace lite
 | ||||
| }  // namespace mindspore
 | ||||
| 
 | ||||
| #endif  // LITE_MINDSPORE_LITE_C_OPS_GELU_H_
 | ||||
| @ -0,0 +1,236 @@ | ||||
| /**
 | ||||
|  * Copyright 2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" | ||||
| #include <vector> | ||||
| #include <memory> | ||||
| #include "src/common/log_adapter.h" | ||||
| #include "src/ops/primitive_c.h" | ||||
| #include "src/tensor.h" | ||||
| 
 | ||||
| using mindspore::lite::PrimitiveC; | ||||
| namespace mindspore { | ||||
| namespace opt { | ||||
| namespace { | ||||
| template <typename T> | ||||
| void CopyAttrForArgMinMax(T *left, T *right) { | ||||
|   MS_ASSERT(left != null && right != nullptr); | ||||
|   left->axis = right->axis; | ||||
|   left->outMaxValue = right->outMaxValue; | ||||
|   left->axisType = right->axisType; | ||||
|   left->keepDims = right->keepDims; | ||||
|   left->topK = right->topK; | ||||
| } | ||||
| }  // namespace
 | ||||
| 
 | ||||
| bool MindirInputAdjustOpPass::CheckCNodeIsArgMinMax(const CNodePtr &cnode) { | ||||
|   MS_ASSERT(cnode != nullptr); | ||||
|   auto prim_node = cnode->inputs().at(0); | ||||
|   MS_ASSERT(prim_node != nullptr); | ||||
|   auto prim_value_node = prim_node->cast<ValueNodePtr>(); | ||||
|   if (prim_value_node == nullptr) { | ||||
|     MS_LOG(DEBUG) << "cnode first input is not valueNode."; | ||||
|     return false; | ||||
|   } | ||||
|   auto value = prim_value_node->value(); | ||||
|   MS_ASSERT(value != nullptr); | ||||
|   auto prim_c = value->cast<PrimitiveCPtr>(); | ||||
|   if (prim_c == nullptr) { | ||||
|     MS_LOG(DEBUG) << "prim is not primitiveC."; | ||||
|     return false; | ||||
|   } | ||||
|   auto prim = prim_c->primitiveT(); | ||||
|   MS_ASSERT(prim != nullptr); | ||||
|   return prim->value.type == schema::PrimitiveType_ArgMax || prim->value.type == schema::PrimitiveType_ArgMin; | ||||
| } | ||||
| 
 | ||||
| int MindirInputAdjustOpPass::AdjustArgMinMaxInputs(std::vector<AnfNodePtr> *inputs, bool index_or_value) { | ||||
|   MS_ASSERT(inputs != nullptr); | ||||
|   auto prim_node = inputs->at(0); | ||||
|   MS_ASSERT(prim_node != nullptr); | ||||
|   auto prim_value_node = prim_node->cast<ValueNodePtr>(); | ||||
|   if (prim_value_node == nullptr) { | ||||
|     MS_LOG(ERROR) << "cnode first input is not valueNode."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   auto prim_value = prim_value_node->value(); | ||||
|   if (prim_value == nullptr) { | ||||
|     MS_LOG(ERROR) << "valueNode value is nullptr."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   auto prim_c = prim_value->cast<PrimitiveCPtr>(); | ||||
|   if (prim_c == nullptr) { | ||||
|     MS_LOG(ERROR) << "value is not primitiveC."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   auto prim = prim_c->primitiveT(); | ||||
|   MS_ASSERT(prim != nullptr && prim->value.value != nullptr); | ||||
|   auto attr = prim->value.value; | ||||
|   if (prim->value.type == schema::PrimitiveType_ArgMax) { | ||||
|     reinterpret_cast<schema::ArgMaxT *>(attr)->outMaxValue = index_or_value; | ||||
|   } else if (prim->value.type == schema::PrimitiveType_ArgMin) { | ||||
|     reinterpret_cast<schema::ArgMinT *>(attr)->outMaxValue = index_or_value; | ||||
|   } | ||||
|   return lite::RET_OK; | ||||
| } | ||||
| 
 | ||||
| int MindirInputAdjustOpPass::CopyPrimitiveCForArgMinMax(std::vector<AnfNodePtr> *inputs) { | ||||
|   MS_ASSERT(inputs != nullptr); | ||||
|   auto prim_node = inputs->at(0); | ||||
|   MS_ASSERT(prim_node != nullptr); | ||||
|   auto prim_value_node = prim_node->cast<ValueNodePtr>(); | ||||
|   if (prim_value_node == nullptr) { | ||||
|     MS_LOG(ERROR) << "cnode first input is not valueNode."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   auto prim_value = prim_value_node->value(); | ||||
|   if (prim_value == nullptr) { | ||||
|     MS_LOG(ERROR) << "valueNode value is nullptr."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   auto prim_c = prim_value->cast<PrimitiveCPtr>(); | ||||
|   if (prim_c == nullptr) { | ||||
|     MS_LOG(ERROR) << "value is not primitiveC."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   auto prim = prim_c->primitiveT(); | ||||
|   MS_ASSERT(prim != nullptr && prim->value.value != nullptr); | ||||
|   auto primitive = std::make_unique<schema::PrimitiveT>(); | ||||
|   if (prim->value.type == schema::PrimitiveType_ArgMax) { | ||||
|     primitive->value.type = schema::PrimitiveType_ArgMax; | ||||
|     auto attr = std::make_unique<schema::ArgMaxT>(); | ||||
|     CopyAttrForArgMinMax<schema::ArgMaxT>(attr.get(), reinterpret_cast<schema::ArgMaxT *>(prim->value.value)); | ||||
|     primitive->value.value = attr.release(); | ||||
|   } else { | ||||
|     primitive->value.type = schema::PrimitiveType_ArgMin; | ||||
|     auto attr = std::make_unique<schema::ArgMinT>(); | ||||
|     CopyAttrForArgMinMax<schema::ArgMinT>(attr.get(), reinterpret_cast<schema::ArgMinT *>(prim->value.value)); | ||||
|     primitive->value.value = attr.release(); | ||||
|   } | ||||
|   auto primitive_c = PrimitiveC::Create(primitive.release()); | ||||
|   auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitive_c)); | ||||
|   inputs->erase(inputs->begin()); | ||||
|   inputs->insert(inputs->begin(), value_node); | ||||
|   return lite::RET_OK; | ||||
| } | ||||
| 
 | ||||
| int MindirInputAdjustOpPass::BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, | ||||
|                                                     const CNodePtr &argmin_max) { | ||||
|   MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr); | ||||
|   auto inputs = argmin_max->inputs(); | ||||
|   if (CopyPrimitiveCForArgMinMax(&inputs) != lite::RET_OK) { | ||||
|     MS_LOG(ERROR) << "copy argmin or argmax failed."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   if (AdjustArgMinMaxInputs(&inputs, false) != lite::RET_OK) { | ||||
|     MS_LOG(ERROR) << "adjust argmin or argmax attr failed."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   auto new_cnode = graph->NewCNode(inputs); | ||||
|   new_cnode->set_fullname_with_scope(argmin_max->fullname_with_scope() + "_index"); | ||||
|   auto type_ptr = TypeIdToType(kTypeUnknown); | ||||
|   std::vector<int64_t> shape_vector; | ||||
|   new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | ||||
|   auto manager = graph->manager(); | ||||
|   MS_ASSERT(manager != nullptr); | ||||
|   manager->Replace(tuple_get_item, new_cnode); | ||||
|   return lite::RET_OK; | ||||
| } | ||||
| 
 | ||||
| int MindirInputAdjustOpPass::AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, | ||||
|                                              const CNodePtr &argmin_max) { | ||||
|   MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr); | ||||
|   auto inputs = argmin_max->inputs(); | ||||
|   if (AdjustArgMinMaxInputs(&inputs, true) != lite::RET_OK) { | ||||
|     MS_LOG(ERROR) << "adjust argmin or argmax attr failed."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   auto type_ptr = TypeIdToType(kTypeUnknown); | ||||
|   std::vector<int64_t> shape_vector; | ||||
|   auto abtract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | ||||
|   argmin_max->set_abstract(abtract_tensor); | ||||
|   auto manager = graph->manager(); | ||||
|   MS_ASSERT(manager != nullptr); | ||||
|   manager->Replace(tuple_get_item, argmin_max); | ||||
|   return lite::RET_OK; | ||||
| } | ||||
| 
 | ||||
| int MindirInputAdjustOpPass::AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode) { | ||||
|   MS_ASSERT(graph != nullptr && cnode != nullptr); | ||||
|   auto inputs = cnode->inputs(); | ||||
|   if (inputs.size() != 3) { | ||||
|     MS_LOG(ERROR) << "tupleGetItem inputs size is invalid: " << inputs.size(); | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   auto argmin_max = inputs.at(1); | ||||
|   MS_ASSERT(argmin_max != nullptr); | ||||
|   auto argmin_max_cnode = argmin_max->cast<CNodePtr>(); | ||||
|   if (argmin_max_cnode == nullptr) { | ||||
|     MS_LOG(ERROR) << "the second input is not a cnode."; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   if (!CheckCNodeIsArgMinMax(argmin_max_cnode)) { | ||||
|     MS_LOG(DEBUG) << "tuple_get_item first input is not argmin and argmax."; | ||||
|     return lite::RET_OK; | ||||
|   } | ||||
|   auto index_vnode = inputs.at(2); | ||||
|   auto value_node = index_vnode->cast<ValueNodePtr>(); | ||||
|   if (value_node == nullptr) { | ||||
|     MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; | ||||
|     return lite::RET_ERROR; | ||||
|   } | ||||
|   int index = lite::CastToInt(value_node->value()).front(); | ||||
|   if (index == 0) { | ||||
|     if (BuildCNodeForArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) { | ||||
|       MS_LOG(ERROR) << "build new cnode failed."; | ||||
|       return lite::RET_ERROR; | ||||
|     } | ||||
|   } else if (index == 1) { | ||||
|     if (AdjustArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) { | ||||
|       MS_LOG(ERROR) << "adjust argmin_max failed."; | ||||
|       return lite::RET_ERROR; | ||||
|     } | ||||
|   } | ||||
|   return lite::RET_OK; | ||||
| } | ||||
| 
 | ||||
| bool MindirInputAdjustOpPass::Run(const FuncGraphPtr &graph) { | ||||
|   MS_ASSERT(graph != nullptr); | ||||
|   auto manager = Manage(graph, true); | ||||
|   if (manager == nullptr) { | ||||
|     MS_LOG(ERROR) << "manager is nullptr."; | ||||
|     return lite::RET_NULL_PTR; | ||||
|   } | ||||
|   auto node_list = TopoSort(graph->get_return()); | ||||
|   int status = lite::RET_OK; | ||||
|   for (auto &node : node_list) { | ||||
|     auto cnode = node->cast<CNodePtr>(); | ||||
|     if (cnode == nullptr) { | ||||
|       MS_LOG(DEBUG) << "node is not cnode."; | ||||
|       continue; | ||||
|     } | ||||
|     auto type = opt::GetCNodeType(node); | ||||
|     if (type == schema::PrimitiveType_TupleGetItem) { | ||||
|       status = AdjustTupleGetItemWithArgMinMax(graph, cnode); | ||||
|     } | ||||
|     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | ||||
|       MS_LOG(ERROR) << "adjust input pass is failed."; | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
| }  // namespace opt
 | ||||
| }  // namespace mindspore
 | ||||
| @ -0,0 +1,41 @@ | ||||
| /**
 | ||||
|  * Copyright 2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_ | ||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_ | ||||
| 
 | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include "backend/optimizer/common/pass.h" | ||||
| #include "tools/converter/converter_flags.h" | ||||
| #include "tools/optimizer/common/gllo_utils.h" | ||||
| #include "src/param_value_lite.h" | ||||
| 
 | ||||
| namespace mindspore::opt { | ||||
| class MindirInputAdjustOpPass : public Pass { | ||||
|  public: | ||||
|   MindirInputAdjustOpPass() : Pass("mindir_inputs_adjust_pass") {} | ||||
|   ~MindirInputAdjustOpPass() override = default; | ||||
|   bool CheckCNodeIsArgMinMax(const CNodePtr &cnode); | ||||
|   int AdjustArgMinMaxInputs(std::vector<AnfNodePtr> *inputs, bool index_or_value); | ||||
|   int CopyPrimitiveCForArgMinMax(std::vector<AnfNodePtr> *inputs); | ||||
|   int BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max); | ||||
|   int AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max); | ||||
|   int AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode); | ||||
|   bool Run(const FuncGraphPtr &graph) override; | ||||
| }; | ||||
| }  // namespace mindspore::opt
 | ||||
| #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_
 | ||||
					Loading…
					
					
				
		Reference in new issue