!10431 Split unsupport transdata when doing ref

From: @lianliguang
Reviewed-by: @zhoufeng54,@chujinjin
Signed-off-by: @chujinjin
pull/10431/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit dafd26196e

@ -71,7 +71,6 @@
#include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h"
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include "backend/optimizer/ascend/format_type/convert_cast_format.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/pass/optimize_dependence.h"
@ -250,7 +249,6 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<SplitUnsupportedTransData>());
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>());
optimizer->AddPassManager(mixed_precision_pm);

@ -256,9 +256,9 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
return trans_node;
}
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
MS_EXCEPTION_IF_NULL(func_graph);
std::string input_format = format;
std::string output_format = format;

@ -94,9 +94,9 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
const bool need_padding, const std::string &op_name);
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type);
CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type);
AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select);

@ -26,8 +26,7 @@
namespace mindspore {
namespace opt {
namespace {
session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) {
session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const {
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0);
AnfNodePtr cur_node = kernel_with_index.first;
size_t cur_out_index = kernel_with_index.second;
@ -62,8 +61,8 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) {
return kernel_with_index;
}
void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index,
const size_t input_index) {
void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const size_t output_index, const size_t input_index) const {
// record the ref_pair
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -72,9 +71,10 @@ void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr
kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index);
}
void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item,
const AnfNodePtr &final_node, size_t final_index,
const session::KernelWithIndex &origin_pair) {
void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const AnfNodePtr &get_item, const AnfNodePtr &final_node,
size_t final_index,
const session::KernelWithIndex &origin_pair) const {
// record the ref_pair
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -95,9 +95,10 @@ void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cno
// if get_item is nullptr, the additional node will link to the cnode
// else the additional node will link to the get_item node (the get_item node link to cnode)
AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index,
size_t input_index, const AnfNodePtr &get_item) {
AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item);
CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
size_t output_index, size_t input_index,
const CNodePtr &get_item) const {
CNodePtr final_node = (get_item == nullptr ? cnode : get_item);
bool need_refresh_ref_addr = false;
size_t final_index = output_index;
AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index);
@ -119,6 +120,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
auto kernel_select = std::make_shared<KernelSelect>();
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
final_node = SplitTransdataIfNotSupported(func_graph, final_node);
final_index = 0;
need_refresh_ref_addr = true;
MS_EXCEPTION_IF_NULL(final_node);
@ -148,15 +150,15 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
return final_node;
}
AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) {
CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) const {
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos();
std::vector<AnfNodePtr> make_tuple_inputs;
AbstractBasePtrList abstract_list;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index);
CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index);
// deal with ref output
if (ref_infos.count(output_index) != 0) {
auto input_index = ref_infos.at(output_index);
@ -167,14 +169,14 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP
make_tuple_inputs.push_back(final_node);
}
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return make_tuple;
}
AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) {
CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) const {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos();
@ -187,7 +189,6 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
}
return nullptr;
}
} // namespace
const BaseRef DealRefTransAndCast::DefinePattern() const {
VarPtr V = std::make_shared<CondVar>(UnVisited);
@ -195,7 +196,7 @@ const BaseRef DealRefTransAndCast::DefinePattern() const {
return VectorRef({V, Xs});
}
void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_size; ++i) {
@ -238,5 +239,38 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
}
return nullptr;
}
CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
const CNodePtr &cnode) const {
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
MS_EXCEPTION_IF_NULL(kernel_info);
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
if (IsFormatInvaild(cnode)) {
return DoSplit(func_graph, cnode);
}
return cnode;
}
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
std::vector<AnfNodePtr> next_trans_node_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), cnode};
MS_EXCEPTION_IF_NULL(func_graph);
auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
next_trans_node->set_abstract(cnode->abstract());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
if (IsFormatInvaild(cnode)) {
auto after_split_node = DoSplit(func_graph, cnode);
AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0);
}
if (IsFormatInvaild(next_trans_node)) {
return DoSplit(func_graph, next_trans_node);
}
return next_trans_node;
}
} // namespace opt
} // namespace mindspore

@ -16,20 +16,37 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_
#include <memory>
#include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ir_fission/transdata_split.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class DealRefTransAndCast : public PatternProcessPass {
class DealRefTransAndCast : public TransDataSplit {
public:
explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {}
explicit DealRefTransAndCast(bool multigraph = true) : TransDataSplit(multigraph, "deal_ref_trans_and_cast") {}
~DealRefTransAndCast() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) const;
CNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) const;
CNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index,
size_t input_index, const CNodePtr &get_item) const;
void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item,
const AnfNodePtr &final_node, size_t final_index,
const session::KernelWithIndex &origin_pair) const;
void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index,
const size_t input_index) const;
session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) const;
};
} // namespace opt
} // namespace mindspore

