|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
|
|
|
|
|
#include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h"
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <memory>
|
|
|
|
@ -26,7 +26,7 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const {
|
|
|
|
|
session::KernelWithIndex DealRefAndSpiltUnSupportedTransdata::FindRefOriginNode(const AnfNodePtr &node) const {
|
|
|
|
|
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0);
|
|
|
|
|
AnfNodePtr cur_node = kernel_with_index.first;
|
|
|
|
|
size_t cur_out_index = kernel_with_index.second;
|
|
|
|
@ -61,8 +61,9 @@ session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr
|
|
|
|
|
return kernel_with_index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|
|
|
|
const size_t output_index, const size_t input_index) const {
|
|
|
|
|
void DealRefAndSpiltUnSupportedTransdata::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph,
|
|
|
|
|
const CNodePtr &cnode, const size_t output_index,
|
|
|
|
|
const size_t input_index) const {
|
|
|
|
|
// record the ref_pair
|
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
@ -71,10 +72,10 @@ void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_g
|
|
|
|
|
kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|
|
|
|
const AnfNodePtr &get_item, const AnfNodePtr &final_node,
|
|
|
|
|
size_t final_index,
|
|
|
|
|
const session::KernelWithIndex &origin_pair) const {
|
|
|
|
|
void DealRefAndSpiltUnSupportedTransdata::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|
|
|
|
const AnfNodePtr &get_item,
|
|
|
|
|
const AnfNodePtr &final_node, size_t final_index,
|
|
|
|
|
const session::KernelWithIndex &origin_pair) const {
|
|
|
|
|
// record the ref_pair
|
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
@ -95,9 +96,10 @@ void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph
|
|
|
|
|
|
|
|
|
|
// if get_item is nullptr, the additional node will link to the cnode
|
|
|
|
|
// else the additional node will link to the get_item node (the get_item node link to cnode)
|
|
|
|
|
CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|
|
|
|
size_t output_index, size_t input_index,
|
|
|
|
|
const CNodePtr &get_item) const {
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph,
|
|
|
|
|
const CNodePtr &cnode, size_t output_index,
|
|
|
|
|
size_t input_index,
|
|
|
|
|
const CNodePtr &get_item) const {
|
|
|
|
|
CNodePtr final_node = (get_item == nullptr ? cnode : get_item);
|
|
|
|
|
bool need_refresh_ref_addr = false;
|
|
|
|
|
size_t final_index = output_index;
|
|
|
|
@ -149,8 +151,9 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_
|
|
|
|
|
return final_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
|
|
|
|
|
const CNodePtr &cnode, const FuncGraphPtr &func_graph) const {
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
|
|
|
|
|
const CNodePtr &cnode,
|
|
|
|
|
const FuncGraphPtr &func_graph) const {
|
|
|
|
|
std::vector<AnfNodePtr> depend_nodes;
|
|
|
|
|
if (get_item != nullptr) {
|
|
|
|
|
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
|
|
|
|
@ -159,8 +162,8 @@ CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNo
|
|
|
|
|
}
|
|
|
|
|
return func_graph->NewCNode(depend_nodes);
|
|
|
|
|
}
|
|
|
|
|
CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|
|
|
|
const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
|
|
|
|
|
const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_info);
|
|
|
|
|
auto ref_infos = op_info->ref_infos();
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
|
|
|
@ -185,8 +188,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_
|
|
|
|
|
return make_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|
|
|
|
const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|
|
|
|
const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_info);
|
|
|
|
|
auto ref_infos = op_info->ref_infos();
|
|
|
|
@ -200,13 +203,14 @@ CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph,
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef DealRefTransAndCast::DefinePattern() const {
|
|
|
|
|
const BaseRef DealRefAndSpiltUnSupportedTransdata::DefinePattern() const {
|
|
|
|
|
VarPtr V = std::make_shared<CondVar>(UnVisited);
|
|
|
|
|
VarPtr Xs = std::make_shared<SeqVar>();
|
|
|
|
|
return VectorRef({V, Xs});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
|
|
|
|
|
void DealRefAndSpiltUnSupportedTransdata::DealBroadCastAsRef(const FuncGraphPtr &func_graph,
|
|
|
|
|
const CNodePtr &cnode) const {
|
|
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
|
|
|
|
|
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
|
|
|
|
|
for (size_t i = 0; i < input_size; ++i) {
|
|
|
|
@ -219,8 +223,8 @@ void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, con
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|
|
const EquivPtr &) const {
|
|
|
|
|
const AnfNodePtr DealRefAndSpiltUnSupportedTransdata::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|
|
const EquivPtr &) const {
|
|
|
|
|
if (node == nullptr || !node->isa<CNode>()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
@ -250,11 +254,12 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
|
|
|
|
|
const CNodePtr &cnode) const {
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
|
|
|
|
|
const CNodePtr &cnode) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
// When the input and output format is only one special format just need to be splited into transpose and transdata
|
|
|
|
|
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
|
|
|
|
|
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
|
|
|
|
|
if (IsFormatInvaild(cnode)) {
|
|
|
|
@ -262,6 +267,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f
|
|
|
|
|
}
|
|
|
|
|
return cnode;
|
|
|
|
|
}
|
|
|
|
|
// When input and output format are all special format
|
|
|
|
|
// the node should be splited to two transdata connected by default format
|
|
|
|
|
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
|
|
|
|
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
|
|
|
|
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
|
|
|
|
@ -273,6 +280,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f
|
|
|
|
|
next_trans_node->set_abstract(cnode->abstract());
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get());
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
|
|
|
|
|
RefreshKernelBuildInfo(AnfAlgo::GetInputFormat(cnode, 0), kOpFormat_DEFAULT, cnode);
|
|
|
|
|
RefreshKernelBuildInfo(kOpFormat_DEFAULT, AnfAlgo::GetOutputFormat(next_trans_node, 0), next_trans_node);
|
|
|
|
|
if (IsFormatInvaild(cnode)) {
|
|
|
|
|
auto after_split_node = DoSplit(func_graph, cnode);
|
|
|
|
|
AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0);
|