[MKL-DNN] Add Fully Connected Op for inference only(#15226)
* fuse mul and elementwise add to fc * Reimplement the FC forward operator * Fix FC MKLDNN integration by transposing weights * Add FC MKLDNN Pass test=develop * FC MKLDNN Pass: change memcpy to std::copy * Fix MKLDNN FC handling of mismatch input and weights dims * Lower tolerance for MKL-DNN in resnet50 test test=develop * Adjust FC to support MKLDNN Op placement test=develop * Adjust Placement Op to set use_mkldnn attribute for graph test=develop * MKLDNN FC: fix weights format so that gemm version is called test=develop * FC MKLDNN: Remove tolerance decrease from tester_helper * FC MKL-DNN: Refactor the code, change input reorder to weight reorder * MKL-DNN FC: Introduce operator caching test=develop * FC MKL-DNN: Fix the tensor type in ExpectedKernelType test=develop * FC MKL-DNN: fix style changes test=develop * FC MKL-DNN: fallback to native on non-supported dim sizes test=develop * FC MKLDNN: fix CMake paths test=develop * FC MKLDNN: Refine placement pass graph mkldnn attribute test=develop * Fix Transpiler error for fuse_conv_eltwise test=develop * Fix missing STL includes in files test=develop * FC MKL-DNN: Enable new output size computation Also, refine pass to comply with newest interface. test=develop * FC MKL-DNN: enable only when fc_mkldnn_pass is enabled * FC MKL-DNN: Allow Weights to use oi or io format * FC MKL-DNN: Adjust UT to work with correct dims test=develop * Enable MKL DEBUG for resnet50 analyzer test=develop * FC MKL-DNN: Improve Hashing function test=develop * FC MKL-DNN: Fix shape for fc weights in transpiler * FC MKL-DNN: Update input pointer in re-used fc primitive * Add log for not handling fc fuse for unsupported dims test=develop * FC MKL-DNN: Move transpose from pass to Op Kernel test=develop * FC MKL-DNN: Disable transpose in unit test test=develop * FC MKL-DNN: Remove fc_mkldnn_pass from default list * Correct Flag for fake data analyzer tests test=develop * FC MKL-DNN: Add comment about fc mkldnn pass disablement test=develop * FC MKL-DNN: Disable fc in int8 tests test=developfix_ema
parent
21138eb12a
commit
0c39b97b4e
@ -0,0 +1,77 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
|
||||
PADDLE_ENFORCE(graph);
|
||||
Init("fc_mkldnn_pass", graph);
|
||||
|
||||
auto* scope = param_scope();
|
||||
PADDLE_ENFORCE(scope);
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto* x = gpd.mutable_pattern()
|
||||
->NewNode("fc_mkldnn_pass/x")
|
||||
->AsInput()
|
||||
->assert_is_op_input("fc", "Input");
|
||||
patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass");
|
||||
fc_pattern(x, true /*with bias*/);
|
||||
|
||||
int found_fc_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "Handle FC MKL-DNN pass";
|
||||
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
|
||||
VLOG(3) << "do not perform fc fuse";
|
||||
return;
|
||||
}
|
||||
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
|
||||
|
||||
OpDesc* desc = fc->Op();
|
||||
auto in_size = fc->inputs[0]->Var()->GetShape().size();
|
||||
if (in_size != 2 && in_size != 4) {
|
||||
VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than 2 & 4";
|
||||
return;
|
||||
}
|
||||
desc->SetAttr("use_mkldnn", true);
|
||||
PADDLE_ENFORCE(subgraph.count(x));
|
||||
|
||||
found_fc_count++;
|
||||
};
|
||||
|
||||
gpd(graph, handler);
|
||||
|
||||
AddStatis(found_fc_count);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(fc_mkldnn_pass, paddle::framework::ir::FCMKLDNNPass);
|
||||
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Transpose weights of FC to comply with MKL-DNN interface
|
||||
*/
|
||||
class FCMKLDNNPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~FCMKLDNNPass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue