test=develop

release/1.1
sneaxiy 7 years ago
commit 5a38930660

@ -2,8 +2,8 @@
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html) [![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.0/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html) [![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Latest PaddlePaddle Release: [Fluid 1.0.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.0.0) ### Latest PaddlePaddle Release: [Fluid 1.0.1](https://github.com/PaddlePaddle/Paddle/tree/release/1.0.0)
### Install Latest Stable Release: ### Install Latest Stable Release:
``` ```
# Linux CPU # Linux CPU
@ -27,9 +27,9 @@ pip install paddlepaddle
# Linux GPU cuda9cudnn7 # Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7 # Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==0.15.0.post87 pip install paddlepaddle-gpu==1.0.1.post87
# Linux GPU cuda8cudnn5 # Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==0.15.0.post85 pip install paddlepaddle-gpu==1.0.1.post85
# For installation on other platform, refer to http://paddlepaddle.org/ # For installation on other platform, refer to http://paddlepaddle.org/
``` ```

@ -311,6 +311,8 @@ function(cc_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif() endif()
endfunction(cc_test) endfunction(cc_test)
@ -629,6 +631,8 @@ function(py_test TARGET_NAME)
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS} PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS} ${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif() endif()
endfunction() endfunction()

@ -61,12 +61,12 @@ paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None
paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)) paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100))
paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None)) paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None))
paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, False)) paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None)) paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None)) paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None)) paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False)) paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
@ -97,8 +97,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)

@ -10,7 +10,7 @@ function(pass_library TARGET DEST)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass ${op_library_DEPS}) cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically. # add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference") if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
message(STATUS "add pass ${TARGET} ${DEST}") message(STATUS "add pass ${TARGET} ${DEST}")
@ -25,13 +25,11 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library(graph_to_program_pass base) pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base) pass_library(graph_viz_pass base)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
if (WITH_MKLDNN)
pass_library(conv_relu_mkldnn_fuse_pass inference)
endif ()
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference) pass_library(infer_clean_graph_pass inference)
pass_library(fc_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference)
@ -39,6 +37,10 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference)
pass_library(conv_bn_fuse_pass inference) pass_library(conv_bn_fuse_pass inference)
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base)
pass_library(conv_relu_mkldnn_fuse_pass inference)
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )

