parent
2ee4fdade6
commit
dfa7b79d75
@ -0,0 +1,96 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "kernel/kernel_fusion.h"
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
#include "operator/ops.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "pre_activate/common/fusion_id_allocator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||||
|
const session::KernelGraph &kernel_graph,
|
||||||
|
FusedNodeRecord *candidate_fusion) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||||
|
auto manager = kernel_graph.manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto getitem = relu_input->cast<CNodePtr>();
|
||||||
|
auto bnupdate = getitem->input(1);
|
||||||
|
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
|
||||||
|
std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
|
||||||
|
for (auto out_getitem : manager->node_users()[bnupdate]) {
|
||||||
|
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||||
|
auto input2 = out_getitem_ptr->input(2);
|
||||||
|
auto output_idx = GetValue<int>(GetValueNode(input2));
|
||||||
|
output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size());
|
||||||
|
}
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
|
||||||
|
std::unordered_set<AnfNodePtr> record{cnode, bnupdate};
|
||||||
|
candidate_fusion->push_back(record);
|
||||||
|
SetRecordFusionId(record);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BnupdateEltwiseFusionPass::MatchBnupdateOpNamePattern(const session::KernelGraph &kernel_graph,
|
||||||
|
FusedNodeRecord *candidate_fusion) {
|
||||||
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||||
|
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
|
||||||
|
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||||
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||||
|
auto eltwise_input = cnode->input(1);
|
||||||
|
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
|
||||||
|
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
|
||||||
|
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BnupdateEltwiseFusionPass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) {
|
||||||
|
auto manager = kernel_graph.manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto return_node = kernel_graph.get_return();
|
||||||
|
MS_EXCEPTION_IF_NULL(return_node);
|
||||||
|
if (return_node->inputs().size() <= 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "MatchBufferFusionPattern start...";
|
||||||
|
FusedNodeRecord candidate_fusion;
|
||||||
|
|
||||||
|
MatchBnupdateOpNamePattern(kernel_graph, &candidate_fusion);
|
||||||
|
if (candidate_fusion.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "MatchBufferFusionPattern Success...";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,50 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_
|
||||||
|
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "pre_activate/common/pass.h"
|
||||||
|
#include "pre_activate/common/fusion_id_allocator.h"
|
||||||
|
#include "device/kernel_info.h"
|
||||||
|
#include "kernel/kernel.h"
|
||||||
|
#include "session/kernel_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
||||||
|
|
||||||
|
class BnupdateEltwiseFusionPass : public FusionBasePass {
|
||||||
|
public:
|
||||||
|
BnupdateEltwiseFusionPass() : FusionBasePass("BnupdateEltwiseFusionPass") {}
|
||||||
|
explicit BnupdateEltwiseFusionPass(FusionIdAllocator *idAllocator)
|
||||||
|
: FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {}
|
||||||
|
~BnupdateEltwiseFusionPass() override = default;
|
||||||
|
bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
||||||
|
FusedNodeRecord *candidate_fusion);
|
||||||
|
void MatchBnupdateOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_
|
@ -0,0 +1,107 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "kernel/kernel_fusion.h"
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
#include "operator/ops.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "pre_activate/common/fusion_id_allocator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnode,
|
||||||
|
const session::KernelGraph &kernel_graph,
|
||||||
|
FusedNodeRecord *candidate_fusion, bool is_order) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||||
|
auto manager = kernel_graph.manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
if (is_order) {
|
||||||
|
// DepthwiseConvolution--->Elemwise
|
||||||
|
auto depthwise_conv = cnode->input(1);
|
||||||
|
MS_EXCEPTION_IF_NULL(depthwise_conv);
|
||||||
|
if (cnode->isa<CNode>() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) {
|
||||||
|
std::vector<int> output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())};
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv);
|
||||||
|
std::unordered_set<AnfNodePtr> record{cnode, depthwise_conv};
|
||||||
|
candidate_fusion->push_back(record);
|
||||||
|
SetRecordFusionId(record);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Elemwise-->DepthwiseConvolution
|
||||||
|
auto relu = cnode->input(1);
|
||||||
|
MS_EXCEPTION_IF_NULL(relu);
|
||||||
|
if (cnode->isa<CNode>() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) {
|
||||||
|
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu].size())};
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu);
|
||||||
|
std::unordered_set<AnfNodePtr> record{cnode, relu};
|
||||||
|
candidate_fusion->push_back(record);
|
||||||
|
SetRecordFusionId(record);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DepthwiseConvEltwiseFusionPass::MatchDepthwiseOpNamePattern(const session::KernelGraph &kernel_graph,
|
||||||
|
FusedNodeRecord *candidate_fusion) {
|
||||||
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||||
|
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
|
||||||
|
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||||
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||||
|
auto eltwise_input = cnode->input(1);
|
||||||
|
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
|
||||||
|
if (eltwise_input->isa<CNode>() &&
|
||||||
|
AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) {
|
||||||
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) {
|
||||||
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DepthwiseConvEltwiseFusionPass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) {
|
||||||
|
auto manager = kernel_graph.manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto return_node = kernel_graph.get_return();
|
||||||
|
MS_EXCEPTION_IF_NULL(return_node);
|
||||||
|
if (return_node->inputs().size() <= 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "MatchBufferFusionPattern start...";
|
||||||
|
FusedNodeRecord candidate_fusion;
|
||||||
|
MatchDepthwiseOpNamePattern(kernel_graph, &candidate_fusion);
|
||||||
|
if (candidate_fusion.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "MatchBufferFusionPattern Success...";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,50 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_
|
||||||
|
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "pre_activate/common/pass.h"
|
||||||
|
#include "pre_activate/common/fusion_id_allocator.h"
|
||||||
|
#include "device/kernel_info.h"
|
||||||
|
#include "kernel/kernel.h"
|
||||||
|
#include "session/kernel_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
||||||
|
|
||||||
|
class DepthwiseConvEltwiseFusionPass : public FusionBasePass {
|
||||||
|
public:
|
||||||
|
DepthwiseConvEltwiseFusionPass() : FusionBasePass("DepthwiseConvEltwiseFusionPass") {}
|
||||||
|
explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocator *idAllocator)
|
||||||
|
: FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {}
|
||||||
|
~DepthwiseConvEltwiseFusionPass() override = default;
|
||||||
|
bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||||
|
FusedNodeRecord *candidate_fusion, bool is_order);
|
||||||
|
void MatchDepthwiseOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_
|
@ -0,0 +1,38 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h"
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <memory>
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "pre_activate/common/fusion_id_allocator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
void FusionBasePass::SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record) {
|
||||||
|
auto id = fusion_id_allocator->AllocateFusionId();
|
||||||
|
for (auto node : record) {
|
||||||
|
fusion_id_allocator->SetFusionId(node, id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool FusionBasePass::Run(const FuncGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
return MatchUBFusionPattern(*kernel_graph);
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,50 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_
|
||||||
|
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "pre_activate/common/pass.h"
|
||||||
|
#include "pre_activate/common/fusion_id_allocator.h"
|
||||||
|
#include "device/kernel_info.h"
|
||||||
|
#include "kernel/kernel.h"
|
||||||
|
#include "session/kernel_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
||||||
|
|
||||||
|
class FusionBasePass : public Pass {
|
||||||
|
public:
|
||||||
|
explicit FusionBasePass(const std::string &name) : Pass(name) {}
|
||||||
|
FusionBasePass(const std::string &name, FusionIdAllocator *idAllocator)
|
||||||
|
: Pass(name), fusion_id_allocator(idAllocator) {}
|
||||||
|
~FusionBasePass() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &graph) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph) = 0;
|
||||||
|
void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record);
|
||||||
|
FusionIdAllocator *fusion_id_allocator;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_
|
@ -0,0 +1,245 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.h"
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <deque>
|
||||||
|
#include <memory>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "kernel/kernel_fusion.h"
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "pre_activate/common/fusion_id_allocator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
const int8_t MAX_PATTERN_SIZE = 7;
|
||||||
|
const int8_t MIN_PATTERN_SIZE = 2;
|
||||||
|
const int8_t ELTWISE_INPUT_SIZE = 2;
|
||||||
|
const int8_t ELTWISE_USE = 1;
|
||||||
|
const int8_t MULTI_ELTWISE_USE = 2;
|
||||||
|
const int8_t MAX_MULTI_ELTWISE_SIZE = 4;
|
||||||
|
const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3;
|
||||||
|
constexpr auto kOpAttrFusionId = "fusion_id";
|
||||||
|
|
||||||
|
bool CheckEltWiseNode(FuncGraphManager *manager, std::unordered_set<AnfNodePtr> *record, const CNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
MS_EXCEPTION_IF_NULL(record);
|
||||||
|
auto user_nodes = manager->node_users()[node];
|
||||||
|
return (AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||||
|
AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE &&
|
||||||
|
(user_nodes.size() <= ELTWISE_USE || record->size() == 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common method to check for predecessors and successors in a fusion pattern
|
||||||
|
std::tuple<bool, CNodePtr> FindPredAndSuccEltWiseNodes(const int8_t &max_size, FuncGraphManager *manager,
|
||||||
|
std::unordered_set<AnfNodePtr> *visited_set,
|
||||||
|
std::deque<AnfNodePtr> *todo,
|
||||||
|
std::unordered_set<AnfNodePtr> *record, const CNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
MS_EXCEPTION_IF_NULL(visited_set);
|
||||||
|
MS_EXCEPTION_IF_NULL(todo);
|
||||||
|
MS_EXCEPTION_IF_NULL(record);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
|
CNodePtr new_node = node;
|
||||||
|
if (new_node->inputs().size() < ELTWISE_INPUT_SIZE) {
|
||||||
|
return std::make_tuple(false, new_node);
|
||||||
|
}
|
||||||
|
int8_t index = 1;
|
||||||
|
auto &users = manager->node_users();
|
||||||
|
while (CheckEltWiseNode(manager, record, new_node)) {
|
||||||
|
(void)record->insert(new_node);
|
||||||
|
(void)visited_set->insert(new_node);
|
||||||
|
(void)todo->insert(todo->end(), new_node->inputs().begin() + 1, new_node->inputs().end());
|
||||||
|
|
||||||
|
auto cnode = new_node->input(1);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (!cnode->isa<CNode>()) {
|
||||||
|
return std::make_tuple(false, new_node);
|
||||||
|
}
|
||||||
|
new_node = cnode->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(new_node);
|
||||||
|
|
||||||
|
if (!AnfAlgo::IsRealKernel(new_node) || new_node->inputs().size() < ELTWISE_INPUT_SIZE ||
|
||||||
|
users[(new_node)].size() >= MULTI_ELTWISE_USE || visited_set->find(new_node) != visited_set->end()) {
|
||||||
|
return std::make_tuple(false, new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (index >= max_size) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
return std::make_tuple(true, new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<bool, CNodePtr> MatchGeneralPattern(FuncGraphManager *manager, std::unordered_set<AnfNodePtr> *record,
|
||||||
|
std::unordered_set<AnfNodePtr> *visited_set,
|
||||||
|
std::deque<AnfNodePtr> *todo, const CNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
MS_EXCEPTION_IF_NULL(record);
|
||||||
|
MS_EXCEPTION_IF_NULL(visited_set);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(todo);
|
||||||
|
CNodePtr new_node = node;
|
||||||
|
auto &users = manager->node_users();
|
||||||
|
if (users[(new_node)].size() >= MULTI_ELTWISE_USE) {
|
||||||
|
return std::make_tuple(false, new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
(void)record->insert(node);
|
||||||
|
(void)visited_set->insert(node);
|
||||||
|
(void)todo->insert(todo->end(), new_node->inputs().begin() + 1, new_node->inputs().end());
|
||||||
|
|
||||||
|
if (node->inputs().size() < 2) {
|
||||||
|
return std::make_tuple(false, new_node);
|
||||||
|
}
|
||||||
|
// only check the first real input, will check all
|
||||||
|
auto cnode = node->input(1);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (!cnode->isa<CNode>()) {
|
||||||
|
return std::make_tuple(false, new_node);
|
||||||
|
}
|
||||||
|
new_node = cnode->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(new_node);
|
||||||
|
|
||||||
|
if (!AnfAlgo::IsRealKernel(new_node) || users[(new_node)].size() >= MULTI_ELTWISE_USE ||
|
||||||
|
visited_set->find(new_node) != visited_set->end()) {
|
||||||
|
return std::make_tuple(false, new_node);
|
||||||
|
}
|
||||||
|
return std::make_tuple(true, new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr FindFusionAnfNode(FuncGraphManager *manager, std::unordered_set<AnfNodePtr> *visited_set,
|
||||||
|
std::unordered_set<AnfNodePtr> *record, std::deque<AnfNodePtr> *todo, const CNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
MS_EXCEPTION_IF_NULL(visited_set);
|
||||||
|
MS_EXCEPTION_IF_NULL(record);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(todo);
|
||||||
|
// find fusion pattern predecessor nodes
|
||||||
|
auto ret = FindPredAndSuccEltWiseNodes(MAX_MULTI_ELTWISE_SIZE, manager, visited_set, todo, record, node);
|
||||||
|
auto new_node = std::get<1>(ret);
|
||||||
|
auto node_use_size = manager->node_users()[new_node].size();
|
||||||
|
if (!std::get<0>(ret) || (record->size() > 1 && node_use_size > 1) || record->size() >= MAX_MULTI_ELTWISE_SIZE ||
|
||||||
|
AnfAlgo::GetKernelType(new_node) != KernelType::TBE_KERNEL) {
|
||||||
|
return new_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// key of fusion precessor
|
||||||
|
auto node_fusion_type = AnfAlgo::GetFusionType(new_node);
|
||||||
|
switch (node_fusion_type) {
|
||||||
|
case kernel::FusionType::COMMREDUCE:
|
||||||
|
case kernel::FusionType::SEGMENT:
|
||||||
|
ret = MatchGeneralPattern(manager, record, visited_set, todo, new_node);
|
||||||
|
new_node = std::get<1>(ret);
|
||||||
|
if (!std::get<0>(ret)) {
|
||||||
|
return new_node;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case kernel::FusionType::ELEMWISE:
|
||||||
|
return new_node;
|
||||||
|
// -fallthrough to default and return
|
||||||
|
case kernel::FusionType::CONVLUTION:
|
||||||
|
(void)record->insert(new_node);
|
||||||
|
default:
|
||||||
|
(void)visited_set->insert(new_node);
|
||||||
|
if (new_node != nullptr) {
|
||||||
|
(void)todo->insert(todo->end(), new_node->inputs().begin() + 1, new_node->inputs().end());
|
||||||
|
}
|
||||||
|
return new_node;
|
||||||
|
}
|
||||||
|
// find fusion pattern successor nodes
|
||||||
|
ret = FindPredAndSuccEltWiseNodes(MAX_PURE_BUFFER_SUCC_SIZE, manager, visited_set, todo, record, new_node);
|
||||||
|
return std::get<1>(ret);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void FusionTypeFusionPass::MatchFusionTypePattern(const session::KernelGraph &kernel_graph,
|
||||||
|
FusedNodeRecord *candidate_fusion) {
|
||||||
|
auto manager = kernel_graph.manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||||
|
|
||||||
|
auto return_node = kernel_graph.get_return();
|
||||||
|
MS_EXCEPTION_IF_NULL(return_node);
|
||||||
|
if (return_node->inputs().size() <= 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::deque<AnfNodePtr> todo;
|
||||||
|
todo.push_back(return_node->input(1));
|
||||||
|
std::unordered_set<AnfNodePtr> visited_set;
|
||||||
|
|
||||||
|
while (!todo.empty()) {
|
||||||
|
auto node = todo.front();
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
todo.pop_front();
|
||||||
|
std::unordered_set<AnfNodePtr> record;
|
||||||
|
if (visited_set.find(node) != visited_set.end() || fusion_id_allocator->HasFusionIdAttr(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Only fuse real cnode
|
||||||
|
if (!AnfAlgo::IsRealCNodeKernel(node)) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode != nullptr) {
|
||||||
|
(void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
// cnode maybe updated
|
||||||
|
cnode = FindFusionAnfNode(manager.get(), &visited_set, &record, &todo, cnode);
|
||||||
|
if (record.size() >= MIN_PATTERN_SIZE && record.size() <= MAX_PATTERN_SIZE) {
|
||||||
|
candidate_fusion->push_back(record);
|
||||||
|
SetRecordFusionId(record);
|
||||||
|
}
|
||||||
|
if (record.find(cnode) == record.end()) {
|
||||||
|
todo.push_back(cnode);
|
||||||
|
}
|
||||||
|
// no node matched
|
||||||
|
if (record.size() == 0) {
|
||||||
|
(void)visited_set.insert(node);
|
||||||
|
}
|
||||||
|
(void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FusionTypeFusionPass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) {
|
||||||
|
auto manager = kernel_graph.manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto return_node = kernel_graph.get_return();
|
||||||
|
MS_EXCEPTION_IF_NULL(return_node);
|
||||||
|
if (return_node->inputs().size() <= 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "MatchBufferFusionPattern start...";
|
||||||
|
FusedNodeRecord candidate_fusion;
|
||||||
|
MatchFusionTypePattern(kernel_graph, &candidate_fusion);
|
||||||
|
if (candidate_fusion.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "MatchBufferFusionPattern Success...";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,47 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_TYPE_FUSION_PASS_H_
|
||||||
|
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_TYPE_FUSION_PASS_H_
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "pre_activate/common/pass.h"
|
||||||
|
#include "pre_activate/common/fusion_id_allocator.h"
|
||||||
|
#include "device/kernel_info.h"
|
||||||
|
#include "kernel/kernel.h"
|
||||||
|
#include "session/kernel_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
||||||
|
|
||||||
|
class FusionTypeFusionPass : public FusionBasePass {
|
||||||
|
public:
|
||||||
|
FusionTypeFusionPass() : FusionBasePass("FusionTypeFusionPass") {}
|
||||||
|
explicit FusionTypeFusionPass(FusionIdAllocator *idAllocator) : FusionBasePass("FusionTypeFusionPass", idAllocator) {}
|
||||||
|
~FusionTypeFusionPass() override = default;
|
||||||
|
bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_TYPE_FUSION_PASS_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,50 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_TBE_BUFFER_FUSION_H_
|
||||||
|
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_TBE_BUFFER_FUSION_H_
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "pre_activate/common/pass.h"
|
||||||
|
#include "pre_activate/common/fusion_id_allocator.h"
|
||||||
|
#include "device/kernel_info.h"
|
||||||
|
#include "kernel/kernel.h"
|
||||||
|
#include "session/kernel_graph.h"
|
||||||
|
#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
||||||
|
|
||||||
|
class TbeBufferFusion : public Pass {
|
||||||
|
public:
|
||||||
|
TbeBufferFusion() : Pass("TbeBufferFusion") {}
|
||||||
|
~TbeBufferFusion() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &graph) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
||||||
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const;
|
||||||
|
bool ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos, int32_t fusion_id,
|
||||||
|
const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const;
|
||||||
|
bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_TBE_BUFFER_FUSION_H_
|
Loading…
Reference in new issue