@ -31,15 +31,6 @@ const BaseRef InsertTransOp::DefinePattern() const {
return VectorRef({V, Xs});
}
bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem});
auto iter = std::find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) {
return true;
}
return false;
}
const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {

@ -1,66 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include <vector>
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/trace_base.h"
namespace mindspore {
namespace opt {
const BaseRef SplitUnsupportedTransData::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
return VectorRef({prim::KPrimTransData, X});
}
const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
return nullptr;
}
auto ori_trans_data = node->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) {
return nullptr;
}
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data);
MS_EXCEPTION_IF_NULL(kernel_info);
if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) {
MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1"
<< ori_trans_data->DebugString() << trace::DumpSourceLines(node);
}
return SplitTransData(func_graph, ori_trans_data);
}
AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const {
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node);
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
return trans_node;
}
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
std::vector<AnfNodePtr> next_trans_node_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node};
auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
next_trans_node->set_abstract(trans_node->abstract());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
return next_trans_node;
}
} // namespace opt
} // namespace mindspore

@ -1,37 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class SplitUnsupportedTransData : public PatternProcessPass {
public:
explicit SplitUnsupportedTransData(bool multigraph = true)
: PatternProcessPass("split_unsupported_transdata", multigraph) {}
~SplitUnsupportedTransData() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H

@ -27,22 +27,20 @@ const std::set<std::pair<string, string>> invalid_formats_pair = {
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM},
{kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
bool TransDataSplit::Run(const FuncGraphPtr &func_graph) {
const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = false;
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum);
if (IsFormatInvaild(node)) {
TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
changed = DoSplit(func_graph, node);
}
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum);
if (IsFormatInvaild(node)) {
TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
return DoSplit(func_graph, node);
}
}
return changed;
return nullptr;
}
bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) {
bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
@ -52,8 +50,14 @@ bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) {
return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end();
}
const BaseRef TransDataSplit::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
return VectorRef({prim::KPrimTransData, X});
}
// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW)
bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
@ -63,9 +67,9 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
auto input_format = AnfAlgo::GetInputFormat(node, 0);
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
AnfNodePtr new_transdata_node = nullptr;
AnfNodePtr new_transpose_node = nullptr;
AnfNodePtr new_replace_node = nullptr;
CNodePtr new_transdata_node = nullptr;
CNodePtr new_transpose_node = nullptr;
CNodePtr new_replace_node = nullptr;
auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
// if output_format=default transdata need split transdata->transpose else transpose->transdata
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
@ -96,16 +100,8 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
new_transdata_node->set_abstract(node->abstract());
new_replace_node = new_transdata_node;
}
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(func_graph);
if (!manager->Replace(node, new_replace_node)) {
MS_LOG(EXCEPTION) << "Manager replace node failed"
<< " trace: " << trace::DumpSourceLines(node);
}
MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success.";
return true;
return new_replace_node;
}
} // namespace opt
} // namespace mindspore

@ -29,15 +29,17 @@
namespace mindspore {
namespace opt {
class TransDataSplit : public Pass {
class TransDataSplit : public PatternProcessPass {
public:
TransDataSplit() : Pass("trans_data_split"), kernel_select_(std::make_shared<KernelSelect>()) {}
explicit TransDataSplit(bool multigraph = true, const string &name = "trans_data_split")
: PatternProcessPass(name, multigraph), kernel_select_(std::make_shared<KernelSelect>()) {}
~TransDataSplit() override = default;
bool Run(const FuncGraphPtr &graph) override;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
private:
bool DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
bool IsFormatInvaild(const AnfNodePtr &node);
protected:
CNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
bool IsFormatInvaild(const AnfNodePtr &node) const;
KernelSelectPtr kernel_select_;
};
} // namespace opt

@ -481,13 +481,13 @@ bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
return true;
}
AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
auto idx = NewValueNode(SizeToLong(output_idx));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
idx->set_abstract(abstract_scalar);
AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_getitem);
tuple_getitem->set_scope(node->scope());
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);

@ -169,7 +169,7 @@ void HideNopNode(session::KernelGraph *const graph);
void RemoveNopNode(session::KernelGraph *const graph);
AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx);
CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx);
bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);

Loading…
Cancel
Save