|
|
|
@ -566,8 +566,8 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::
|
|
|
|
|
// Elemwise-->DepthwiseConvolution
|
|
|
|
|
auto relu = cnode->input(1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(relu);
|
|
|
|
|
if (cnode->isa<CNode>() && AnfAlgo::GetCNodeName(relu) == prim::kPrimRelu->name() ||
|
|
|
|
|
AnfAlgo::GetCNodeName() == kReluV2OpName) {
|
|
|
|
|
if (cnode->isa<CNode>() &&
|
|
|
|
|
(AnfAlgo::GetCNodeName(relu) == prim::kPrimRelu->name() || AnfAlgo::GetCNodeName(relu) == kReluV2OpName)) {
|
|
|
|
|
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu].size())};
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu);
|
|
|
|
|
std::unordered_set<AnfNodePtr> record{cnode, relu};
|
|
|
|
|