parent
23fc896bc2
commit
603ba5e01d
@ -0,0 +1,101 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
|
||||||
|
#include <string>
|
||||||
|
#include "paddle/fluid/framework/lod_tensor.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
|
||||||
|
GraphPatternDetector gpd;
|
||||||
|
auto* pattern = gpd.mutable_pattern();
|
||||||
|
|
||||||
|
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X"))
|
||||||
|
->assert_is_op_input("sequence_conv")
|
||||||
|
->assert_var_not_persistable();
|
||||||
|
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope);
|
||||||
|
fuse_pattern(x);
|
||||||
|
|
||||||
|
// Create New OpDesc
|
||||||
|
auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight,
|
||||||
|
Node* eltadd_bias, Node* relu_out) {
|
||||||
|
OpDesc op_desc;
|
||||||
|
op_desc.SetType("fusion_seqconv_eltadd_relu");
|
||||||
|
op_desc.SetInput("X", {input->Name()});
|
||||||
|
op_desc.SetInput("Filter", {seqconv_weight->Name()});
|
||||||
|
op_desc.SetInput("Bias", {eltadd_bias->Name()});
|
||||||
|
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
|
||||||
|
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
|
||||||
|
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
|
||||||
|
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
|
||||||
|
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
|
||||||
|
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
|
||||||
|
op_desc.SetOutput("ColMat", {ColMat});
|
||||||
|
op_desc.SetOutput("Out", {relu_out->Name()});
|
||||||
|
scope->Var(ColMat)->GetMutable<LoDTensor>();
|
||||||
|
|
||||||
|
auto* op = graph->CreateOpNode(&op_desc);
|
||||||
|
IR_NODE_LINK_TO(input, op);
|
||||||
|
IR_NODE_LINK_TO(seqconv_weight, op);
|
||||||
|
IR_NODE_LINK_TO(eltadd_bias, op);
|
||||||
|
IR_NODE_LINK_TO(op, relu_out);
|
||||||
|
return op;
|
||||||
|
};
|
||||||
|
|
||||||
|
int fusion_count{0};
|
||||||
|
|
||||||
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||||
|
Graph* g) {
|
||||||
|
VLOG(4) << "handle SeqConv EltAdd Relu fuse";
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern);
|
||||||
|
|
||||||
|
fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias,
|
||||||
|
relu_out);
|
||||||
|
std::unordered_set<const Node*> marked_nodes(
|
||||||
|
{seqconv, seqconv_out, eltadd, eltadd_out, relu});
|
||||||
|
GraphSafeRemoveNodes(graph, marked_nodes);
|
||||||
|
++fusion_count;
|
||||||
|
};
|
||||||
|
|
||||||
|
gpd(graph, handler);
|
||||||
|
|
||||||
|
return fusion_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl(
|
||||||
|
std::unique_ptr<ir::Graph> graph) const {
|
||||||
|
FusePassBase::Init(name_scope_, graph.get());
|
||||||
|
|
||||||
|
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope());
|
||||||
|
AddStatis(fusion_count);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
|
||||||
|
paddle::framework::ir::SeqConvEltAddReluFusePass);
|
@ -0,0 +1,38 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
class SeqConvEltAddReluFusePass : public FusePassBase {
|
||||||
|
public:
|
||||||
|
virtual ~SeqConvEltAddReluFusePass() {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||||
|
|
||||||
|
const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
Loading…
Reference in new issue