commit
f638f91020
@ -0,0 +1,213 @@
|
||||
#!/usr/bin/env python
|
||||
from paddle.trainer_config_helpers import *
|
||||
|
||||
height = 224
|
||||
width = 224
|
||||
num_class = 1000
|
||||
batch_size = get_config_arg('batch_size', int, 64)
|
||||
layer_num = get_config_arg("layer_num", int, 50)
|
||||
is_test = get_config_arg("is_test", bool, False)
|
||||
|
||||
args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
|
||||
define_py_data_sources2(
|
||||
"train.list", None, module="provider", obj="process", args=args)
|
||||
|
||||
settings(
|
||||
batch_size=batch_size,
|
||||
learning_rate=0.01 / batch_size,
|
||||
learning_method=MomentumOptimizer(0.9),
|
||||
regularization=L2Regularization(0.0005 * batch_size))
|
||||
|
||||
|
||||
#######################Network Configuration #############
|
||||
def conv_bn_layer(name,
|
||||
input,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
padding,
|
||||
channels=None,
|
||||
active_type=ReluActivation()):
|
||||
"""
|
||||
A wrapper for conv layer with batch normalization layers.
|
||||
Note:
|
||||
conv layer has no activation.
|
||||
"""
|
||||
|
||||
tmp = img_conv_layer(
|
||||
name=name + "_conv",
|
||||
input=input,
|
||||
filter_size=filter_size,
|
||||
num_channels=channels,
|
||||
num_filters=num_filters,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
act=LinearActivation(),
|
||||
bias_attr=False)
|
||||
return batch_norm_layer(
|
||||
name=name + "_bn", input=tmp, act=active_type, use_global_stats=is_test)
|
||||
|
||||
|
||||
def bottleneck_block(name, input, num_filters1, num_filters2):
|
||||
"""
|
||||
A wrapper for bottlenect building block in ResNet.
|
||||
Last conv_bn_layer has no activation.
|
||||
Addto layer has activation of relu.
|
||||
"""
|
||||
last_name = conv_bn_layer(
|
||||
name=name + '_branch2a',
|
||||
input=input,
|
||||
filter_size=1,
|
||||
num_filters=num_filters1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
last_name = conv_bn_layer(
|
||||
name=name + '_branch2b',
|
||||
input=last_name,
|
||||
filter_size=3,
|
||||
num_filters=num_filters1,
|
||||
stride=1,
|
||||
padding=1)
|
||||
last_name = conv_bn_layer(
|
||||
name=name + '_branch2c',
|
||||
input=last_name,
|
||||
filter_size=1,
|
||||
num_filters=num_filters2,
|
||||
stride=1,
|
||||
padding=0,
|
||||
active_type=LinearActivation())
|
||||
|
||||
return addto_layer(
|
||||
name=name + "_addto", input=[input, last_name], act=ReluActivation())
|
||||
|
||||
|
||||
def mid_projection(name, input, num_filters1, num_filters2, stride=2):
|
||||
"""
|
||||
A wrapper for middile projection in ResNet.
|
||||
projection shortcuts are used for increasing dimensions,
|
||||
and other shortcuts are identity
|
||||
branch1: projection shortcuts are used for increasing
|
||||
dimensions, has no activation.
|
||||
branch2x: bottleneck building block, shortcuts are identity.
|
||||
"""
|
||||
# stride = 2
|
||||
branch1 = conv_bn_layer(
|
||||
name=name + '_branch1',
|
||||
input=input,
|
||||
filter_size=1,
|
||||
num_filters=num_filters2,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
active_type=LinearActivation())
|
||||
|
||||
last_name = conv_bn_layer(
|
||||
name=name + '_branch2a',
|
||||
input=input,
|
||||
filter_size=1,
|
||||
num_filters=num_filters1,
|
||||
stride=stride,
|
||||
padding=0)
|
||||
last_name = conv_bn_layer(
|
||||
name=name + '_branch2b',
|
||||
input=last_name,
|
||||
filter_size=3,
|
||||
num_filters=num_filters1,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
last_name = conv_bn_layer(
|
||||
name=name + '_branch2c',
|
||||
input=last_name,
|
||||
filter_size=1,
|
||||
num_filters=num_filters2,
|
||||
stride=1,
|
||||
padding=0,
|
||||
active_type=LinearActivation())
|
||||
|
||||
return addto_layer(
|
||||
name=name + "_addto", input=[branch1, last_name], act=ReluActivation())
|
||||
|
||||
|
||||
img = data_layer(name='image', size=height * width * 3)
|
||||
|
||||
|
||||
def deep_res_net(res2_num=3, res3_num=4, res4_num=6, res5_num=3):
|
||||
"""
|
||||
A wrapper for 50,101,152 layers of ResNet.
|
||||
res2_num: number of blocks stacked in conv2_x
|
||||
res3_num: number of blocks stacked in conv3_x
|
||||
res4_num: number of blocks stacked in conv4_x
|
||||
res5_num: number of blocks stacked in conv5_x
|
||||
"""
|
||||
# For ImageNet
|
||||
# conv1: 112x112
|
||||
tmp = conv_bn_layer(
|
||||
"conv1",
|
||||
input=img,
|
||||
filter_size=7,
|
||||
channels=3,
|
||||
num_filters=64,
|
||||
stride=2,
|
||||
padding=3)
|
||||
tmp = img_pool_layer(name="pool1", input=tmp, pool_size=3, stride=2)
|
||||
|
||||
# conv2_x: 56x56
|
||||
tmp = mid_projection(
|
||||
name="res2_1", input=tmp, num_filters1=64, num_filters2=256, stride=1)
|
||||
for i in xrange(2, res2_num + 1, 1):
|
||||
tmp = bottleneck_block(
|
||||
name="res2_" + str(i), input=tmp, num_filters1=64, num_filters2=256)
|
||||
|
||||
# conv3_x: 28x28
|
||||
tmp = mid_projection(
|
||||
name="res3_1", input=tmp, num_filters1=128, num_filters2=512)
|
||||
for i in xrange(2, res3_num + 1, 1):
|
||||
tmp = bottleneck_block(
|
||||
name="res3_" + str(i),
|
||||
input=tmp,
|
||||
num_filters1=128,
|
||||
num_filters2=512)
|
||||
|
||||
# conv4_x: 14x14
|
||||
tmp = mid_projection(
|
||||
name="res4_1", input=tmp, num_filters1=256, num_filters2=1024)
|
||||
for i in xrange(2, res4_num + 1, 1):
|
||||
tmp = bottleneck_block(
|
||||
name="res4_" + str(i),
|
||||
input=tmp,
|
||||
num_filters1=256,
|
||||
num_filters2=1024)
|
||||
|
||||
# conv5_x: 7x7
|
||||
tmp = mid_projection(
|
||||
name="res5_1", input=tmp, num_filters1=512, num_filters2=2048)
|
||||
for i in xrange(2, res5_num + 1, 1):
|
||||
tmp = bottleneck_block(
|
||||
name="res5_" + str(i),
|
||||
input=tmp,
|
||||
num_filters1=512,
|
||||
num_filters2=2048)
|
||||
|
||||
tmp = img_pool_layer(
|
||||
name='avgpool',
|
||||
input=tmp,
|
||||
pool_size=7,
|
||||
stride=1,
|
||||
pool_type=AvgPooling())
|
||||
|
||||
return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation())
|
||||
|
||||
|
||||
if layer_num == 50:
|
||||
resnet = deep_res_net(3, 4, 6, 3)
|
||||
elif layer_num == 101:
|
||||
resnet = deep_res_net(3, 4, 23, 3)
|
||||
elif layer_num == 152:
|
||||
resnet = deep_res_net(3, 8, 36, 3)
|
||||
else:
|
||||
print("Wrong layer number.")
|
||||
|
||||
lbl = data_layer(name="label", size=num_class)
|
||||
loss = cross_entropy(name='loss', input=resnet, label=lbl)
|
||||
inputs(img, lbl)
|
||||
outputs(loss)
|
@ -0,0 +1,188 @@
|
||||
if(NOT WITH_GPU)
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
|
||||
set(paddle_known_gpu_archs7 "30 35 50 52")
|
||||
set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
|
||||
|
||||
######################################################################################
|
||||
# A function for automatic detection of GPUs installed (if autodetection is enabled)
|
||||
# Usage:
|
||||
# detect_installed_gpus(out_variable)
|
||||
function(detect_installed_gpus out_variable)
|
||||
if(NOT CUDA_gpu_detect_output)
|
||||
set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu)
|
||||
|
||||
file(WRITE ${cufile} ""
|
||||
"#include <cstdio>\n"
|
||||
"int main() {\n"
|
||||
" int count = 0;\n"
|
||||
" if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
|
||||
" if (count == 0) return -1;\n"
|
||||
" for (int device = 0; device < count; ++device) {\n"
|
||||
" cudaDeviceProp prop;\n"
|
||||
" if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
|
||||
" std::printf(\"%d.%d \", prop.major, prop.minor);\n"
|
||||
" }\n"
|
||||
" return 0;\n"
|
||||
"}\n")
|
||||
|
||||
execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "-ccbin=${CUDA_HOST_COMPILER}"
|
||||
"--run" "${cufile}"
|
||||
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
|
||||
RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
if(nvcc_res EQUAL 0)
|
||||
# only keep the last line of nvcc_out
|
||||
STRING(REGEX REPLACE ";" "\\\\;" nvcc_out "${nvcc_out}")
|
||||
STRING(REGEX REPLACE "\n" ";" nvcc_out "${nvcc_out}")
|
||||
list(GET nvcc_out -1 nvcc_out)
|
||||
string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}")
|
||||
set(CUDA_gpu_detect_output ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_installed_gpus tool" FORCE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT CUDA_gpu_detect_output)
|
||||
message(STATUS "Automatic GPU detection failed. Building for all known architectures.")
|
||||
set(${out_variable} ${paddle_known_gpu_archs} PARENT_SCOPE)
|
||||
else()
|
||||
set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
########################################################################
|
||||
# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME
|
||||
# Usage:
|
||||
# select_nvcc_arch_flags(out_variable)
|
||||
function(select_nvcc_arch_flags out_variable)
|
||||
# List of arch names
|
||||
set(archs_names "Kepler" "Maxwell" "Pascal" "All" "Manual")
|
||||
set(archs_name_default "All")
|
||||
if(NOT CMAKE_CROSSCOMPILING)
|
||||
list(APPEND archs_names "Auto")
|
||||
endif()
|
||||
|
||||
# set CUDA_ARCH_NAME strings (so it will be seen as dropbox in CMake-Gui)
|
||||
set(CUDA_ARCH_NAME ${archs_name_default} CACHE STRING "Select target NVIDIA GPU achitecture.")
|
||||
set_property( CACHE CUDA_ARCH_NAME PROPERTY STRINGS "" ${archs_names} )
|
||||
mark_as_advanced(CUDA_ARCH_NAME)
|
||||
|
||||
# verify CUDA_ARCH_NAME value
|
||||
if(NOT ";${archs_names};" MATCHES ";${CUDA_ARCH_NAME};")
|
||||
string(REPLACE ";" ", " archs_names "${archs_names}")
|
||||
message(FATAL_ERROR "Only ${archs_names} architeture names are supported.")
|
||||
endif()
|
||||
|
||||
if(${CUDA_ARCH_NAME} STREQUAL "Manual")
|
||||
set(CUDA_ARCH_BIN ${paddle_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported")
|
||||
set(CUDA_ARCH_PTX "50" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for")
|
||||
mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX)
|
||||
else()
|
||||
unset(CUDA_ARCH_BIN CACHE)
|
||||
unset(CUDA_ARCH_PTX CACHE)
|
||||
endif()
|
||||
|
||||
if(${CUDA_ARCH_NAME} STREQUAL "Kepler")
|
||||
set(cuda_arch_bin "30 35")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
|
||||
set(cuda_arch_bin "50")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
|
||||
set(cuda_arch_bin "60 61")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
|
||||
set(cuda_arch_bin "70")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
|
||||
set(cuda_arch_bin ${paddle_known_gpu_archs})
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
|
||||
detect_installed_gpus(cuda_arch_bin)
|
||||
else() # (${CUDA_ARCH_NAME} STREQUAL "Manual")
|
||||
set(cuda_arch_bin ${CUDA_ARCH_BIN})
|
||||
endif()
|
||||
|
||||
# remove dots and convert to lists
|
||||
string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
|
||||
string(REGEX REPLACE "\\." "" cuda_arch_ptx "${CUDA_ARCH_PTX}")
|
||||
string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}")
|
||||
string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}")
|
||||
list(REMOVE_DUPLICATES cuda_arch_bin)
|
||||
list(REMOVE_DUPLICATES cuda_arch_ptx)
|
||||
|
||||
set(nvcc_flags "")
|
||||
set(nvcc_archs_readable "")
|
||||
|
||||
# Tell NVCC to add binaries for the specified GPUs
|
||||
foreach(arch ${cuda_arch_bin})
|
||||
if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
|
||||
# User explicitly specified PTX for the concrete BIN
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
|
||||
list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
|
||||
else()
|
||||
# User didn't explicitly specify PTX for the concrete BIN, we assume PTX=BIN
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
|
||||
list(APPEND nvcc_archs_readable sm_${arch})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Tell NVCC to add PTX intermediate code for the specified architectures
|
||||
foreach(arch ${cuda_arch_ptx})
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
|
||||
list(APPEND nvcc_archs_readable compute_${arch})
|
||||
endforeach()
|
||||
|
||||
string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
|
||||
set(${out_variable} ${nvcc_flags} PARENT_SCOPE)
|
||||
set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
message(STATUS "CUDA detected: " ${CUDA_VERSION})
|
||||
if (${CUDA_VERSION} LESS 7.0)
|
||||
set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
|
||||
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
|
||||
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
|
||||
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
|
||||
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
|
||||
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the
|
||||
# warning for now.
|
||||
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
|
||||
endif()
|
||||
|
||||
include_directories(${CUDA_INCLUDE_DIRS})
|
||||
list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
|
||||
if(NOT WITH_DSO)
|
||||
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${NCCL_LIBRARY})
|
||||
endif(NOT WITH_DSO)
|
||||
|
||||
# setting nvcc arch flags
|
||||
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
|
||||
list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
|
||||
message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA_readable}")
|
||||
|
||||
# Set C++11 support
|
||||
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||
|
||||
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
|
||||
# So, don't set these flags here.
|
||||
list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
|
||||
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
|
||||
# Set :expt-relaxed-constexpr to suppress Eigen warnings
|
||||
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
|
||||
|
||||
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG})
|
||||
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
|
||||
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
|
||||
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL})
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
|
||||
mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION)
|
@ -0,0 +1,36 @@
|
||||
=====================
|
||||
Data Reader Interface
|
||||
=====================
|
||||
|
||||
|
||||
DataTypes
|
||||
=========
|
||||
|
||||
.. automodule:: paddle.v2.data_type
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
DataFeeder
|
||||
==========
|
||||
|
||||
.. automodule:: paddle.v2.data_feeder
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
Reader
|
||||
======
|
||||
|
||||
.. automodule:: paddle.v2.reader
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. automodule:: paddle.v2.reader.creator
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
minibatch
|
||||
=========
|
||||
|
||||
.. automodule:: paddle.v2.minibatch
|
||||
:members:
|
||||
:noindex:
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue