|
|
|
@ -93,6 +93,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
|
|
|
|
|
bool use_gpu = Has("use_gpu") ? Get<bool>("use_gpu") : false;
|
|
|
|
|
bool use_fc_padding =
|
|
|
|
|
Has("use_fc_padding") ? Get<bool>("use_fc_padding") : true;
|
|
|
|
|
const std::string& w_name = patterns::UniqueKey(w->Name());
|
|
|
|
|
VarDesc w_key(w_name);
|
|
|
|
|
w_key.SetPersistable(true);
|
|
|
|
|
auto* w_node = g->CreateVarNode(&w_key);
|
|
|
|
|
if (!use_gpu && use_fc_padding) {
|
|
|
|
|
auto* scope = param_scope();
|
|
|
|
|
auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>();
|
|
|
|
@ -102,20 +106,25 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
|
|
|
|
|
int w_h = weight_dims[0];
|
|
|
|
|
int w_w = weight_dims[1];
|
|
|
|
|
if (w_h % 128 == 0 && w_w % 128 == 0) {
|
|
|
|
|
auto* w_var = scope->Var(w_name);
|
|
|
|
|
auto* w_tensor = w_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
auto* weight_data_tmp = new float[weight_num];
|
|
|
|
|
for (int i = 0; i < w_h; i++) {
|
|
|
|
|
memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w,
|
|
|
|
|
w_w * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4});
|
|
|
|
|
w_tensor->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4});
|
|
|
|
|
auto* weight_data_new =
|
|
|
|
|
weight->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
w_tensor->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
for (int i = 0; i < w_h; i++) {
|
|
|
|
|
memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w,
|
|
|
|
|
w_w * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
delete[] weight_data_tmp;
|
|
|
|
|
desc.SetInput("W", {w_name});
|
|
|
|
|
desc.SetAttr("padding_weights", true);
|
|
|
|
|
desc.Flush();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -147,7 +156,12 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(subgraph.at(x), fc_node);
|
|
|
|
|
if (desc.GetAttrIfExists<bool>("padding_weights")) {
|
|
|
|
|
IR_NODE_LINK_TO(w_node, fc_node);
|
|
|
|
|
} else {
|
|
|
|
|
GraphSafeRemoveNodes(g, {w_node});
|
|
|
|
|
IR_NODE_LINK_TO(w, fc_node);
|
|
|
|
|
}
|
|
|
|
|
IR_NODE_LINK_TO(bias, fc_node);
|
|
|
|
|
if (with_relu) {
|
|
|
|
|
IR_NODE_LINK_TO(fc_node, relu_out);
|
|
|
|
|