You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
6.2 KiB
121 lines
6.2 KiB
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
|
|
file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n")
|
|
file(APPEND ${pass_file} "\#pragma once\n")
|
|
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
|
|
|
|
|
|
# Usage: pass_library(target inference) will append to paddle_inference_pass.h
|
|
unset(INFER_IR_PASSES CACHE) # clear the global variable
|
|
function(pass_library TARGET DEST)
|
|
set(options "")
|
|
set(oneValueArgs "")
|
|
set(multiValueArgs SRCS DEPS)
|
|
set(targetPrefix "")
|
|
|
|
# Get optional argument
|
|
set(extraMacroArgs ${ARGN})
|
|
list(LENGTH extraMacroArgs numExtraMacroArgs)
|
|
if(numExtraMacroArgs GREATER 0)
|
|
list(GET extraMacroArgs 0 targetPrefix)
|
|
endif()
|
|
|
|
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
|
if(targetPrefix)
|
|
cc_library(${TARGET} SRCS ${targetPrefix}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
|
|
else()
|
|
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
|
|
endif()
|
|
|
|
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
|
|
if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
|
|
message(STATUS "add pass ${TARGET} ${DEST}")
|
|
file(APPEND ${pass_file} "USE_PASS(${TARGET});\n")
|
|
set(INFER_IR_PASSES ${INFER_IR_PASSES} ${TARGET} CACHE INTERNAL "")
|
|
endif()
|
|
endfunction()
|
|
|
|
|
|
cc_library(node SRCS node.cc DEPS proto_desc)
|
|
cc_library(graph SRCS graph.cc DEPS node pretty_log)
|
|
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
|
|
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
|
|
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
|
|
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
|
|
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
|
|
|
|
pass_library(graph_to_program_pass base)
|
|
pass_library(graph_viz_pass base)
|
|
pass_library(lock_free_optimize_pass base)
|
|
pass_library(cpu_quantize_placement_pass base)
|
|
pass_library(cpu_quantize_pass inference)
|
|
pass_library(cpu_quantize_squash_pass inference)
|
|
pass_library(fc_fuse_pass inference)
|
|
pass_library(attention_lstm_fuse_pass inference)
|
|
pass_library(infer_clean_graph_pass inference)
|
|
pass_library(fc_lstm_fuse_pass inference)
|
|
pass_library(embedding_fc_lstm_fuse_pass inference)
|
|
pass_library(fc_gru_fuse_pass inference)
|
|
pass_library(seq_concat_fc_fuse_pass inference)
|
|
pass_library(multi_batch_merge_pass base)
|
|
pass_library(conv_bn_fuse_pass inference)
|
|
pass_library(seqconv_eltadd_relu_fuse_pass inference)
|
|
pass_library(seqpool_concat_fuse_pass inference)
|
|
pass_library(repeated_fc_relu_fuse_pass inference)
|
|
pass_library(squared_mat_sub_fuse_pass inference)
|
|
pass_library(is_test_pass base)
|
|
pass_library(conv_elementwise_add_act_fuse_pass inference)
|
|
pass_library(conv_elementwise_add2_act_fuse_pass inference)
|
|
pass_library(conv_elementwise_add_fuse_pass inference)
|
|
pass_library(conv_affine_channel_fuse_pass inference)
|
|
pass_library(transpose_flatten_concat_fuse_pass inference)
|
|
pass_library(identity_scale_op_clean_pass base)
|
|
pass_library(sync_batch_norm_pass base)
|
|
pass_library(runtime_context_cache_pass base)
|
|
|
|
# There may be many transpose-flatten structures in a model, and the output of
|
|
# these structures will be used as inputs to the concat Op. This pattern will
|
|
# be detected by our pass. The index here represents the number of structures in the
|
|
# pattern. We use index 3 ~ 6, because these quantities of structures are
|
|
# common in the models.
|
|
foreach (index RANGE 3 6)
|
|
file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n")
|
|
endforeach()
|
|
|
|
if(WITH_MKLDNN)
|
|
pass_library(mkldnn_placement_pass base mkldnn)
|
|
pass_library(depthwise_conv_mkldnn_pass base mkldnn)
|
|
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
|
|
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
|
|
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
|
|
endif()
|
|
|
|
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
|
|
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
|
|
|
|
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
|
|
|
|
cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
|
|
|
|
cc_test(node_test SRCS node_test.cc DEPS node)
|
|
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
|
|
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
|
|
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
|
|
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
|
|
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
|
|
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
|
|
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
|
|
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
|
|
cc_test(test_cpu_quantize_placement_pass SRCS cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
|
|
cc_test(test_cpu_quantize_pass SRCS cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
|
|
cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
|
|
if(NOT WIN32)
|
|
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
|
|
endif()
|
|
if (WITH_MKLDNN)
|
|
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
|
|
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
|
|
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
|
|
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
|
|
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
|
|
endif ()
|