Enhance fc_fuse_pass to enable fusing relu to fc_op (#19733)
* Refine the codes related to fc op. * Add GPU implementation for fc functor. * Apply fc_fuse_pass in GPU inference. test=develop * Change the cmake for fc op. * Change PADDLE_ENFORCE to PADDLE_ENFORCE_EQ. * Add an attribute to set the activation type in fc_op. * Enhance the unittest of fc_op. test=develop * Remove the declaration of FCOpGrad back to the header file. test=develop * Set default value for newly added arguments in test_fc_op. test=develop * Enhance fc_fuse_pass to enable fusing relu. * Allow print the shapes of var_desc in graph. test=develop * Enhance fc_fuse_pass_tester. * Remove the use of PADDLE_ENFORCE. test=develop * Correct the number of ops after fusing. test=develop * Fix a typo. test=develop * Set activation_type to null when there is no relu in fc. test=develop * Refine fc_fuse_pass's codes. * Enable the set of shape for tensor. * Refine repeated_fc_relu_pass and add unittest. test=developexpand_as_op_1
parent
b5a5d93bbe
commit
c67c8758cb
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,71 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void TestMain(int num_fc) {
|
||||
// inputs operator output
|
||||
// -------------------------------------------------------------
|
||||
// (x, filters, bias_0) conv2d -> conv2d_out
|
||||
// (conv2d_out, fc_weights_0, fc_bias_0) fc -> fc_out_0
|
||||
// (fc_out_0, fc_weights_1, fc_bias_1) fc -> fc_out_1
|
||||
// ...
|
||||
Layers layers;
|
||||
VarDesc* x = layers.data("x");
|
||||
VarDesc* filters = layers.data("filters", {}, true);
|
||||
VarDesc* bias_0 = layers.data("bias_0", {}, true);
|
||||
VarDesc* conv2d_out = layers.conv2d(x, filters, bias_0);
|
||||
VarDesc* fc_in = conv2d_out;
|
||||
for (int i = 0; i < num_fc; ++i) {
|
||||
VarDesc* weights_i =
|
||||
layers.data("fc_weights_" + std::to_string(i), {}, true);
|
||||
VarDesc* bias_i = layers.data("fc_bias_" + std::to_string(i), {}, true);
|
||||
std::string activation_type = i < (num_fc - 1) ? "relu" : "";
|
||||
VarDesc* fc_out = layers.fc(fc_in, weights_i, bias_i, 1, activation_type);
|
||||
fc_in = fc_out;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
|
||||
auto pass = PassRegistry::Instance().Get("repeated_fc_relu_fuse_pass");
|
||||
int num_nodes_before = graph->Nodes().size();
|
||||
int num_fc_nodes_before = GetNumOpNodes(graph, "fc");
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
int num_nodes_after = graph->Nodes().size();
|
||||
int num_fused_nodes_after = GetNumOpNodes(graph, "fusion_repeated_fc_relu");
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
// Delete (num_fc_nodes_before - 1) fc ops
|
||||
PADDLE_ENFORCE_EQ(num_nodes_before - (num_fc_nodes_before - 1) + 1,
|
||||
num_nodes_after);
|
||||
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1);
|
||||
}
|
||||
|
||||
TEST(RepeatedFCReluFusePass, basic_3) { TestMain(3); }
|
||||
|
||||
TEST(RepeatedFCReluFusePass, basic_9) { TestMain(9); }
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(repeated_fc_relu_fuse_pass);
|
Loading…
Reference in new issue