|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
@ -278,9 +279,48 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
// update weights and biases
|
|
|
|
|
float epsilon =
|
|
|
|
|
BOOST_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
|
|
|
|
|
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
|
|
|
|
|
*bn_mean, *bn_variance, eltwise_y_in_tensor,
|
|
|
|
|
epsilon, conv_type());
|
|
|
|
|
|
|
|
|
|
// if bias is an input to other ops as well then we cannot overwrite it
|
|
|
|
|
// so we create separate elementwise Y in nodes
|
|
|
|
|
if (eltwise_y_in->outputs.size() > 1) {
|
|
|
|
|
// Make a copy of eltwise Y input tensor
|
|
|
|
|
// Create eltwise_y (conv bias) variable
|
|
|
|
|
VarDesc eltwise_y_in_desc(patterns::PDNodeName(
|
|
|
|
|
name_scope_, "eltwise_y_in" + std::to_string(found_conv_bn_count)));
|
|
|
|
|
eltwise_y_in_desc.SetShape(
|
|
|
|
|
framework::vectorize(eltwise_y_in_tensor->dims()));
|
|
|
|
|
eltwise_y_in_desc.SetDataType(eltwise_y_in_tensor->type());
|
|
|
|
|
eltwise_y_in_desc.SetLoDLevel(eltwise_y_in->Var()->GetLoDLevel());
|
|
|
|
|
eltwise_y_in_desc.SetPersistable(true);
|
|
|
|
|
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
|
|
|
|
|
auto* eltwise_y_in_tensor_ex =
|
|
|
|
|
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
// Initialize eltwise_y
|
|
|
|
|
TensorCopy(*eltwise_y_in_tensor, platform::CPUPlace(),
|
|
|
|
|
eltwise_y_in_tensor_ex);
|
|
|
|
|
|
|
|
|
|
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
|
|
|
|
|
*bn_mean, *bn_variance, eltwise_y_in_tensor_ex,
|
|
|
|
|
epsilon, conv_type());
|
|
|
|
|
// Set new var
|
|
|
|
|
eltwise->Op()->RenameInput(eltwise_y_in->Name(),
|
|
|
|
|
eltwise_y_in_node->Name());
|
|
|
|
|
// Link new bias node to eltwise
|
|
|
|
|
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise);
|
|
|
|
|
// unlink original bias from eltwise_op
|
|
|
|
|
eltwise_y_in->outputs.erase(
|
|
|
|
|
std::remove_if(eltwise_y_in->outputs.begin(),
|
|
|
|
|
eltwise_y_in->outputs.end(),
|
|
|
|
|
[&](Node*& n) {
|
|
|
|
|
return n->id() == eltwise->id() ? true : false;
|
|
|
|
|
}),
|
|
|
|
|
eltwise_y_in->outputs.end());
|
|
|
|
|
} else {
|
|
|
|
|
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
|
|
|
|
|
*bn_mean, *bn_variance, eltwise_y_in_tensor,
|
|
|
|
|
epsilon, conv_type());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Update the elementwise_add node
|
|
|
|
|
eltwise->Op()->SetAttr("axis", 1);
|
|
|
|
|