|
|
|
@ -36,9 +36,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
->assert_var_not_persistable();
|
|
|
|
|
patterns::FC fc_pattern(pattern, name_scope);
|
|
|
|
|
|
|
|
|
|
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
|
|
|
|
|
auto* fc_out =
|
|
|
|
|
fc_pattern(x, with_fc_bias, /* with_relu */ false)->AsIntermediate();
|
|
|
|
|
auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false);
|
|
|
|
|
patterns::LSTM lstm_pattern(pattern, name_scope);
|
|
|
|
|
lstm_pattern(fc_out);
|
|
|
|
|
|
|
|
|
@ -58,28 +56,25 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
// Add FC-bias with LSTM-bias and create a new weight
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
|
|
|
|
|
const std::string& new_bias_var = patterns::UniqueKey("NewBias");
|
|
|
|
|
auto* bias_var = scope->Var(new_bias_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(bias_var, platform::errors::InvalidArgument(
|
|
|
|
|
"Bias var ptr cannot be nullptr."));
|
|
|
|
|
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto* lstm_bias_var = scope->FindVar(bias->Name());
|
|
|
|
|
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(lstm_bias_var,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Lstm bias var ptr cannot be nullptr."));
|
|
|
|
|
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
|
|
|
|
|
bias_tensor->Resize(lstm_bias_tensor.dims());
|
|
|
|
|
|
|
|
|
|
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(fc_bias_var,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"FC bias var ptr cannot be nullptr."));
|
|
|
|
|
auto* lstm_bias_tensor =
|
|
|
|
|
lstm_bias_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
auto* data = bias_tensor->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
auto lstm_bias_data =
|
|
|
|
|
lstm_bias_tensor->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
auto* fc_bias_data = fc_bias_tensor.data<float>();
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < bias_tensor->numel(); i++) {
|
|
|
|
|
data[i] =
|
|
|
|
|
fc_bias_tensor.data<float>()[i] + lstm_bias_tensor.data<float>()[i];
|
|
|
|
|
for (int i = 0; i < lstm_bias_tensor->numel(); i++) {
|
|
|
|
|
lstm_bias_data[i] += fc_bias_data[i];
|
|
|
|
|
}
|
|
|
|
|
op_desc.SetInput("Bias", {new_bias_var});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("H0", {});
|
|
|
|
@ -114,6 +109,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
IR_NODE_LINK_TO(weight_h, op);
|
|
|
|
|
IR_NODE_LINK_TO(bias, op);
|
|
|
|
|
IR_NODE_LINK_TO(op, hidden);
|
|
|
|
|
IR_NODE_LINK_TO(op, cell);
|
|
|
|
|
IR_NODE_LINK_TO(op, xx);
|
|
|
|
|
|
|
|
|
|
#define IR_NODE(x) \
|
|
|
|
|
VarDesc key_##x(x); \
|
|
|
|
|