Compare commits

..

111 Commits

Author SHA1 Message Date
wuhuanzhou 587d99ae44
update compilation with C++14 (#31815)
5 years ago
tianshuo78520a b09c1ce09a
fix whl package push pypi (#31585)
5 years ago
Thunderbrook 393b3bd6b7
fix split core (#31892)
5 years ago
wuhuanzhou 3a95a0bc26
update cmake minimum version to 3.15 (#31807)
5 years ago
taixiurong 52b05baca3
fix some bug in transformer training in xpu (#31918)
5 years ago
Wenyu 5394194e3a
support minus-int idx to LayerList (#31750)
5 years ago
furnace ef8323d49e
[ROCM] Add ROCm support for warpctc op (#31817)
5 years ago
Jiawei Wang 95f808c878
fix stack op grad nullptr (#31962)
5 years ago
liym27 57d4288ad4
[dynamic setitem] Fix bug of dynamic setitem: Decerease axes to do right broadcast (#31960)
5 years ago
石晓伟 0fa6c8a35c
fix a syntax error, test=develop (#31930)
5 years ago
Pei Yang 98e803e04f
map_matmul_to_mul_pass support 3dim (#31958)
5 years ago
wuhuanzhou a37a7f67e1
modify CI recommend information (#31395)
5 years ago
jakpiase 6dca7a1de7
Added int8 kernel for oneDNN LSTM op (#31894)
5 years ago
Pei Yang 14b7e3cf06
[Paddle-TRT] TRT inference support for BERT/Transformer in paddle 2.0 api (#31744)
5 years ago
Zhou Wei 245252b86e
fix bug when dtype of to_tensor is core.VarType (#31931)
5 years ago
Zhen Wang e1f931610e
Fix save/load error in imperative qat UT. (#31937)
5 years ago
Yiqun Liu e50bc2c2a6
Enhance cmake to support specifying CUDA_ARCH_NAME to Ampere. (#31923)
5 years ago
Zhou Wei 04a49b097e
[Custom OP]Remove old custom OP and reduce whl package volume (#31813)
5 years ago
wangguanzhong fe2848686b
add exclusive for test_conv2d_op, test=develop (#31936)
5 years ago
chajchaj 73a6fa3ed0
add deprecated for softmax_with_cross_entropy (#31722)
5 years ago
Shang Zhizhou 8084b7594b
fix batchnorm when inpu dims < 3 (#31933)
5 years ago
zlsh80826 64ee255ffd
[Paddle-TRT] yolobox (#31755)
5 years ago
Aurelius84 c4b60efabd
Fix segment Fault from set_value (#31891)
5 years ago
wuhuanzhou 17030ff28b
fix op benchmark ci error caused by missing test_pr branch, test=document_fix (#31920)
5 years ago
niuliling123 a71d72d921
relu forward and backward with vectortype (#31869)
5 years ago
tianshuo78520a 8829a309fe
Delete cudnn6 code (#31835)
5 years ago
wanghuancoder b48841ba2e
modify API nn.Bilinear's doc (#31889)
5 years ago
liym27 525c32e33c
Fix bug of set_value op:Decerease axes to do right broadcast (#31875)
5 years ago
ronnywang 123949eb48
[ROCM] added a cudnn switch of conv2d for rocm platform (#31836)
5 years ago
Shang Zhizhou 61805d8f0a
fix cmake model path (#31866)
5 years ago
Jiabin Yang 51eb29de18
[CustomOP] Add shape related constructor for Tensor (#31681)
5 years ago
zlsh80826 e3a38d790a
[Paddle-TRT] roi_align_plugin (#31732)
5 years ago
zlsh80826 bfb5cf5567
[Paddle-TRT] trt affine channel converter (#31628)
5 years ago
cc b47478efc2
[dygraph qat] Use layer to calculate output scale (#31861)
5 years ago
lilong12 c3974d0e2a
[3D-parallel] Reformat pipeline parallel (#31786)
5 years ago
zlsh80826 01aa252624
[Paddle-TRT] multiclass nms (#31742)
5 years ago
Wilber 70b67f1029
fix go api bug. (#31857)
5 years ago
tianshuo78520a e804f08559
delete include framework.pb.h (#31859)
5 years ago
Chengmo f58cb01864
【Paddle.Fleet】fix dataset zip py3 bug (#31441)
5 years ago
Kaipeng Deng bf09dcb346
add GPU tensor notice & update default_collate_fn/default_convert_fn. test=develop (#31763)
5 years ago
Chen Weihang 27f2d8df8e
Polish two error messages (#31852)
5 years ago
Zhou Wei 511e204e62
LRScheduler.get_lr should not update lr in LinearWarmup (#31843)
5 years ago
niuliling123 6472d62093
Revert "add relu forward kernel and backward kernel (#31613)" (#31853)
5 years ago
winter-wang e7f28d6c0d
fix runtime crash when rnn model inference, test=develop (#31833)
5 years ago
parap1uie-s 5d89ec36dc
Update pooling.py (#31829)
5 years ago
Huihuang Zheng 649868ffb2
[Dy2stat] Fix the bug that loop_body_func may return single element (#31806)
5 years ago
Wojciech Uss e5f7a834d4
fix cache key in concat oneDNN kernel (#31820)
5 years ago
Aurelius84 f2cfc0f46d
[CustomOp]Avoid raising warning while import paddle (#31804)
5 years ago
cc 84a551380e
[dygraph qat] Refine saving output scale to infer program (#31784)
5 years ago
Chen Weihang 68497e7b39
change trainable to stop_gradient in optimizer (#31823)
5 years ago
ronnywang 270699e647
[ROCM] fix test_matmul_v2_op (#31802)
5 years ago
Zhou Wei 1eb927f935
Restore the third-party library cache for windows (#31811)
5 years ago
Chen Weihang 3f66e7deab
add cmath header for bfloat (#31792)
5 years ago
Feiyu Chan 4046f1303a
add coalesce_tensor into white list when checking re-creation of parameters (#31800)
5 years ago
Zhou Wei a70de87d76
Update windows compiler and CI from VS2015 to VS2017 (#31652)
5 years ago
Wilber f4d9212de2
trt plugin upgrade to pluginv2ext (#31670)
5 years ago
niuliling123 372ac08a17
add relu forward kernel and backward kernel (#31613)
5 years ago
Wojciech Uss 814b38e30f
update scale collection and propagation algorithm (#31783)
5 years ago
tianshuo78520a 513641e153
Delete fast_check_nan_inf (#31788)
5 years ago
Shang Zhizhou 9d04ef7369
fix tensorrt output varible reshape (#31733)
5 years ago
Qi Li 46dd1d4aad
[ROCM] fix reduce_sum nan in ROCM platform, test=develop (#31780)
5 years ago
gongweibao f72d197ec5
fix launch ps ut test=develop (#31771)
5 years ago
Tao Luo 032de0bfd0
update approval (#31782)
5 years ago
zlsh80826 bfced39eb6
[Paddle-TRT] nearest_interp op (#31626)
5 years ago
arlesniak 7ccf6b6030
[oneDNN] Initial bf16 amp integration (#31093)
5 years ago
lilong12 a501a7b0ca
[3D-parallel] add 1f1b scheduler for pipeline (#31566)
5 years ago
guofei ed7956a816
Fix skip_quant in QAT (#31704)
5 years ago
ronnywang 8c19d7aa2f
[ROCM] fix test_conv2d_transpose_op (#31749)
5 years ago
Ouyang Chao a45c8ca69d
fix bug of DepthwiseConvTransposeGradKernel (#31762)
5 years ago
Jacek Czaja 25fc2a1fdb
[oneDNN] Added Elementwise Mul grad fp32/bf16 (#31647)
5 years ago
Chen Weihang 878e117b6d
[CustomOp] Support float16 in custom op (#31725)
5 years ago
ronnywang c9e1d9dc31
[ROCM] fix test_rnn_op (#31735)
5 years ago
zlsh80826 1c67cf0c98
run radix sort of proposals layer on context stream (#31631)
5 years ago
Chen Weihang e429deb0c4
[CustomOp] Support attribute in infershape function (#31713)
5 years ago
Adam Osewski a4a2b77def
[oneDNN] lookup_table op with support for BF16 data type. (#31558)
5 years ago
zlsh80826 c86e771e94
NMS Performance Optimization (#31634)
5 years ago
zlsh80826 50cafa0b0c
remove redundant sync, set collect/dist kernel to context stream, sub_lod memcpy opt (#31641)
5 years ago
cc 1d197f6c97
[dgraph qat] Refine calculating output scale of dygraph qat (#31710)
5 years ago
ronnywang 420527f0d9
[ROCM] fix layer_norm, norm, p_norm, test_sequence_softmax_op, test_math_op_patch_var_base (#31709)
5 years ago
Chen Weihang 87852616aa
[CustomOp] Support complex dtype in custom op (#31657)
5 years ago
zlsh80826 fe241fd02f
[Paddle-TRT] gather converter (#31640)
5 years ago
zlsh80826 4ea3427865
[Paddle-TRT] support batch axis concatenation when using dynamic shape (#31627)
5 years ago
Zhou Wei d4282ea97e
fix multi cuda environment bug (#31694)
5 years ago
Chengmo 09482ddec4
【Paddle.Fleet】Fix one ps gradient clip (#31664)
5 years ago
Kaipeng Deng 740359edaf
remove useless import (#31700)
5 years ago
Zhang Ting 7f50bb7ec1
support NHWC for temporal_shift op (#31642)
5 years ago
liym27 402288ad65
In __getitem__, convert integers to int64 Tensor not int32 to be compatible with Lite(#31658)
5 years ago
Chen Weihang 2fbe9b097a
[CustomOp] Remove Eigen dependencies of float16 (#31669)
5 years ago
cc 19592d2b71
Refine dygraph qat, test=develop (#31680)
5 years ago
Zhou Wei 4c0c55bba1
support Geforce RTX 30+ GPU (#31529)
5 years ago
YUNSHEN XIE cdc5a55ac1
turn off added ut check on windows (#31660)
5 years ago
Qi Li d9b50f664f
[ROCM] update ci scripts and dockefile, test=develop (#31551)
5 years ago
YUNSHEN XIE 1a6e3b04cd
Second optimization of retry method (#31646)
5 years ago
wuhuanzhou 41e9ecfd1f
Optimize compilation with Ninja (#31449)
5 years ago
yiak c1b1ccfbf5
Update tinyformat.h (#31612)
5 years ago
gongweibao 9c624b16d5
Extend unittest time of (#31570)
5 years ago
YUNSHEN XIE 580442ceba
fix wget with no proxy on windows (#31505)
5 years ago
ronnywang da10c5cf8b
[ROCM] fix softmax_with_cross_entropy_op, test=develop (#31629)
5 years ago
LielinJiang 75433126df
Fix summary bug when calaculating output shape (#31549)
5 years ago
ShenLiang c3634c6b0a
fix amp bug of fleet (#31532)
5 years ago
Chen Weihang 027b574a0e
[CustomOp] Remove the dependence of the underlying data types on eigen (#31602)
5 years ago
WangXi 9066b74f58
c_gen_nccl_id add SocketServer to persit server (#31589)
5 years ago
Kaipeng Deng a32e8bf1e7
DataLoader supprot dict str (#31481)
5 years ago
Chen Weihang 30a627aaf3
Normalized function parameter writing (#31588)
5 years ago
Pei Yang cac9635a67
[Paddle-TRT] Fix engine key in trt int8 calibration (#31513)
5 years ago
Shang Zhizhou 50ac7dbfd0
Trt elementwise plugin serialize (#31587)
5 years ago
guofei ef0dd3efed
Support loading parameters from checkpoint to save quantized model (#31419)
5 years ago
whs da9dda5c9b
Make CreateProgramDesc more robust (#31543)
5 years ago
hong 99dcd66508
try to fix imperative orc unitest error; test=develop (#31568)
5 years ago
Qi Li 3d5aa9d10a
[ROCM] fix conv2d and conv3d op, test=develop (#31553)
5 years ago
YUNSHEN XIE f302bb4f8b
help timeout ut debug (#31500)
5 years ago

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
cmake_minimum_required(VERSION 3.10) cmake_minimum_required(VERSION 3.15)
cmake_policy(VERSION 3.10)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
@ -38,11 +39,6 @@ endif()
if (WITH_GPU AND WITH_ASCEND) if (WITH_GPU AND WITH_ASCEND)
message(FATAL_ERROR "Error when compile GPU and ASCEND at the same time") message(FATAL_ERROR "Error when compile GPU and ASCEND at the same time")
endif() endif()
# cmake 3.12, 3.13, 3.14 will append gcc link options to nvcc, and nvcc doesn't recognize them.
if(WITH_GPU AND (${CMAKE_VERSION} VERSION_GREATER_EQUAL 3.12) AND (${CMAKE_VERSION} VERSION_LESS 3.15))
message(FATAL_ERROR "cmake ${CMAKE_VERSION} is not supported when WITH_GPU=ON because of bug https://cmake.org/pipermail/cmake/2018-September/068195.html. "
"You can use cmake 3.16 (recommended), 3.10, 3.11, 3.15 or 3.17. Please refer to the install document: https://cmake.org/install/")
endif()
if(WITH_GPU AND NOT APPLE) if(WITH_GPU AND NOT APPLE)
enable_language(CUDA) enable_language(CUDA)
@ -61,6 +57,7 @@ if(WITH_MUSL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy")
endif() endif()
if(WIN32) if(WIN32)
option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) option(MSVC_STATIC_CRT "use static C Runtime library by default" ON)
@ -72,6 +69,13 @@ if(WIN32)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj")
if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /Zc:inline")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /Zc:inline")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /Zc:inline")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Zc:inline")
endif()
if (MSVC_STATIC_CRT) if (MSVC_STATIC_CRT)
message(STATUS "Use static C runtime time, refer to https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=vs-2019") message(STATUS "Use static C runtime time, refer to https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=vs-2019")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /MTd") set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /MTd")
@ -88,7 +92,7 @@ if(WIN32)
endif() endif()
endforeach(flag_var) endforeach(flag_var)
endif() endif()
# NOTE(Avin0323): Less parallel count result in faster compilation. # NOTE(Avin0323): Less parallel count result in faster compilation.
math(EXPR PROCESS_MAX "${CPU_CORES} * 2 / 3") math(EXPR PROCESS_MAX "${CPU_CORES} * 2 / 3")
# windows build turn off warnings, use parallel compiling. # windows build turn off warnings, use parallel compiling.
@ -116,6 +120,10 @@ if(WIN32)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838")
foreach(flag_var CMAKE_SHARED_LINKER_FLAGS CMAKE_STATIC_LINKER_FLAGS CMAKE_EXE_LINKER_FLAGS CMAKE_LINKER_FLAGS)
set(${flag_var} "${${flag_var}} /ignore:4049 /ignore:4217 /ignore:4006 /ignore:4221")
endforeach(flag_var)
if (WITH_WIN_DUMP_DBG) if (WITH_WIN_DUMP_DBG)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Zi") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Zi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zi") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zi")

@ -74,7 +74,7 @@ endfunction()
# select_nvcc_arch_flags(out_variable) # select_nvcc_arch_flags(out_variable)
function(select_nvcc_arch_flags out_variable) function(select_nvcc_arch_flags out_variable)
# List of arch names # List of arch names
set(archs_names "Kepler" "Maxwell" "Pascal" "Volta" "Turing" "All" "Manual") set(archs_names "Kepler" "Maxwell" "Pascal" "Volta" "Turing" "Ampere" "All" "Manual")
set(archs_name_default "Auto") set(archs_name_default "Auto")
list(APPEND archs_names "Auto") list(APPEND archs_names "Auto")
@ -91,7 +91,7 @@ function(select_nvcc_arch_flags out_variable)
if(${CUDA_ARCH_NAME} STREQUAL "Manual") if(${CUDA_ARCH_NAME} STREQUAL "Manual")
set(CUDA_ARCH_BIN ${paddle_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported") set(CUDA_ARCH_BIN ${paddle_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported")
set(CUDA_ARCH_PTX "50" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for") set(CUDA_ARCH_PTX "" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for")
mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX) mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX)
else() else()
unset(CUDA_ARCH_BIN CACHE) unset(CUDA_ARCH_BIN CACHE)
@ -108,6 +108,8 @@ function(select_nvcc_arch_flags out_variable)
set(cuda_arch_bin "70") set(cuda_arch_bin "70")
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing") elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
set(cuda_arch_bin "75") set(cuda_arch_bin "75")
elseif(${CUDA_ARCH_NAME} STREQUAL "Ampere")
set(cuda_arch_bin "80")
elseif(${CUDA_ARCH_NAME} STREQUAL "All") elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs}) set(cuda_arch_bin ${paddle_known_gpu_archs})
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto") elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
@ -175,14 +177,22 @@ elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0) # CUDA 9.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs9}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs9})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets")
elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) # CUDA 10.x elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) # CUDA 10.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs10}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.x set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets")
elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.2) # CUDA 11.0/11.1
set(paddle_known_gpu_archs ${paddle_known_gpu_archs11}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs11})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets")
elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.2+
set(paddle_known_gpu_archs "${paddle_known_gpu_archs11} 86")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets")
endif() endif()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0) if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0)
@ -198,14 +208,11 @@ select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}")
message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}") message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}")
# Set C++11 support # Set C++14 support
set(CUDA_PROPAGATE_HOST_FLAGS OFF) set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here. # So, don't set these flags here.
if (NOT WIN32) # windows msvc2015 support c++11 natively. set(CMAKE_CUDA_STANDARD 14)
# -std=c++11 -fPIC not recoginize by msvc, -Xcompiler will be added by cmake.
set(CMAKE_CUDA_STANDARD 11)
endif(NOT WIN32)
# (Note) For windows, if delete /W[1-4], /W1 will be added defaultly and conflic with -w # (Note) For windows, if delete /W[1-4], /W1 will be added defaultly and conflic with -w
# So replace /W[1-4] with /W0 # So replace /W[1-4] with /W0

@ -94,7 +94,7 @@ macro(find_cudnn_version cudnn_header_file)
"${CUDNN_MAJOR_VERSION} * 1000 + "${CUDNN_MAJOR_VERSION} * 1000 +
${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}") ${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}")
message(STATUS "Current cuDNN header is ${cudnn_header_file} " message(STATUS "Current cuDNN header is ${cudnn_header_file} "
"Current cuDNN version is v${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}. ") "Current cuDNN version is v${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}.${CUDNN_PATCHLEVEL_VERSION}. ")
endif() endif()
endif() endif()
endmacro() endmacro()

@ -14,11 +14,15 @@
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
IF(WITH_ROCM)
add_definitions(-DWARPCTC_WITH_HIP)
ENDIF()
SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc) SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc)
SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc) SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git) set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git)
set(WARPCTC_TAG 95a461eddeabd51099ef059dcfada1117eb1bfb8) set(WARPCTC_TAG c690fc5755abbdbdc98ef78d51ec10a6748a8cd1)
SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE) CACHE PATH "Warp-ctc Directory" FORCE)
@ -49,14 +53,15 @@ ExternalProject_Add(
BUILD_ALWAYS 1 BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_C_FLAGS=$<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} -DCMAKE_C_FLAGS_DEBUG=$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} -DCMAKE_C_FLAGS_RELEASE=$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} -DCMAKE_CXX_FLAGS=$<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_CXX_FLAGS_RELEASE=$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_DEBUG=$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>
-DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU} -DWITH_GPU=${WITH_GPU}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP} -DWITH_OMP=${USE_OMP}
-DWITH_TORCH=OFF -DWITH_TORCH=OFF
-DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON

@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT)
elseif(WITH_SUNWAY) elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
else() else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_02_27.tar.gz" CACHE STRING "" FORCE) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_03_30.tar.gz" CACHE STRING "" FORCE)
endif() endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")

@ -4,10 +4,10 @@ include(CheckCCompilerFlag)
include(CheckCXXSymbolExists) include(CheckCXXSymbolExists)
include(CheckTypeSize) include(CheckTypeSize)
function(CheckCompilerCXX11Flag) function(CheckCompilerCXX14Flag)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.4)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") message(FATAL_ERROR "Unsupported GCC version. GCC >= 5.4 required.")
elseif(${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER 8.2) elseif(${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER 8.2)
message(WARNING "Found GCC ${CMAKE_CXX_COMPILER_VERSION} which is too high, recommended to use GCC 8.2") message(WARNING "Found GCC ${CMAKE_CXX_COMPILER_VERSION} which is too high, recommended to use GCC 8.2")
endif() endif()
@ -20,23 +20,15 @@ function(CheckCompilerCXX11Flag)
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.") message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
endif() endif()
else() else()
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3) if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.4)
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.") message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.4 required.")
endif() endif()
endif() endif()
endif() endif()
endfunction() endfunction()
CheckCompilerCXX11Flag() CheckCompilerCXX14Flag()
if (WITH_GPU) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
if (${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
endif()
# safe_set_flag # safe_set_flag
# #
# Set a compile flag only if compiler is support # Set a compile flag only if compiler is support

@ -492,10 +492,8 @@ function(nv_library TARGET_NAME)
message(FATAL "Please specify source file or library in nv_library.") message(FATAL "Please specify source file or library in nv_library.")
endif() endif()
endif(nv_library_SRCS) endif(nv_library_SRCS)
if (WIN32 AND ${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) if((CUDA_VERSION GREATER 9.2) AND (CUDA_VERSION LESS 11.0) AND (MSVC_VERSION LESS 1910))
if(${MSVC_VERSION} LESS_EQUAL 1900) set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS})
set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS})
endif()
endif() endif()
endif() endif()
endfunction(nv_library) endfunction(nv_library)
@ -512,7 +510,7 @@ function(nv_binary TARGET_NAME)
add_dependencies(${TARGET_NAME} ${nv_binary_DEPS}) add_dependencies(${TARGET_NAME} ${nv_binary_DEPS})
common_link(${TARGET_NAME}) common_link(${TARGET_NAME})
endif() endif()
if (WIN32 AND ${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) if((CUDA_VERSION GREATER 9.2) AND (CUDA_VERSION LESS 11.0) AND (MSVC_VERSION LESS 1910))
set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS}) set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS})
endif() endif()
endif() endif()
@ -539,7 +537,7 @@ function(nv_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
if (WIN32 AND ${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) if((CUDA_VERSION GREATER 9.2) AND (CUDA_VERSION LESS 11.0) AND (MSVC_VERSION LESS 1910))
set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS}) set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS})
endif() endif()
endif() endif()

@ -192,6 +192,15 @@ include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/* SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/*
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex64.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/float16.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
# CAPI inference library for only inference # CAPI inference library for only inference
set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING

@ -18,6 +18,10 @@ if(NOT WIN32)
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -DNDEBUG") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -DNDEBUG")
set(CMAKE_CXX_FLAGS_MINSIZEREL "-Os -DNDEBUG") set(CMAKE_CXX_FLAGS_MINSIZEREL "-Os -DNDEBUG")
else() else()
# It has not been used now, it can specify CUDA compile flag manualy,
# its use is to remvoe /Zi to reduce GPU static library size. But it's dangerous
# because CUDA will update by nvidia, then error will occur.
# Now, it's used in CUDA:[10.0, 10.2]
set(WIN_PROPS ${CMAKE_SOURCE_DIR}/cmake/paddle_win.props) set(WIN_PROPS ${CMAKE_SOURCE_DIR}/cmake/paddle_win.props)
endif() endif()

@ -15,7 +15,7 @@
<Warning>InheritFromHost</Warning> <Warning>InheritFromHost</Warning>
<BaseCommandLineTemplate>-ccbin "%(VCBinDir)" -x cu [GenerateRelocatableDeviceCode] [Include] [RequiredIncludes] [InterleaveSourceInPTX] [GPUDebugInfo] [GenerateLineInfo] [Keep] [KeepDir] [MaxRegCount] [PtxAsOptionV] [TargetMachinePlatform] [NvccCompilation] [CudaRuntime] [AdditionalOptions]</BaseCommandLineTemplate> <BaseCommandLineTemplate>-ccbin "%(VCBinDir)" -x cu [GenerateRelocatableDeviceCode] [Include] [RequiredIncludes] [InterleaveSourceInPTX] [GPUDebugInfo] [GenerateLineInfo] [Keep] [KeepDir] [MaxRegCount] [PtxAsOptionV] [TargetMachinePlatform] [NvccCompilation] [CudaRuntime] [AdditionalOptions]</BaseCommandLineTemplate>
<BuildCommandLineTemplate>--use-local-env --cl-version $(CudaClVersion)</BuildCommandLineTemplate> <BuildCommandLineTemplate>--use-local-env $(CudaClVersion)</BuildCommandLineTemplate>
<BuildDynamicCommandLineTemplate>[CodeGeneration]</BuildDynamicCommandLineTemplate> <BuildDynamicCommandLineTemplate>[CodeGeneration]</BuildDynamicCommandLineTemplate>
<CleanCommandLineTemplate>-clean</CleanCommandLineTemplate> <CleanCommandLineTemplate>-clean</CleanCommandLineTemplate>
<!-- <HostCommandLineTemplate>-Xcompiler &quot;/EHsc [Warning] /nologo [Optimization] $(CudaForceSynchronousPdbWrites) /Zi [RuntimeChecks] [Runtime] [TypeInfo]&quot;</HostCommandLineTemplate> --> <!-- <HostCommandLineTemplate>-Xcompiler &quot;/EHsc [Warning] /nologo [Optimization] $(CudaForceSynchronousPdbWrites) /Zi [RuntimeChecks] [Runtime] [TypeInfo]&quot;</HostCommandLineTemplate> -->

@ -50,6 +50,7 @@ output_data := value.Interface().([][]float32)
运行 运行
```bash ```bash
go mod init github.com/paddlepaddle
export LD_LIBRARY_PATH=`pwd`/paddle_c/paddle/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=`pwd`/paddle_c/paddle/lib:$LD_LIBRARY_PATH
go run ./demo/mobilenet.go go run ./demo/mobilenet.go
``` ```

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
package main package main
import "../paddle" import "github.com/paddlepaddle/paddle"
import "strings" import "strings"
import "io/ioutil" import "io/ioutil"
import "strconv" import "strconv"

@ -15,7 +15,7 @@
package paddle package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include // #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c // #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h> // #include <stdbool.h>
// #include <paddle_c_api.h> // #include <paddle_c_api.h>
import "C" import "C"

@ -15,7 +15,7 @@
package paddle package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include // #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c // #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h> // #include <stdbool.h>
// #include <stdlib.h> // #include <stdlib.h>
// #include <paddle_c_api.h> // #include <paddle_c_api.h>

@ -15,7 +15,7 @@
package paddle package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include // #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c // #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h> // #include <stdbool.h>
// #include "paddle_c_api.h" // #include "paddle_c_api.h"
import "C" import "C"
@ -88,7 +88,7 @@ func (predictor *Predictor) GetInputNames() []string {
} }
func (predictor *Predictor) GetOutputNames() []string { func (predictor *Predictor) GetOutputNames() []string {
names := make([]string, predictor.GetInputNum()) names := make([]string, predictor.GetOutputNum())
for i := 0; i < len(names); i++ { for i := 0; i < len(names); i++ {
names[i] = predictor.GetOutputName(i) names[i] = predictor.GetOutputName(i)
} }

@ -15,7 +15,7 @@
package paddle package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include // #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c // #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h> // #include <stdbool.h>
// #include <stdlib.h> // #include <stdlib.h>
// #include <string.h> // #include <string.h>
@ -209,7 +209,7 @@ func DecodeTensor(r *bytes.Reader, shape []int32, t reflect.Type, ptr reflect.Va
value := reflect.Indirect(ptr) value := reflect.Indirect(ptr)
value.Set(reflect.MakeSlice(t, int(shape[0]), int(shape[0]))) value.Set(reflect.MakeSlice(t, int(shape[0]), int(shape[0])))
if len(shape) == 1 && value.Len() > 0 { if len(shape) == 1 && value.Len() > 0 {
switch value.Index(1).Kind() { switch value.Index(0).Kind() {
case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32: case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32:
binary.Read(r, Endian(), value.Interface()) binary.Read(r, Endian(), value.Interface())
return return

@ -47,6 +47,22 @@ namespace paddle {
} \ } \
}() }()
#define PD_DISPATCH_FLOATING_AND_HALF_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT16, paddle::float16, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
::paddle::ToString(__dtype__), "`"); \
} \
}()
///////// Integral Dispatch Marco /////////// ///////// Integral Dispatch Marco ///////////
#define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
@ -68,6 +84,22 @@ namespace paddle {
} \ } \
}() }()
///////// Complex Dispatch Marco ///////////
#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating and Integral Dispatch Marco /////////// ///////// Floating and Integral Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \
@ -93,6 +125,55 @@ namespace paddle {
} \ } \
}() }()
///////// Floating and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating, Integral and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
// TODO(chenweihang): Add more Marcos in the future if needed // TODO(chenweihang): Add more Marcos in the future if needed
} // namespace paddle } // namespace paddle

@ -16,10 +16,17 @@ limitations under the License. */
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include "complex128.h" // NOLINT
#include "complex64.h" // NOLINT
#include "ext_exception.h" // NOLINT #include "ext_exception.h" // NOLINT
#include "float16.h" // NOLINT
namespace paddle { namespace paddle {
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
using float16 = paddle::platform::float16;
enum class DataType { enum class DataType {
BOOL, BOOL,
INT8, INT8,
@ -27,8 +34,11 @@ enum class DataType {
INT16, INT16,
INT32, INT32,
INT64, INT64,
FLOAT16,
FLOAT32, FLOAT32,
FLOAT64, FLOAT64,
COMPLEX64,
COMPLEX128,
// TODO(JiabinYang) support more data types if needed. // TODO(JiabinYang) support more data types if needed.
}; };
@ -46,24 +56,33 @@ inline std::string ToString(DataType dtype) {
return "int32_t"; return "int32_t";
case DataType::INT64: case DataType::INT64:
return "int64_t"; return "int64_t";
case DataType::FLOAT16:
return "float16";
case DataType::FLOAT32: case DataType::FLOAT32:
return "float"; return "float";
case DataType::FLOAT64: case DataType::FLOAT64:
return "double"; return "double";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
default: default:
PD_THROW("Unsupported paddle enum data type."); PD_THROW("Unsupported paddle enum data type.");
} }
} }
#define PD_FOR_EACH_DATA_TYPE(_) \ #define PD_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::BOOL) \ _(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \ _(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \ _(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \ _(int16_t, DataType::INT16) \
_(int, DataType::INT32) \ _(int, DataType::INT32) \
_(int64_t, DataType::INT64) \ _(int64_t, DataType::INT64) \
_(float, DataType::FLOAT32) \ _(float16, DataType::FLOAT16) \
_(double, DataType::FLOAT64) _(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64) \
_(complex64, DataType::COMPLEX64) \
_(complex128, DataType::COMPLEX128)
template <paddle::DataType T> template <paddle::DataType T>
struct DataTypeToCPPType; struct DataTypeToCPPType;

File diff suppressed because it is too large Load Diff

@ -52,6 +52,9 @@ class PD_DLL_DECL Tensor {
/// \brief Construct a Tensor on target Place for CustomOp. /// \brief Construct a Tensor on target Place for CustomOp.
/// Generally it's only used for user to create Tensor. /// Generally it's only used for user to create Tensor.
explicit Tensor(const PlaceType& place); explicit Tensor(const PlaceType& place);
/// \brief Construct a Tensor on target Place with shape for CustomOp.
/// Generally it's only used for user to create Tensor.
Tensor(const PlaceType& place, const std::vector<int64_t>& shape);
/// \brief Reset the shape of the tensor. /// \brief Reset the shape of the tensor.
/// Generally it's only used for the input tensor. /// Generally it's only used for the input tensor.
/// Reshape must be called before calling /// Reshape must be called before calling

@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/extension/include/ext_tensor.h" #include "paddle/fluid/extension/include/ext_tensor.h"
#include <utility> #include <utility>
#include "paddle/fluid/framework/custom_tensor_utils.h" #include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
@ -97,13 +102,32 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
void Tensor::reshape(const std::vector<int64_t> &shape) { void Tensor::reshape(const std::vector<int64_t> &shape) {
GET_CASTED_TENSOR GET_CASTED_TENSOR
tensor->Resize(framework::make_ddim(shape)); auto new_dim = framework::make_ddim(shape);
if (tensor->numel() != framework::product(new_dim)) {
LOG(WARNING) << "Custom Op: Calling reshape to a new shape which is bigger "
"or smaller"
<< "than original shape will not change your tensor's memory "
"Please call"
<< "paddle::Tensor::mutable_data<T>() after to reallocate "
"your tensor's size."
<< std::endl;
}
tensor->Resize(new_dim);
} }
Tensor::Tensor(const PlaceType &place) Tensor::Tensor(const PlaceType &place)
: tensor_(std::make_shared<framework::LoDTensor>()), : tensor_(std::make_shared<framework::LoDTensor>()),
place_(place), place_(place),
stream_(StreamWrapper()) {} stream_(StreamWrapper()) {}
Tensor::Tensor(const PlaceType &place, const std::vector<int64_t> &shape)
: tensor_(std::make_shared<framework::LoDTensor>()),
place_(place),
stream_(StreamWrapper()) {
GET_CASTED_TENSOR
tensor->Resize(framework::make_ddim(shape));
}
template <typename T> template <typename T>
T *Tensor::mutable_data(const PlaceType &place) { T *Tensor::mutable_data(const PlaceType &place) {
place_ = place; place_ = place;
@ -162,6 +186,12 @@ DataType Tensor::type() const {
return DataType::FLOAT64; return DataType::FLOAT64;
} else if (type == framework::proto::VarType::BOOL) { } else if (type == framework::proto::VarType::BOOL) {
return DataType::BOOL; return DataType::BOOL;
} else if (type == framework::proto::VarType::COMPLEX64) {
return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128;
} else if (type == framework::proto::VarType::FP16) {
return DataType::FLOAT16;
} }
// TODO(JiabinYang) Support more dtype here // TODO(JiabinYang) Support more dtype here
return DataType::FLOAT32; return DataType::FLOAT32;
@ -217,6 +247,12 @@ template PD_DLL_DECL Tensor
Tensor::copy_to<int16_t>(const PlaceType &target_place) const; Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor template PD_DLL_DECL Tensor
Tensor::copy_to<bool>(const PlaceType &target_place) const; Tensor::copy_to<bool>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
template PD_DLL_DECL float *Tensor::data<float>() const; template PD_DLL_DECL float *Tensor::data<float>() const;
template PD_DLL_DECL double *Tensor::data<double>() const; template PD_DLL_DECL double *Tensor::data<double>() const;
@ -226,6 +262,12 @@ template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const; template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const; template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL bool *Tensor::data<bool>() const; template PD_DLL_DECL bool *Tensor::data<bool>() const;
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const;
template PD_DLL_DECL paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
template PD_DLL_DECL float *Tensor::mutable_data<float>(); template PD_DLL_DECL float *Tensor::mutable_data<float>();
template PD_DLL_DECL double *Tensor::mutable_data<double>(); template PD_DLL_DECL double *Tensor::mutable_data<double>();
@ -235,6 +277,12 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>();
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(); template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>();
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(); template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>();
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(); template PD_DLL_DECL bool *Tensor::mutable_data<bool>();
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>();
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
template PD_DLL_DECL float *Tensor::mutable_data<float>(const PlaceType &place); template PD_DLL_DECL float *Tensor::mutable_data<float>(const PlaceType &place);
template PD_DLL_DECL double *Tensor::mutable_data<double>( template PD_DLL_DECL double *Tensor::mutable_data<double>(
@ -250,6 +298,12 @@ template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>( template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place); const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place); template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
std::vector<int64_t> Tensor::shape() const { std::vector<int64_t> Tensor::shape() const {
GET_CASTED_TENSOR GET_CASTED_TENSOR
@ -310,6 +364,21 @@ Tensor Tensor::cast(const DataType &target_type) const {
framework::VisitDataType( framework::VisitDataType(
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx)); dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break; break;
case framework::proto::VarType::COMPLEX64:
framework::VisitDataType(
dst_type,
CastDataType<paddle::platform::complex64>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX128:
framework::VisitDataType(dst_type,
CastDataType<paddle::platform::complex128>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP16:
framework::VisitDataType(
dst_type,
CastDataType<paddle::platform::float16>(*tensor, rlt_tensor_, ctx));
break;
// TODO(JiabinYang) Support more dtype here // TODO(JiabinYang) Support more dtype here
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(

@ -346,57 +346,25 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file(commit.h.in commit.h) configure_file(commit.h.in commit.h)
# Adapt to custom op mechanism: Include the header files related to the data type
# to avoid exposing the path of the underlying file
include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../extension/include)
cc_library(custom_tensor SRCS ../extension/src/ext_tensor.cc DEPS lod_tensor memory enforce) cc_library(custom_tensor SRCS ../extension/src/ext_tensor.cc DEPS lod_tensor memory enforce)
cc_library(op_meta_info SRCS ../extension/src/ext_op_meta_info.cc DEPS custom_tensor) cc_library(op_meta_info SRCS ../extension/src/ext_op_meta_info.cc DEPS custom_tensor)
cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper custom_tensor op_meta_info) cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper custom_tensor op_meta_info)
cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog) cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../extension/include)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator) set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator)
cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES}) cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES})
# Old custom op extension mechanism related, will be removed in 2.1.0
cc_library(paddle_framework_shared
SHARED SRCS executor.cc operator.cc
${CMAKE_CURRENT_SOURCE_DIR}/c/c_api.cc
${CMAKE_SOURCE_DIR}/paddle/fluid/imperative/layer.cc
DEPS ${FLUID_FRAMEWORK_MODULES})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
set_target_properties(paddle_framework_shared PROPERTIES OUTPUT_NAME paddle_framework)
target_link_libraries(paddle_framework_shared ${os_dependency_modules})
if (LINUX)
set(FLUID_FRAMEWORK_SHARED_LIB
${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_framework.so
CACHE INTERNAL "Fluid framework lib")
endif()
if (WIN32)
if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
set(paddle_framework_lib_path ${CMAKE_CURRENT_BINARY_DIR})
else()
set(paddle_framework_lib_path ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE})
endif()
set(FLUID_FRAMEWORK_IMPORT_LIB
${paddle_framework_lib_path}/paddle_framework.lib
CACHE INTERNAL "Fluid framework lib")
set(FLUID_FRAMEWORK_SHARED_LIB
${paddle_framework_lib_path}/paddle_framework.dll
CACHE INTERNAL "Fluid framework dll")
endif()
if(APPLE)
set(FLUID_FRAMEWORK_SHARED_LIB
${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_framework.dylib
CACHE INTERNAL "Fluid framework lib")
endif()
if(WITH_TESTING AND TEST selected_rows_test) if(WITH_TESTING AND TEST selected_rows_test)
set_tests_properties(selected_rows_test PROPERTIES TIMEOUT 120) set_tests_properties(selected_rows_test PROPERTIES TIMEOUT 120)
endif() endif()
# New custom op extension mechanism related ##### 2.0 New custom op extension mechanism related #####
# if not deps `layer`, will cause: undefined symbol: _ZN6paddle10imperative7VarBase9name_set_ # if not deps `layer`, will cause: undefined symbol: _ZN6paddle10imperative7VarBase9name_set_
set(PADDLE_CUSTOM_OP_MODULES custom_tensor op_meta_info custom_operator layer) set(PADDLE_CUSTOM_OP_MODULES custom_tensor op_meta_info custom_operator layer)

@ -1,53 +0,0 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/c/c_api.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
extern "C" {
paddle::framework::OpInfoMap &PD_GetOpInfoMap() {
return paddle::framework::OpInfoMap::Instance();
}
void PD_InitDevicesPool(paddle::platform::DeviceContextPool *pool) {
paddle::platform::DeviceContextPool::SetPool(pool);
}
std::vector<std::string> PD_GetGradOpDescStrs(
const paddle::framework::OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> *grad_to_var,
const std::vector<paddle::framework::BlockDesc *> &grad_block) {
auto &op_info = PD_GetOpInfoMap().Get(op_desc.Type());
std::vector<std::string> ret;
if (op_info.grad_op_maker_) {
auto grad_op_descs =
op_info.grad_op_maker_(op_desc, no_grad_set, grad_to_var, grad_block);
size_t op_num = grad_op_descs.size();
ret.resize(op_num);
for (size_t i = 0; i < op_num; ++i) {
PADDLE_ENFORCE_EQ(
grad_op_descs[i]->Proto()->SerializePartialToString(&ret[i]), true,
paddle::platform::errors::Unavailable(
"Cannot serialize operator desc message."));
}
}
return ret;
}
} // end extern "C"

@ -1,55 +0,0 @@
/* copyright (c) 2019 paddlepaddle authors. all rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
class OpInfoMap;
} // namespace framework
namespace platform {
class DeviceContextPool;
} // namespace platform
} // namespace paddle
#ifdef __cplusplus
extern "C" {
#endif
// C-API to get global OpInfo map.
paddle::framework::OpInfoMap &PD_GetOpInfoMap();
// C-API to init global DeviceContextPool from outside.
void PD_InitDevicesPool(paddle::platform::DeviceContextPool *pool);
// C-API to serialize the grad op protocol message to a binary string.
std::vector<std::string> PD_GetGradOpDescStrs(
const paddle::framework::OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> *grad_to_var,
const std::vector<paddle::framework::BlockDesc *> &grad_block);
#ifdef __cplusplus
}
#endif

@ -28,7 +28,6 @@ limitations under the License. */
#include "paddle/fluid/extension/include/ext_tensor.h" #include "paddle/fluid/extension/include/ext_tensor.h"
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/custom_tensor_utils.h" #include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_meta_info_helper.h" #include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
@ -178,7 +177,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
"Unsupported `%s` type value as custom attribute now. " "Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, " "Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, " "`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, " "`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, Please check whether " "`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.", "the attribute data type and data type string are matched.",
attr_type_str)); attr_type_str));
@ -327,7 +326,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
"Unsupported `%s` type value as custom attribute now. " "Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, " "Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, " "`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, " "`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, Please check whether " "`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.", "the attribute data type and data type string are matched.",
attr_type_str)); attr_type_str));
@ -581,7 +580,7 @@ void RegisterOperatorWithMetaInfo(
ctx->ShareDim(op_inputs[0], op_outputs[0]); ctx->ShareDim(op_inputs[0], op_outputs[0]);
}; };
} else { } else {
info.infer_shape_ = [op_inputs, op_outputs, info.infer_shape_ = [op_inputs, op_outputs, op_attrs,
infer_shape_func](InferShapeContext* ctx) { infer_shape_func](InferShapeContext* ctx) {
std::vector<std::vector<int64_t>> input_shapes; std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes; std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes;
@ -606,8 +605,50 @@ void RegisterOperatorWithMetaInfo(
} }
} }
std::vector<boost::any> custom_attrs;
for (auto& attr_str : op_attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx->Attrs().Get<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx->Attrs().Get<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx->Attrs().Get<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx->Attrs().Get<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx->Attrs().Get<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
// NOTE(chenweihang): InferShape can't support std::vector<int64_t>
// attr type, because the input type is std::vector<int64_t>, only
// can use one rule to parse std::vector<int64_t> parameter
continue;
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<std::string>`, "
"Please check whether the attribute data type and "
"data type string are matched.",
attr_type_str));
}
}
VLOG(1) << "Custom Operator: InferShape - calc output ddim."; VLOG(1) << "Custom Operator: InferShape - calc output ddim.";
auto output_shapes = infer_shape_func(input_shapes, vec_input_shapes); auto output_shapes =
infer_shape_func(input_shapes, vec_input_shapes, custom_attrs);
VLOG(1) << "Custom Operator: InferShape - set output ddim."; VLOG(1) << "Custom Operator: InferShape - set output ddim.";
for (size_t i = 0; i < op_outputs.size(); ++i) { for (size_t i = 0; i < op_outputs.size(); ++i) {
@ -757,10 +798,39 @@ void RegisterOperatorWithMetaInfo(
return new CustomOperator(type, inputs, outputs, attrs); return new CustomOperator(type, inputs, outputs, attrs);
}; };
// Grad InferShape (gradient's shape is same with forward input default) // Grad InferShape
grad_info.infer_shape_ = [grad_op_outputs](InferShapeContext* ctx) { grad_info.infer_shape_ = [grad_op_inputs,
grad_op_outputs](InferShapeContext* ctx) {
// 1. if forward input exists, gradient's shape is same with forward input
// default
// [Suitable for most situations]
// 2. if forward input not exists, and only contains one grad input and
// output,
// use grad input shape as grad output shape
// [Suitable for the situation that forward input is not used as
// backward input]
// TODO(chenweihang): support set grad op infershape func if needed
for (auto& out_name : grad_op_outputs) { for (auto& out_name : grad_op_outputs) {
ctx->ShareDim(detail::NoGrad(out_name), out_name); auto fwd_name = detail::NoGrad(out_name);
if (detail::IsDuplicableVar(fwd_name)) {
// Duplicable forward var must as backward input
ctx->ShareDim(fwd_name, out_name);
} else {
if (ctx->HasInput(fwd_name)) {
ctx->ShareDim(fwd_name, out_name);
} else {
PADDLE_ENFORCE_EQ(
grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL,
true,
platform::errors::Unavailable(
"Custom grad operator infershape error. "
"If a custom grad operator contains only one input and "
"only one output, the input shape will be directly set to "
"the output shape. Otherwise, Please set the forward input "
"as the grad operator's input."));
ctx->ShareDim(grad_op_inputs[0], out_name);
}
}
} }
}; };

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save