diff --git a/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh b/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh index 6357b39518..ba47943917 100644 --- a/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh +++ b/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh @@ -14,10 +14,10 @@ # limitations under the License. # ============================================================================ -if [ $# -lt 2 ] +if [ $# -lt 3 ] then - echo "Usage:\n \ - sh run_train.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]\n \ + echo "Usage: \ + sh run_train.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [cifar10|imagenet]\ " exit 1 fi @@ -42,10 +42,23 @@ cd ../train || exit export CUDA_VISIBLE_DEVICES="$2" + +dataset_type='cifar10' +if [ $# == 3 ] +then + if [ $3 != "cifar10" ] && [ $3 != "imagenet" ] + then + echo "error: the selected dataset is neither cifar10 nor imagenet" + exit 1 + fi + dataset_type=$3 +fi + + if [ $1 -gt 1 ] then mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \ - python3 ${BASEPATH}/../train.py > train.log 2>&1 & + python3 ${BASEPATH}/../train.py --dataset_name=$dataset_type > train.log 2>&1 & else - python3 ${BASEPATH}/../train.py > train.log 2>&1 & + python3 ${BASEPATH}/../train.py --dataset_name=$dataset_type > train.log 2>&1 & fi