parent
							
								
									0db846978e
								
							
						
					
					
						commit
						46b8ab3c40
					
				| @ -1,65 +0,0 @@ | ||||
| /**
 | ||||
|  * Copyright 2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| #include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" | ||||
| #include <memory> | ||||
| #include "backend/optimizer/ascend/ascend_helper.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "utils/utils.h" | ||||
| #include "base/core_ops.h" | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace opt { | ||||
| const BaseRef InsertReshapeForExtractImagePatchesOp::DefinePattern() const { | ||||
|   VarPtr Xs = std::make_shared<SeqVar>(); | ||||
|   return VectorRef({prim::kPrimExtractImagePatches, Xs}); | ||||
| } | ||||
| 
 | ||||
| const AnfNodePtr InsertReshapeForExtractImagePatchesOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
|                                                                 const EquivPtr &equiv) const { | ||||
|   MS_EXCEPTION_IF_NULL(func_graph); | ||||
|   MS_EXCEPTION_IF_NULL(equiv); | ||||
|   auto extract = CheckAnfNodeIfCNodeAndInputSize(node, 2); | ||||
|   MS_EXCEPTION_IF_NULL(extract); | ||||
|   auto in_node = extract->input(1); | ||||
|   MS_EXCEPTION_IF_NULL(in_node); | ||||
|   auto extract_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(extract); | ||||
|   auto in_node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(in_node); | ||||
|   MS_EXCEPTION_IF_NULL(extract_kernel_build_info); | ||||
|   MS_EXCEPTION_IF_NULL(in_node_kernel_build_info); | ||||
|   std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | ||||
|                                             in_node}; | ||||
|   auto reshape_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
|   reshape_builder->SetInputsFormat({kOpFormat_NC1HWC0}); | ||||
|   reshape_builder->SetOutputsFormat({kOpFormat_NC1HWC0}); | ||||
|   reshape_builder->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); | ||||
|   reshape_builder->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); | ||||
|   reshape_builder->SetKernelType(in_node_kernel_build_info->kernel_type()); | ||||
|   reshape_builder->SetFusionType(in_node_kernel_build_info->fusion_type()); | ||||
|   reshape_builder->SetProcessor(in_node_kernel_build_info->processor()); | ||||
| 
 | ||||
|   auto reshape = func_graph->NewCNode(reshape_inputs); | ||||
|   reshape->set_scope(in_node->scope()); | ||||
|   auto shape_tmp = AnfAlgo::GetOutputInferShape(in_node, 0); | ||||
|   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}, | ||||
|                                       {{shape_tmp[0], shape_tmp[2], shape_tmp[3], shape_tmp[1]}}, reshape.get()); | ||||
|   AnfAlgo::SetSelectKernelBuildInfo(reshape_builder->Build(), reshape.get()); | ||||
|   AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reshape); | ||||
|   AnfAlgo::SetNodeInput(extract, reshape, 0); | ||||
|   return extract; | ||||
| } | ||||
| }  // namespace opt
 | ||||
| }  // namespace mindspore
 | ||||
| @ -1,41 +0,0 @@ | ||||
| /**
 | ||||
|  * Copyright 2020 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | ||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | ||||
| #include <vector> | ||||
| #include <string> | ||||
| #include <utility> | ||||
| #include <memory> | ||||
| #include "ir/anf.h" | ||||
| #include "backend/optimizer/common/pattern_engine.h" | ||||
| #include "backend/optimizer/common/helper.h" | ||||
| #include "backend/optimizer/common/optimizer.h" | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace opt { | ||||
| class InsertReshapeForExtractImagePatchesOp : public PatternProcessPass { | ||||
|  public: | ||||
|   explicit InsertReshapeForExtractImagePatchesOp(bool multigraph = true) | ||||
|       : PatternProcessPass("insert_reshape_for_extract_image_patches_op", multigraph) {} | ||||
|   ~InsertReshapeForExtractImagePatchesOp() override = default; | ||||
|   const BaseRef DefinePattern() const override; | ||||
|   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| }; | ||||
| }  // namespace opt
 | ||||
| }  // namespace mindspore
 | ||||
| 
 | ||||
| #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H
 | ||||
					Loading…
					
					
				
		Reference in new issue