!10431 Split unsupport transdata when doing ref
From: @lianliguang Reviewed-by: @zhoufeng54,@chujinjin Signed-off-by: @chujinjinpull/10431/MERGE
commit
dafd26196e
@ -1,66 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef SplitUnsupportedTransData::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
return VectorRef({prim::KPrimTransData, X});
|
||||
}
|
||||
|
||||
const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto ori_trans_data = node->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1"
|
||||
<< ori_trans_data->DebugString() << trace::DumpSourceLines(node);
|
||||
}
|
||||
return SplitTransData(func_graph, ori_trans_data);
|
||||
}
|
||||
AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const {
|
||||
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node);
|
||||
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
|
||||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
|
||||
return trans_node;
|
||||
}
|
||||
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
||||
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
||||
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
|
||||
std::vector<AnfNodePtr> next_trans_node_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node};
|
||||
auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
|
||||
next_trans_node->set_abstract(trans_node->abstract());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
|
||||
return next_trans_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -1,37 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SplitUnsupportedTransData : public PatternProcessPass {
|
||||
public:
|
||||
explicit SplitUnsupportedTransData(bool multigraph = true)
|
||||
: PatternProcessPass("split_unsupported_transdata", multigraph) {}
|
||||
~SplitUnsupportedTransData() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
|
Loading…
Reference in new issue