@ -262,7 +262,7 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unordered_set<std::string> specified_vars({"data_lod_attention", std::unordered_set<std::string> specified_vars({"data_lod_attention",
"cell_init", "hidden_init", "cell_init", "hidden_init",
"data", "week", "minute"}); "data", "week", "minute"});
int count = 0; size_t count = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsVar() && specified_vars.count(node->Name())) { if (node->IsVar() && specified_vars.count(node->Name())) {
++count; ++count;

@ -126,12 +126,21 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
// conv, batch_norm, // conv, batch_norm,
// conv_weight, conv_out, // conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance, // bn_scale, bn_bias, bn_mean, bn_variance,
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,
// bn_saved_variance
GET_CONV_BN_NODES(conv_bn_pattern); GET_CONV_BN_NODES(conv_bn_pattern);
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+bn fuse";
return;
}
// Create eltwise_y (conv bias) variable // Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc( VarDesc eltwise_y_in_desc(
patterns::PDNodeName(name_scope_, "eltwise_y_in")); patterns::PDNodeName(name_scope_, "eltwise_y_in"));
eltwise_y_in_desc.SetPersistable(true);
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc); auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
auto* eltwise_y_in_tensor = auto* eltwise_y_in_tensor =
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>(); scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
@ -151,27 +160,59 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
*bn_mean, *bn_variance, eltwise_y_in_tensor, *bn_mean, *bn_variance, eltwise_y_in_tensor,
epsilon); epsilon);
// Create an elementwise add node // with MKL-DNN fuse conv+bn into conv with bias
OpDesc desc; // without MKL-DNN fuse conv+bn into conv+elementwise_add
desc.SetInput("X", std::vector<std::string>({conv_out->Name()})); if (fuse_option == FUSE_MKLDNN) {
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()})); auto input_names = conv->Op()->InputNames();
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()})); bool has_bias = std::find(input_names.begin(), input_names.end(),
desc.SetType("elementwise_add"); "Bias") != input_names.end();
desc.SetAttr("axis", 1); if (has_bias && conv->Op()->Input("Bias").size() > 0) {
bool a = boost::get<bool>(conv->Op()->GetAttr("use_mkldnn")); // reuse existing conv bias node
desc.SetAttr("use_mkldnn", a); auto conv_bias_names = conv->Op()->Input("Bias");
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
GraphSafeRemoveNodes(graph.get(), {bn_scale, bn_bias, bn_mean, bn_variance, auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
batch_norm, bn_mean_out, bn_variance_out, PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(),
bn_saved_mean, bn_saved_variance}); eltwise_y_in_tensor->dims());
PADDLE_ENFORCE(subgraph.count(conv_input)); auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
IR_NODE_LINK_TO(conv_out, eltwise_op); eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); } else {
IR_NODE_LINK_TO(eltwise_op, bn_out); // add new conv_bias node
conv->Op()->SetInput(
found_conv_bn_count++; "Bias", std::vector<std::string>({eltwise_y_in_node->Name()}));
IR_NODE_LINK_TO(eltwise_y_in_node, conv);
}
conv->Op()->SetOutput("Output",
std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv, bn_out);
found_conv_bn_count++;
} else { // fuse_option == FUSE_NATIVE
// create an elementwise add node.
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(
graph.get(),
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
IR_NODE_LINK_TO(eltwise_op, bn_out);
found_conv_bn_count++;
}
}; };
gpd(graph.get(), handler); gpd(graph.get(), handler);
@ -237,7 +278,6 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out}); bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(eltwise, bn_out); IR_NODE_LINK_TO(eltwise, bn_out);
found_conv_bn_count++; found_conv_bn_count++;

@ -46,6 +46,12 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
FuseOptions fuse_option = FindFuseOption(*conv, *relu);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+relu fuse";
return;
}
// Transform Conv node into ConvReLU node. // Transform Conv node into ConvReLU node.
OpDesc* desc = conv->Op(); OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()})); desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));

@ -20,17 +20,19 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) { const std::vector<std::string>& outputs, bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
if (type == "conv2d") { if (type == "conv2d") {
op->SetAttr("use_mkldnn", true); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]}); op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") { } else if (type == "relu") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs); op->SetInput("X", inputs);
} }
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
@ -43,7 +45,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
ProgramDesc BuildProgramDesc() { ProgramDesc BuildProgramDesc() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) { std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS); var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") { if (v == "weights" || v == "bias") {
@ -51,14 +54,24 @@ ProgramDesc BuildProgramDesc() {
} }
} }
SetOp(&prog, "OP0", std::vector<std::string>({"a"}), SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"})); std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"b"}), SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"})); std::vector<std::string>({"c"}));
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights", "bias"}), // conv+relu, both with MKL-DNN
std::vector<std::string>({"f"})); SetOp(&prog, "conv2d", "conv1",
SetOp(&prog, "relu", std::vector<std::string>({"f"}), std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"g"})); std::vector<std::string>({"f"}), true);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), true);
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}));
// conv+relu, only one with MKL-DNN
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}));
return prog; return prog;
} }
@ -88,10 +101,16 @@ TEST(ConvReLUFusePass, basic) {
auto* op = node->Op(); auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn"))); EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("fuse_relu")); // check if only "conv1" convolution is fused
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu")); auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (fuse_relu) { if (op_name == "conv1") {
++conv_relu_count; ASSERT_TRUE(op->HasAttr("fuse_relu"));
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
if (fuse_relu) {
++conv_relu_count;
}
} else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_relu"));
} }
} }
} }

