parent
7a3b6667d7
commit
5b76e8f3d7
@ -0,0 +1,151 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/gpu/insert_format_transform_op.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
std::vector<int> TransposeAxis(const std::string &src_format, const std::string &dst_format) {
|
||||
if ((src_format == kOpFormat_NCHW) && (dst_format == kOpFormat_NHWC)) {
|
||||
return {0, 2, 3, 1};
|
||||
} else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) {
|
||||
return {0, 3, 1, 2};
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invaild format transform, from " << src_format << " to " << dst_format;
|
||||
}
|
||||
}
|
||||
|
||||
void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
auto output_type = AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({input_format});
|
||||
builder.SetInputsDeviceType({input_type});
|
||||
builder.SetOutputsFormat({output_format});
|
||||
builder.SetOutputsDeviceType({output_type});
|
||||
builder.SetKernelType(UNKNOWN_KERNEL_TYPE);
|
||||
builder.SetProcessor(kernel::Processor::CUDA);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
|
||||
}
|
||||
|
||||
// Insert transpose op between node and used_node whose position is used_node_index.
|
||||
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
|
||||
int used_node_index, const std::vector<int> &transpose_perm) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// 1.Create a transpose node.
|
||||
auto transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
|
||||
MS_EXCEPTION_IF_NULL(transpose_prim);
|
||||
// 2.Set the input of transpose.
|
||||
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
|
||||
auto transpose_op = graph->NewCNode(transpose_input);
|
||||
// 3.Set the output info of transpose.
|
||||
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
|
||||
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
|
||||
// 4.Set the input of used_node.
|
||||
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
||||
<< ", index: " << used_node_index;
|
||||
AnfAlgo::SetNodeInput(utils::cast<CNodePtr>(used_node), transpose_op, used_node_index);
|
||||
// 5. Update the manager info of transpose op.
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Clear();
|
||||
manager->AddFuncGraph(graph);
|
||||
return transpose_op;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr InsertFormatTransformOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto iter = device::gpu::kKernelFormatPositionMap.find(AnfAlgo::GetCNodeName(node));
|
||||
if (iter == device::gpu::kKernelFormatPositionMap.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto origin_data_format = AnfAlgo::GetOriginDataFormat(node);
|
||||
if (origin_data_format == kOpFormat_DEFAULT) {
|
||||
origin_data_format = kOpFormat_NCHW;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Process node: " << node->fullname_with_scope();
|
||||
// Insert input transpose from origin_data_format to input_format.
|
||||
auto inputs_format = AnfAlgo::GetAllInputFormats(node);
|
||||
for (size_t i = 0; i < inputs_format.size(); i++) {
|
||||
if ((inputs_format[i] != kOpFormat_DEFAULT) && (inputs_format[i] != origin_data_format)) {
|
||||
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto transpose_perm = TransposeAxis(origin_data_format, inputs_format[i]);
|
||||
auto transpose_op = InsertTransposeOp(graph, input_node, node, i, transpose_perm);
|
||||
SetTransposeOpBuildInfo(kOpFormat_DEFAULT, inputs_format[i], transpose_op);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert output transpose from output_format to origin_data_format.
|
||||
auto outputs_format = AnfAlgo::GetAllOutputFormats(node);
|
||||
for (size_t i = 0; i < outputs_format.size(); i++) {
|
||||
if ((outputs_format[i] != kOpFormat_DEFAULT) && (outputs_format[i] != origin_data_format)) {
|
||||
// Find all nodes connected with node output, and change their inputs to transpose.
|
||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i);
|
||||
for (size_t j = 0; j < used_node_list->size(); j++) {
|
||||
auto used_node = used_node_list->at(j).first;
|
||||
auto used_node_index = used_node_list->at(j).second - 1;
|
||||
auto transpose_perm = TransposeAxis(outputs_format[i], origin_data_format);
|
||||
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
|
||||
MS_LOG(DEBUG) << "The used node of [" << node->fullname_with_scope() << "] is tuple item.";
|
||||
// The tuple item need get next used nodes again.
|
||||
ProcessForTupleItem(graph, used_node, used_node_index, transpose_perm, outputs_format[i]);
|
||||
continue;
|
||||
}
|
||||
auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
|
||||
SetTransposeOpBuildInfo(outputs_format[i], kOpFormat_DEFAULT, transpose_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
void InsertFormatTransformOp::ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index,
|
||||
const std::vector<int> &transpose_perm,
|
||||
const std::string &transpose_format) const {
|
||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index);
|
||||
for (size_t i = 0; i < used_node_list->size(); i++) {
|
||||
auto used_node = used_node_list->at(i).first;
|
||||
auto used_node_index = used_node_list->at(i).second - 1;
|
||||
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
|
||||
MS_LOG(EXCEPTION) << "The used node of tuple item can't be tuple item.";
|
||||
}
|
||||
auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
|
||||
SetTransposeOpBuildInfo(transpose_format, kOpFormat_DEFAULT, transpose_op);
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,39 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class InsertFormatTransformOp : public PatternProcessPass {
|
||||
public:
|
||||
explicit InsertFormatTransformOp(bool multigraph = true)
|
||||
: PatternProcessPass("insert_format_transform_op", multigraph) {}
|
||||
~InsertFormatTransformOp() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
void ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index,
|
||||
const std::vector<int> &transpose_perm, const std::string &transpose_format) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_
|
@ -0,0 +1,65 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef RemoveFormatTransformPair::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(X);
|
||||
VectorRef transpose1 = VectorRef({prim::kPrimTranspose, X});
|
||||
VectorRef transpose2 = VectorRef({prim::kPrimTranspose, transpose1});
|
||||
return transpose2;
|
||||
}
|
||||
|
||||
const AnfNodePtr RemoveFormatTransformPair::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
MS_LOG(DEBUG) << "Process node:" << node->fullname_with_scope();
|
||||
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (AnfAlgo::GetCNodeName(node) != prim::kPrimTranspose->name() ||
|
||||
AnfAlgo::GetCNodeName(input_node) != prim::kPrimTranspose->name()) {
|
||||
MS_LOG(EXCEPTION) << "The pattern is not transpose pair, "
|
||||
<< "node:" << AnfAlgo::GetCNodeName(node) << " node input:" << AnfAlgo::GetCNodeName(input_node);
|
||||
}
|
||||
// If transpose operator used by more than one other operators, it cant not be deleted directly.
|
||||
if (IsUsedByOthers(graph, input_node)) {
|
||||
MS_LOG(DEBUG) << "The transpose node [" << input_node->fullname_with_scope()
|
||||
<< "] is used by more than one other operators.";
|
||||
return nullptr;
|
||||
}
|
||||
auto transpose1_input_shape = AnfAlgo::GetInputDeviceShape(input_node, 0);
|
||||
auto transpose2_output_shape = AnfAlgo::GetOutputDeviceShape(node, 0);
|
||||
if (transpose2_output_shape == transpose1_input_shape) {
|
||||
auto transpose1_input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(input_node), 0);
|
||||
MS_EXCEPTION_IF_NULL(transpose1_input_node);
|
||||
return transpose1_input_node;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class RemoveFormatTransformPair : public PatternProcessPass {
|
||||
public:
|
||||
explicit RemoveFormatTransformPair(bool multigraph = true)
|
||||
: PatternProcessPass("remove_format_transform_pair", multigraph) {}
|
||||
~RemoveFormatTransformPair() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_
|
Loading…
Reference in new issue