You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
356 lines
17 KiB
356 lines
17 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "backend/optimizer/ascend/ascend_helper.h"
|
|
#include <set>
|
|
#include "common/trans.h"
|
|
#include "common/utils.h"
|
|
#include "backend/optimizer/common/helper.h"
|
|
#include "utils/utils.h"
|
|
#include "runtime/device/kernel_info.h"
|
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
|
#include "backend/kernel_compiler/common_utils.h"
|
|
#include "frontend/operator/ops.h"
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
#include "backend/session/kernel_graph.h"
|
|
#include "utils/context/ms_context.h"
|
|
|
|
namespace mindspore {
|
|
namespace opt {
|
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
|
namespace {
|
|
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW};
|
|
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
|
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
|
std::vector<AnfNodePtr> trans_inputs;
|
|
auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
|
|
trans_inputs.emplace_back(NewValueNode(prim));
|
|
trans_inputs.emplace_back(input_node);
|
|
auto reshape = func_graph->NewCNode(trans_inputs);
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
|
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
|
|
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
|
|
reshape->set_scope(input_node->scope());
|
|
kernel_select->SelectKernel(reshape);
|
|
return reshape;
|
|
}
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
|
AnfNodePtr trans_node = nullptr;
|
|
AnfNodePtr input_node = node;
|
|
CNodePtr trans_data = nullptr;
|
|
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
|
|
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
|
|
std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
// if insert transdata for input we need to change the input
|
|
if (is_insert_input) {
|
|
if (!node->isa<CNode>()) {
|
|
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
|
|
}
|
|
auto cnode = node->cast<CNodePtr>();
|
|
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
|
|
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
|
|
padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index);
|
|
}
|
|
bool need_padding = false;
|
|
if (is_insert_input) {
|
|
need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
|
|
} else {
|
|
need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
|
|
}
|
|
if (!need_padding) {
|
|
// don't need padding insert transdata only
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
trans_node = trans_data;
|
|
} else if (is_insert_input) {
|
|
// if need padding & is input need insert a transdata
|
|
// reshape[padding shape] -> transdata[padding shape] -> node
|
|
auto padding_shape =
|
|
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
|
|
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
|
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
trans_node = trans_data;
|
|
} else {
|
|
// if need padding & is output need insert a transdata
|
|
// node -> transdata[padding shape] -> reshape[ori_shape]
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
auto reshape_node =
|
|
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
|
|
trans_node = reshape_node;
|
|
}
|
|
// refresh the transdata's format to ori format & dst format
|
|
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
|
|
return trans_node;
|
|
}
|
|
|
|
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
|
|
const KernelSelectPtr &kernel_select) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
auto input_node = AnfAlgo::GetInputNode(node, index);
|
|
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
|
|
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
|
auto real_input = node_with_index.first;
|
|
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
|
|
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
AnfAlgo::SetNodeInput(node, input_node, index);
|
|
}
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
|
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
|
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
|
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
|
<< " To DefaultFormat , index: " << index;
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
|
}
|
|
return input_node;
|
|
}
|
|
|
|
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
const KernelSelectPtr &kernel_select) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
|
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
|
|
<< node->DebugString();
|
|
}
|
|
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
|
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
|
}
|
|
return node;
|
|
}
|
|
|
|
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
const KernelSelectPtr &kernel_select) {
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
|
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
|
|
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
|
MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
|
|
<< node->DebugString();
|
|
}
|
|
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
|
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
|
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
|
|
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
|
|
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
|
|
}
|
|
make_tuple_inputs.emplace_back(trans_op);
|
|
} else {
|
|
// No need insert trans op.
|
|
make_tuple_inputs.push_back(tuple_getitem);
|
|
}
|
|
}
|
|
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
|
return make_tuple;
|
|
}
|
|
} // namespace
|
|
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
|
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
|
|
MS_EXCEPTION_IF_NULL(trans_data);
|
|
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
|
|
MS_EXCEPTION_IF_NULL(ori_build_info);
|
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
|
|
builder->SetInputsFormat({input_format});
|
|
builder->SetInputReshapeType({reshape_type});
|
|
builder->SetOutputReshapeType({reshape_type});
|
|
builder->SetOutputsFormat({output_format});
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
|
|
}
|
|
|
|
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
|
const bool need_padding, const std::string &op_name) {
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
std::vector<AnfNodePtr> trans_inputs;
|
|
auto prim = std::make_shared<Primitive>(op_name);
|
|
trans_inputs.push_back(NewValueNode(prim));
|
|
trans_inputs.push_back(input);
|
|
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
|
|
MS_EXCEPTION_IF_NULL(trans_node);
|
|
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
|
if (need_padding) {
|
|
// if need padding we should set the transdata node's shape to the padding shape
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
|
|
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)},
|
|
trans_node.get());
|
|
} else {
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
|
|
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
|
|
}
|
|
// special handle for ut
|
|
if (trans_node->kernel_info() == nullptr) {
|
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
trans_node->set_kernel_info(kernel_info);
|
|
}
|
|
MS_EXCEPTION_IF_NULL(kernel_select);
|
|
kernel_select->SelectKernel(trans_node);
|
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
|
|
MS_EXCEPTION_IF_NULL(trans_node);
|
|
trans_node->set_scope(input->scope());
|
|
return trans_node;
|
|
}
|
|
|
|
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
|
|
const TypeId &input_type, const TypeId &output_type,
|
|
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
std::string input_format = format;
|
|
std::string output_format = format;
|
|
std::vector<AnfNodePtr> new_cast_inputs;
|
|
auto prim = std::make_shared<Primitive>(prim::kPrimCast->name());
|
|
new_cast_inputs.push_back(NewValueNode(prim));
|
|
new_cast_inputs.push_back(input);
|
|
CNodePtr cast = func_graph->NewCNode(new_cast_inputs);
|
|
MS_EXCEPTION_IF_NULL(cast);
|
|
// set kernel build info
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
|
builder.SetInputsFormat({input_format});
|
|
builder.SetOutputsFormat({output_format});
|
|
builder.SetInputsDeviceType({input_type});
|
|
builder.SetOutputsDeviceType({output_type});
|
|
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
|
builder.SetProcessor(kernel::Processor::AICORE);
|
|
if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) {
|
|
builder.SetKernelType(KernelType::TBE_KERNEL);
|
|
} else {
|
|
builder.SetKernelType(KernelType::AKG_KERNEL);
|
|
}
|
|
// if kernel info is null , it remarks this function is running ut
|
|
if (cast->kernel_info() == nullptr) {
|
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
cast->set_kernel_info(kernel_info);
|
|
}
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
|
|
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
|
|
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
|
return cast;
|
|
}
|
|
|
|
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
const KernelSelectPtr &kernel_select) {
|
|
size_t outputs_num = AnfAlgo::GetOutputTensorNum(node);
|
|
if (outputs_num == 0) {
|
|
return node;
|
|
}
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
// Single output
|
|
if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) {
|
|
auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select);
|
|
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
|
|
kernel_graph->ReplaceInternalOutput(node, new_node);
|
|
}
|
|
return new_node;
|
|
}
|
|
// Multiple output
|
|
return InsertTransOpForMultipleOutput(func_graph, node, kernel_select);
|
|
}
|
|
|
|
AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
const KernelSelectPtr &kernel_select) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
auto cnode = node->cast<CNodePtr>();
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
|
AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
new_inputs.push_back(input_node);
|
|
}
|
|
CNodePtr new_cnode = nullptr;
|
|
// cnode changed so make a new cnode to differ from original one.
|
|
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
|
if (kernel_graph == nullptr) {
|
|
new_cnode = std::make_shared<CNode>(*cnode);
|
|
} else {
|
|
new_cnode = kernel_graph->NewCNode(cnode);
|
|
}
|
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
|
new_cnode->set_inputs(new_inputs);
|
|
return new_cnode;
|
|
}
|
|
|
|
CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
|
const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
|
TypeId origin_type(kTypeUnknown);
|
|
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
|
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
|
|
auto real_input_node = kernel_with_index.first;
|
|
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
|
// weight
|
|
origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
|
|
if (origin_type == kTypeUnknown) {
|
|
origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index);
|
|
}
|
|
} else {
|
|
// feature map
|
|
origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
|
}
|
|
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
|
const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index);
|
|
const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
|
|
// In graph kernel, we check parameter,
|
|
// the eliminate pass will not eliminate this case, so we just do not insert the noused cast.
|
|
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
|
|
new_inputs.push_back(cur_input);
|
|
} else if (origin_type != device_type) {
|
|
auto cast =
|
|
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
|
|
MS_EXCEPTION_IF_NULL(cast);
|
|
cast->set_scope(cnode->scope());
|
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
|
|
new_inputs.push_back(cast);
|
|
} else {
|
|
new_inputs.push_back(cur_input);
|
|
}
|
|
}
|
|
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
|
CNodePtr new_node = nullptr;
|
|
if (kernel_graph == nullptr) {
|
|
new_node = std::make_shared<CNode>(*cnode);
|
|
} else {
|
|
new_node = kernel_graph->NewCNode(cnode);
|
|
}
|
|
MS_EXCEPTION_IF_NULL(new_node);
|
|
new_node->set_inputs(new_inputs);
|
|
return new_node;
|
|
}
|
|
|
|
AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
auto prim = std::make_shared<Primitive>(kMemCpyAsyncOpName);
|
|
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
|
|
auto new_node = graph->NewCNode(new_node_inputs);
|
|
MS_EXCEPTION_IF_NULL(new_node);
|
|
new_node->set_abstract(node->abstract());
|
|
new_node->set_scope(node->scope());
|
|
return new_node;
|
|
}
|
|
} // namespace opt
|
|
} // namespace mindspore
|