commit
a314a80cdb
@ -0,0 +1,46 @@
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash)
|
||||
set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash)
|
||||
set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include")
|
||||
|
||||
IF(WITH_STATIC_LIB)
|
||||
SET(BUILD_CMD make lib)
|
||||
ELSE()
|
||||
SET(BUILD_CMD sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib)
|
||||
ENDIF()
|
||||
|
||||
ExternalProject_Add(
|
||||
extern_xxhash
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
GIT_REPOSITORY "https://github.com/Cyan4973/xxHash"
|
||||
GIT_TAG "v0.6.5"
|
||||
PREFIX ${XXHASH_SOURCE_DIR}
|
||||
DOWNLOAD_NAME "xxhash"
|
||||
UPDATE_COMMAND ""
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_IN_SOURCE 1
|
||||
PATCH_COMMAND
|
||||
BUILD_COMMAND ${BUILD_CMD}
|
||||
INSTALL_COMMAND export PREFIX=${XXHASH_INSTALL_DIR}/ && make install
|
||||
TEST_COMMAND ""
|
||||
)
|
||||
|
||||
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a")
|
||||
INCLUDE_DIRECTORIES(${XXHASH_INCLUDE_DIR})
|
||||
|
||||
add_library(xxhash STATIC IMPORTED GLOBAL)
|
||||
set_property(TARGET xxhash PROPERTY IMPORTED_LOCATION ${XXHASH_LIBRARIES})
|
||||
include_directories(${XXHASH_INCLUDE_DIR})
|
||||
add_dependencies(xxhash extern_xxhash)
|
||||
|
||||
LIST(APPEND external_project_dependencies xxhash)
|
||||
|
||||
IF(WITH_C_API)
|
||||
INSTALL(DIRECTORY ${XXHASH_INCLUDE_DIR} DESTINATION third_party/xxhash)
|
||||
IF(ANDROID)
|
||||
INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib/${ANDROID_ABI})
|
||||
ELSE()
|
||||
INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib)
|
||||
ENDIF()
|
||||
ENDIF()
|
@ -0,0 +1,137 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
template <typename BinaryOperation>
|
||||
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
|
||||
BinaryOperation f) {
|
||||
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
|
||||
LoDTensor vec_y;
|
||||
vec_y.Resize(vec_a.dims());
|
||||
const float* a = vec_a.data<float>();
|
||||
const float* b = vec_b.data<float>();
|
||||
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
|
||||
for (int i = 0; i < vec_a.numel(); i++) {
|
||||
y[i] = f(a[i], b[i]);
|
||||
}
|
||||
return vec_y;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
auto* scope = param_scope();
|
||||
PADDLE_ENFORCE(scope);
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto* conv_input =
|
||||
gpd.mutable_pattern()
|
||||
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
|
||||
->AsInput()
|
||||
->assert_is_op_input("conv2d", "Input");
|
||||
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
|
||||
conv_bias_pattern(conv_input);
|
||||
int found_conv_bias_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle ConvBias fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
|
||||
conv_bias_pattern); // Filter
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
|
||||
// bias
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
|
||||
// output
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
|
||||
// elementwise_add op
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
|
||||
|
||||
PADDLE_ENFORCE(subgraph.count(conv_input));
|
||||
|
||||
// check if fuse can be done and if MKL-DNN should be used
|
||||
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
|
||||
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
|
||||
VLOG(3) << "do not perform conv+bias fuse";
|
||||
return;
|
||||
}
|
||||
|
||||
auto* eltwise_bias_tensor =
|
||||
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
|
||||
|
||||
auto input_names = conv->Op()->InputNames();
|
||||
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
|
||||
input_names.end();
|
||||
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
|
||||
auto conv_bias_names = conv->Op()->Input("Bias");
|
||||
// add eltwise bias to existing conv bias
|
||||
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
|
||||
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
|
||||
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
|
||||
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
|
||||
*conv_bias_tensor = tensor_apply_eltwise(
|
||||
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
|
||||
|
||||
conv->Op()->SetOutput("Output",
|
||||
std::vector<std::string>({eltwise_out->Name()}));
|
||||
|
||||
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
|
||||
|
||||
IR_NODE_LINK_TO(conv, eltwise_out);
|
||||
} else {
|
||||
// take eltwise bias as conv bias
|
||||
OpDesc desc;
|
||||
|
||||
desc.SetInput(
|
||||
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
|
||||
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
|
||||
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
|
||||
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
|
||||
desc.SetType("conv2d");
|
||||
|
||||
for (auto& attr : conv->Op()->GetAttrMap()) {
|
||||
desc.SetAttr(attr.first, attr.second);
|
||||
}
|
||||
auto conv_bias_node = g->CreateOpNode(&desc);
|
||||
|
||||
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
|
||||
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
|
||||
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
|
||||
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
|
||||
|
||||
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
|
||||
}
|
||||
|
||||
found_conv_bias_count++;
|
||||
};
|
||||
gpd(graph.get(), handler);
|
||||
AddStatis(found_conv_bias_count);
|
||||
return graph;
|
||||
}
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
|
||||
paddle::framework::ir::ConvBiasFusePass);
|
@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
/*
|
||||
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
|
||||
*/
|
||||
class ConvBiasFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ConvBiasFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
|
||||
};
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue