fix transdata's dst format && src format is unmatched with build info when transdata has been spilted

pull/14481/head
LianLiguang 4 years ago
parent 36dbb2690e
commit 9c8d016d66

@ -99,7 +99,7 @@
#include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" #include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" #include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
@ -254,7 +254,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>()); mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>()); mixed_precision_pm->AddPass(std::make_shared<DealRefAndSpiltUnSupportedTransdata>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" #include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h"
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <memory> #include <memory>
@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const { session::KernelWithIndex DealRefAndSpiltUnSupportedTransdata::FindRefOriginNode(const AnfNodePtr &node) const {
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0);
AnfNodePtr cur_node = kernel_with_index.first; AnfNodePtr cur_node = kernel_with_index.first;
size_t cur_out_index = kernel_with_index.second; size_t cur_out_index = kernel_with_index.second;
@ -61,8 +61,9 @@ session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr
return kernel_with_index; return kernel_with_index;
} }
void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, void DealRefAndSpiltUnSupportedTransdata::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph,
const size_t output_index, const size_t input_index) const { const CNodePtr &cnode, const size_t output_index,
const size_t input_index) const {
// record the ref_pair // record the ref_pair
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
@ -71,9 +72,9 @@ void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_g
kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index);
} }
void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, void DealRefAndSpiltUnSupportedTransdata::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const AnfNodePtr &get_item, const AnfNodePtr &final_node, const AnfNodePtr &get_item,
size_t final_index, const AnfNodePtr &final_node, size_t final_index,
const session::KernelWithIndex &origin_pair) const { const session::KernelWithIndex &origin_pair) const {
// record the ref_pair // record the ref_pair
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); auto kernel_graph = func_graph->cast<KernelGraphPtr>();
@ -95,8 +96,9 @@ void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph
// if get_item is nullptr, the additional node will link to the cnode // if get_item is nullptr, the additional node will link to the cnode
// else the additional node will link to the get_item node (the get_item node link to cnode) // else the additional node will link to the get_item node (the get_item node link to cnode)
CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, CNodePtr DealRefAndSpiltUnSupportedTransdata::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph,
size_t output_index, size_t input_index, const CNodePtr &cnode, size_t output_index,
size_t input_index,
const CNodePtr &get_item) const { const CNodePtr &get_item) const {
CNodePtr final_node = (get_item == nullptr ? cnode : get_item); CNodePtr final_node = (get_item == nullptr ? cnode : get_item);
bool need_refresh_ref_addr = false; bool need_refresh_ref_addr = false;
@ -149,8 +151,9 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_
return final_node; return final_node;
} }
CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node, CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
const CNodePtr &cnode, const FuncGraphPtr &func_graph) const { const CNodePtr &cnode,
const FuncGraphPtr &func_graph) const {
std::vector<AnfNodePtr> depend_nodes; std::vector<AnfNodePtr> depend_nodes;
if (get_item != nullptr) { if (get_item != nullptr) {
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node}; depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
@ -159,8 +162,8 @@ CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNo
} }
return func_graph->NewCNode(depend_nodes); return func_graph->NewCNode(depend_nodes);
} }
CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
const std::shared_ptr<kernel::OpInfo> &op_info) const { const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos(); auto ref_infos = op_info->ref_infos();
std::vector<AnfNodePtr> make_tuple_inputs; std::vector<AnfNodePtr> make_tuple_inputs;
@ -185,7 +188,7 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_
return make_tuple; return make_tuple;
} }
CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) const { const std::shared_ptr<kernel::OpInfo> &op_info) const {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
@ -200,13 +203,14 @@ CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph,
return nullptr; return nullptr;
} }
const BaseRef DealRefTransAndCast::DefinePattern() const { const BaseRef DealRefAndSpiltUnSupportedTransdata::DefinePattern() const {
VarPtr V = std::make_shared<CondVar>(UnVisited); VarPtr V = std::make_shared<CondVar>(UnVisited);
VarPtr Xs = std::make_shared<SeqVar>(); VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs}); return VectorRef({V, Xs});
} }
void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { void DealRefAndSpiltUnSupportedTransdata::DealBroadCastAsRef(const FuncGraphPtr &func_graph,
const CNodePtr &cnode) const {
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
auto input_size = AnfAlgo::GetInputTensorNum(cnode); auto input_size = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_size; ++i) { for (size_t i = 0; i < input_size; ++i) {
@ -219,7 +223,7 @@ void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, con
} }
} }
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr DealRefAndSpiltUnSupportedTransdata::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const { const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {
return nullptr; return nullptr;
@ -250,11 +254,12 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
return nullptr; return nullptr;
} }
CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, CNodePtr DealRefAndSpiltUnSupportedTransdata::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
const CNodePtr &cnode) const { const CNodePtr &cnode) const {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode); auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);
// When the input and output format is only one special format just need to be splited into transpose and transdata
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
if (IsFormatInvaild(cnode)) { if (IsFormatInvaild(cnode)) {
@ -262,6 +267,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f
} }
return cnode; return cnode;
} }
// When input and output format are all special format
// the node should be splited to two transdata connected by default format
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
@ -273,6 +280,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f
next_trans_node->set_abstract(cnode->abstract()); next_trans_node->set_abstract(cnode->abstract());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get()); AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
RefreshKernelBuildInfo(AnfAlgo::GetInputFormat(cnode, 0), kOpFormat_DEFAULT, cnode);
RefreshKernelBuildInfo(kOpFormat_DEFAULT, AnfAlgo::GetOutputFormat(next_trans_node, 0), next_trans_node);
if (IsFormatInvaild(cnode)) { if (IsFormatInvaild(cnode)) {
auto after_split_node = DoSplit(func_graph, cnode); auto after_split_node = DoSplit(func_graph, cnode);
AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0); AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0);

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_
#include <memory> #include <memory>
#include "ir/anf.h" #include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
@ -25,10 +25,11 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class DealRefTransAndCast : public TransDataSplit { class DealRefAndSpiltUnSupportedTransdata : public TransDataSplit {
public: public:
explicit DealRefTransAndCast(bool multigraph = true) : TransDataSplit(multigraph, "deal_ref_trans_and_cast") {} explicit DealRefAndSpiltUnSupportedTransdata(bool multigraph = true)
~DealRefTransAndCast() override = default; : TransDataSplit(multigraph, "deal_ref_and_transdata_spilt") {}
~DealRefAndSpiltUnSupportedTransdata() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
@ -52,4 +53,4 @@ class DealRefTransAndCast : public TransDataSplit {
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_
Loading…
Cancel
Save