@ -0,0 +1,62 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
void FusePassBase::Init(const std::string& repr, Graph* graph) const {
repr_ = repr;
graph_ = graph;
}
Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
}
void FusePassBase::AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_);
PADDLE_ENFORCE(!repr_.empty());
if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
}
auto& info =
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
info[repr_] = count_of_fused;
}
FuseOptions FusePassBase::FindFuseOption(const Node& node1,
const Node& node2) const {
#ifdef PADDLE_WITH_MKLDNN
bool node1_mkldnn = node1.Op()->HasAttr("use_mkldnn") &&
boost::get<bool>(node1.Op()->GetAttr("use_mkldnn"));
bool node2_mkldnn = node2.Op()->HasAttr("use_mkldnn") &&
boost::get<bool>(node2.Op()->GetAttr("use_mkldnn"));
if (node1_mkldnn && node2_mkldnn)
return FUSE_MKLDNN;
else if (!node1_mkldnn && !node2_mkldnn)
return FUSE_NATIVE;
else
return DO_NOT_FUSE;
#else
return FUSE_NATIVE;
#endif
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -25,32 +25,24 @@ namespace ir {
static const char kParamScopeAttr[] = "__param_scope__"; static const char kParamScopeAttr[] = "__param_scope__";
static const char kFuseStatisAttr[] = "__fuse_statis__"; static const char kFuseStatisAttr[] = "__fuse_statis__";
enum FuseOptions {
DO_NOT_FUSE, // fusing will not be done
FUSE_NATIVE, // fusing will be done without MKL-DNN
FUSE_MKLDNN // fusing will be done with MKL-DNN
};
class FusePassBase : public Pass { class FusePassBase : public Pass {
public: public:
void Init(const std::string& repr, Graph* graph) const { void Init(const std::string& repr, Graph* graph) const;
repr_ = repr; Scope* param_scope() const;
graph_ = graph; void AddStatis(int count_of_fused) const;
}
Scope* param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
}
void AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_);
PADDLE_ENFORCE(!repr_.empty());
if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
}
auto& info =
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
info[repr_] = count_of_fused;
}
virtual ~FusePassBase() {} virtual ~FusePassBase() {}
protected: protected:
virtual FuseOptions FindFuseOption(const Node& node1,
const Node& node2) const;
mutable Graph* graph_; mutable Graph* graph_;
mutable std::string repr_; mutable std::string repr_;
}; };

@ -259,6 +259,8 @@ GraphPatternDetector::DetectPatterns() {
return result; return result;
} }
// TODO(Superjomn) enhance the function as it marks unique unique as duplicates
// see https://github.com/PaddlePaddle/Paddle/issues/13550
void GraphPatternDetector::UniquePatterns( void GraphPatternDetector::UniquePatterns(
std::vector<GraphPatternDetector::subgraph_t> *subgraphs) { std::vector<GraphPatternDetector::subgraph_t> *subgraphs) {
if (subgraphs->empty()) return; if (subgraphs->empty()) return;

@ -0,0 +1,37 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn_placement_pass.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Aplies MKL-DNN placement strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) {
n->Op()->SetAttr("use_mkldnn", true);
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(mkldnn_placement_pass,
paddle::framework::ir::MKLDNNPlacementPass);

@ -0,0 +1,31 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class MKLDNNPlacementPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -85,10 +85,6 @@ class CompileTimeInferShapeContext : public InferShapeContext {
VLOG(3) << "input " << in << " is not LodTensor"; VLOG(3) << "input " << in << " is not LodTensor";
return; return;
} }
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j,
out);
out_var->SetLoDLevel(in_var->GetLoDLevel()); out_var->SetLoDLevel(in_var->GetLoDLevel());
} }

