|
|
|
@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
if (with_fc_bias) {
|
|
|
|
|
// Add FC-bias with LSTM-bias and create a new weight
|
|
|
|
|
PADDLE_ENFORCE(scope);
|
|
|
|
|
const std::string& new_bias_var = name_scope + "_bias.new";
|
|
|
|
|
const std::string& new_bias_var = patterns::UniqueKey("NewBias");
|
|
|
|
|
auto* bias_var = scope->Var(new_bias_var);
|
|
|
|
|
PADDLE_ENFORCE(bias_var);
|
|
|
|
|
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
|
|
|
|
@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
|
|
|
|
@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
fc_bias);
|
|
|
|
|
// Remove unneeded nodes.
|
|
|
|
|
std::unordered_set<const Node*> marked_nodes(
|
|
|
|
|
{mul, lstm, elementwise_add});
|
|
|
|
|
{mul, lstm, elementwise_add, fc_bias});
|
|
|
|
|
GraphSafeRemoveNodes(graph, marked_nodes);
|
|
|
|
|
} else {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
|
|
|
|
|