|
|
|
@ -14,9 +14,9 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
if [ $# != 2 ]
|
|
|
|
|
if [ $# != 2 ] && [ $# != 3 ]
|
|
|
|
|
then
|
|
|
|
|
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH]"
|
|
|
|
|
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [cifar10|imagenet2012]"
|
|
|
|
|
exit 1
|
|
|
|
|
fi
|
|
|
|
|
|
|
|
|
@ -32,6 +32,19 @@ then
|
|
|
|
|
exit 1
|
|
|
|
|
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_type='cifar10'
|
|
|
|
|
if [ $# == 3 ]
|
|
|
|
|
then
|
|
|
|
|
if [ $3 != "cifar10" ] && [ $3 != "imagenet2012" ]
|
|
|
|
|
then
|
|
|
|
|
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
|
|
|
|
|
exit 1
|
|
|
|
|
fi
|
|
|
|
|
dataset_type=$3
|
|
|
|
|
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export DEVICE_NUM=8
|
|
|
|
|
export RANK_SIZE=8
|
|
|
|
|
export RANK_TABLE_FILE=$1
|
|
|
|
@ -45,8 +58,8 @@ do
|
|
|
|
|
cp *.py ./train_parallel$i
|
|
|
|
|
cp -r src ./train_parallel$i
|
|
|
|
|
cd ./train_parallel$i || exit
|
|
|
|
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
|
|
|
|
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
|
|
|
|
|
env > env.log
|
|
|
|
|
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log &
|
|
|
|
|
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 --dataset=$dataset_type &> log &
|
|
|
|
|
cd ..
|
|
|
|
|
done
|
|
|
|
|
done
|
|
|
|
|