py_test and test_image_classification_train support argument (#5934)

* py_test support argument, test_image_classification_train support argument

* use REMOVE_ITEM to rm item from list in cmake
release/0.11.0
Qiao Longfei 7 years ago committed by GitHub
parent 9de8ce171c
commit c9a96575d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -459,11 +459,11 @@ function(py_test TARGET_NAME)
if(WITH_TESTING)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python
${PYTHON_EXECUTABLE} ${py_test_SRCS}
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction()

@ -1,5 +1,11 @@
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_image_classification_train)
py_test(test_image_classification_train_resnet SRCS test_image_classification_train.py ARGS resnet)
py_test(test_image_classification_train_vgg SRCS test_image_classification_train.py ARGS vgg)
# default test
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()

@ -1,7 +1,9 @@
from __future__ import print_function
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import sys
def resnet_cifar10(input, depth=32):
@ -80,11 +82,18 @@ data_shape = [3, 32, 32]
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Add neural network config
# option 1. resnet
# net = resnet_cifar10(images, 32)
# option 2. vgg
net = vgg16_bn_drop(images)
net_type = "vgg"
if len(sys.argv) >= 2:
net_type = sys.argv[1]
if net_type == "vgg":
print("train vgg net")
net = vgg16_bn_drop(images)
elif net_type == "resnet":
print("train resnet")
net = resnet_cifar10(images, 32)
else:
raise ValueError("%s network is not supported" % net_type)
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)

Loading…
Cancel
Save