vgg16 support imagenet dataset on Ascend

pull/5924/head
CaoJian 5 years ago
parent 75045e3e2a
commit 41e6ceaa72

@ -14,9 +14,9 @@
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH]"
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [cifar10|imagenet2012]"
exit 1
fi
@ -32,6 +32,19 @@ then
exit 1
fi
dataset_type='cifar10'
if [ $# == 3 ]
then
if [ $3 != "cifar10" ] && [ $3 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
dataset_type=$3
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$1
@ -45,8 +58,8 @@ do
cp *.py ./train_parallel$i
cp -r src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
env > env.log
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log &
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 --dataset=$dataset_type &> log &
cd ..
done
done

@ -139,5 +139,8 @@ def vgg16(num_classes=1000, args=None, phase="train"):
>>> vgg16(num_classes=1000, args=args)
"""
if args is None:
from .config import cifar_cfg
args = cifar_cfg
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
return net

Loading…
Cancel
Save