@ -126,7 +126,7 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
std::vector<std::string> feed_target_names; std::vector<std::string> feed_target_names;
for (auto *op : global_block.AllOps()) { for (auto *op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) { if (op->Type() == kFeedOpType) {
int col = boost::get<int>(op->GetAttr("col")); size_t col = boost::get<int>(op->GetAttr("col"));
if (col >= feed_target_names.size()) { if (col >= feed_target_names.size()) {
feed_target_names.resize(col + 1); feed_target_names.resize(col + 1);
} }
@ -143,7 +143,7 @@ const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
std::vector<std::string> fetch_target_names; std::vector<std::string> fetch_target_names;
for (auto *op : global_block.AllOps()) { for (auto *op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) { if (op->Type() == kFetchOpType) {
int col = boost::get<int>(op->GetAttr("col")); size_t col = boost::get<int>(op->GetAttr("col"));
if (col >= fetch_target_names.size()) { if (col >= fetch_target_names.size()) {
fetch_target_names.resize(col + 1); fetch_target_names.resize(col + 1);
} }

@ -39,7 +39,7 @@ TEST(READER, decorate_chain) {
{ {
auto endpoints = root->GetEndPoints(); auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U); ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(end_point1.get()), 0); ASSERT_NE(endpoints.count(end_point1.get()), 0UL);
ASSERT_NE(endpoints.count(end_point2.get()), 0); ASSERT_NE(endpoints.count(end_point2.get()), 0);
} }

@ -91,7 +91,7 @@ TEST(SelectedRows, SparseTable) {
ASSERT_TRUE(table.HasKey(10)); ASSERT_TRUE(table.HasKey(10));
ASSERT_TRUE(table.HasKey(8)); ASSERT_TRUE(table.HasKey(8));
ASSERT_TRUE(table.HasKey(6)); ASSERT_TRUE(table.HasKey(6));
ASSERT_EQ(table.rows().size(), 3); ASSERT_EQ(table.rows().size(), 3UL);
framework::Tensor ids; framework::Tensor ids;
ids.Resize(framework::make_ddim({4})); ids.Resize(framework::make_ddim({4}));

@ -101,7 +101,11 @@ Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) { void Analyzer::Run(Argument* argument) {
std::vector<std::string> passes; std::vector<std::string> passes;
for (auto& pass : all_ir_passes_) { if (use_mkldnn_) {
VLOG(3) << "Adding MKL-DNN placement pass";
passes.push_back("mkldnn_placement_pass");
}
for (auto& pass : ir_passes_) {
if (!disabled_ir_passes_.count(pass)) { if (!disabled_ir_passes_.count(pass)) {
passes.push_back(pass); passes.push_back(pass);
passes.push_back("graph_viz_pass"); // add graphviz for debug. passes.push_back("graph_viz_pass"); // add graphviz for debug.
@ -117,11 +121,26 @@ void Analyzer::Run(Argument* argument) {
} }
} }
Analyzer& Analyzer::IncludeAllIrPasses() {
ir_passes_ = all_ir_passes_;
return *this;
}
Analyzer& Analyzer::DisableIrPasses(const std::vector<std::string>& passes) { Analyzer& Analyzer::DisableIrPasses(const std::vector<std::string>& passes) {
disabled_ir_passes_.insert(passes.begin(), passes.end()); disabled_ir_passes_.insert(passes.begin(), passes.end());
return *this; return *this;
} }
Analyzer& Analyzer::IncludeIrPasses(const std::vector<std::string>& passes) {
ir_passes_ = passes;
return *this;
}
Analyzer& Analyzer::SetUseMkldnn(bool use_mkldnn) {
use_mkldnn_ = use_mkldnn;
return *this;
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle

@ -54,6 +54,9 @@ class Analyzer : public OrderedRegistry<PassManager> {
void Run(Argument* argument); void Run(Argument* argument);
Analyzer& DisableIrPasses(const std::vector<std::string>& passes); Analyzer& DisableIrPasses(const std::vector<std::string>& passes);
Analyzer& IncludeIrPasses(const std::vector<std::string>& passes);
Analyzer& IncludeAllIrPasses();
Analyzer& SetUseMkldnn(bool use_mkldnn);
DISABLE_COPY_AND_ASSIGN(Analyzer); DISABLE_COPY_AND_ASSIGN(Analyzer);
@ -81,6 +84,9 @@ class Analyzer : public OrderedRegistry<PassManager> {
}}; }};
std::unordered_set<std::string> disabled_ir_passes_; std::unordered_set<std::string> disabled_ir_passes_;
// Ir passes to run
std::vector<std::string> ir_passes_;
bool use_mkldnn_;
}; };
} // namespace analysis } // namespace analysis

@ -225,10 +225,24 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.origin_program_desc.reset( argument_.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto())); new ProgramDesc(*inference_program_->Proto()));
PADDLE_ENFORCE(
config_.ir_mode == contrib::AnalysisConfig::IrPassMode::kExclude, switch (config_.ir_mode) {
"Only kExclude is supported yet."); case contrib::AnalysisConfig::IrPassMode::kExclude:
Analyzer().DisableIrPasses(config_.ir_passes).Run(&argument_); Analyzer()
.IncludeAllIrPasses()
.SetUseMkldnn(config_._use_mkldnn)
.DisableIrPasses(config_.ir_passes)
.Run(&argument_);
break;
case contrib::AnalysisConfig::IrPassMode::kInclude:
Analyzer()
.SetUseMkldnn(config_._use_mkldnn)
.IncludeIrPasses(config_.ir_passes)
.Run(&argument_);
break;
default:
LOG(ERROR) << "Only kExclude and kInclude modes are supoorted yet.";
}
CHECK(argument_.transformed_program_desc); CHECK(argument_.transformed_program_desc);
VLOG(5) << "to prepare executor"; VLOG(5) << "to prepare executor";

@ -259,10 +259,17 @@ struct AnalysisConfig : public NativeConfig {
kExclude // Specify the disabled passes in `ir_passes`. kExclude // Specify the disabled passes in `ir_passes`.
}; };
void SetIncludeMode() {
ir_mode = IrPassMode::kInclude;
// this pass has to be run at the beginning of all fuse passes
ir_passes = {"infer_clean_graph_pass"};
}
// Determine whether to perform graph optimization. // Determine whether to perform graph optimization.
bool enable_ir_optim = true; bool enable_ir_optim = true;
// Manually determine the IR passes to run. // Manually determine the IR passes to run.
IrPassMode ir_mode{IrPassMode::kExclude}; IrPassMode ir_mode{IrPassMode::kExclude};
// passes to be excluded/included
std::vector<std::string> ir_passes{"embedding_fc_lstm_fuse_pass"}; std::vector<std::string> ir_passes{"embedding_fc_lstm_fuse_pass"};
// NOT stable yet. // NOT stable yet.

@ -52,9 +52,10 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
} }
// Easy for profiling independently. // Easy for profiling independently.
TEST(Analyzer_resnet50, profile) { void profile(bool use_mkldnn = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg._use_mkldnn = use_mkldnn;
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
@ -69,6 +70,11 @@ TEST(Analyzer_resnet50, profile) {
} }
} }
TEST(Analyzer_resnet50, profile) { profile(); }
#ifndef PADDLE_WITH_MKLDNN
TEST(Analyzer_resnet50, profile_mkldnn) { profile(true /* use_mkldnn */); }
#endif
// Check the fuse status // Check the fuse status
TEST(Analyzer_resnet50, fuse_statis) { TEST(Analyzer_resnet50, fuse_statis) {
AnalysisConfig cfg; AnalysisConfig cfg;
@ -82,15 +88,21 @@ TEST(Analyzer_resnet50, fuse_statis) {
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_resnet50, compare) { void compare(bool use_mkldnn = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg._use_mkldnn = use_mkldnn;
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all); SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all); CompareNativeAndAnalysis(cfg, input_slots_all);
} }
TEST(Analyzer_resnet50, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_resnet50, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle

@ -59,9 +59,6 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->specify_input_name = true; cfg->specify_input_name = true;
// TODO(TJ): fix fusion gru // TODO(TJ): fix fusion gru
cfg->ir_passes.push_back("fc_gru_fuse_pass"); cfg->ir_passes.push_back("fc_gru_fuse_pass");
#ifdef PADDLE_WITH_MKLDNN
cfg->_use_mkldnn = true;
#endif
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
@ -84,9 +81,10 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
// Easy for profiling independently. // Easy for profiling independently.
// ocr, mobilenet and se_resnext50 // ocr, mobilenet and se_resnext50
TEST(Analyzer_vis, profile) { void profile(bool use_mkldnn = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg._use_mkldnn = use_mkldnn;
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
@ -108,6 +106,12 @@ TEST(Analyzer_vis, profile) {
} }
} }
TEST(Analyzer_vis, profile) { profile(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_vis, profile_mkldnn) { profile(true /* use_mkldnn */); }
#endif
// Check the fuse status // Check the fuse status
TEST(Analyzer_vis, fuse_statis) { TEST(Analyzer_vis, fuse_statis) {
AnalysisConfig cfg; AnalysisConfig cfg;
@ -118,15 +122,21 @@ TEST(Analyzer_vis, fuse_statis) {
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_vis, compare) { void compare(bool use_mkldnn = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg._use_mkldnn = use_mkldnn;
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all); SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all); CompareNativeAndAnalysis(cfg, input_slots_all);
} }
TEST(Analyzer_vis, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_vis, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle

@ -163,7 +163,8 @@ void TestPrediction(const AnalysisConfig &config,
const std::vector<std::vector<PaddleTensor>> &inputs, const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads, std::vector<PaddleTensor> *outputs, int num_threads,
bool use_analysis = FLAGS_use_analysis) { bool use_analysis = FLAGS_use_analysis) {
LOG(INFO) << "use_analysis: " << use_analysis; LOG(INFO) << "use_analysis: " << use_analysis
<< ", use_mkldnn: " << config._use_mkldnn;
if (num_threads == 1) { if (num_threads == 1) {
TestOneThreadPrediction(config, inputs, outputs, use_analysis); TestOneThreadPrediction(config, inputs, outputs, use_analysis);
} else { } else {
@ -175,6 +176,7 @@ void TestPrediction(const AnalysisConfig &config,
void CompareNativeAndAnalysis( void CompareNativeAndAnalysis(
const AnalysisConfig &config, const AnalysisConfig &config,
const std::vector<std::vector<PaddleTensor>> &inputs) { const std::vector<std::vector<PaddleTensor>> &inputs) {
LOG(INFO) << "use_mkldnn: " << config._use_mkldnn;
std::vector<PaddleTensor> native_outputs, analysis_outputs; std::vector<PaddleTensor> native_outputs, analysis_outputs;
TestOneThreadPrediction(config, inputs, &native_outputs, false); TestOneThreadPrediction(config, inputs, &native_outputs, false);
TestOneThreadPrediction(config, inputs, &analysis_outputs, true); TestOneThreadPrediction(config, inputs, &analysis_outputs, true);

@ -229,7 +229,7 @@ TEST(BlockingQueue, speed_test_mode) {
q1.Receive(&b); q1.Receive(&b);
EXPECT_EQ(b, i); EXPECT_EQ(b, i);
} }
EXPECT_EQ(q1.Size(), 0); EXPECT_EQ(q1.Size(), 0UL);
BlockingQueue<size_t> q2(queue_size, true); BlockingQueue<size_t> q2(queue_size, true);
for (size_t i = 0; i < queue_size; ++i) { for (size_t i = 0; i < queue_size; ++i) {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save