fix transformer gpu

pull/13630/head
panfengfeng 4 years ago
parent 1965ecb9a1
commit 25d18bc889

@ -15,7 +15,7 @@
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "Please run the script as: "
echo "sh run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH"
echo "for example: sh run_standalone_train.sh Ascend 0 52 /path/ende-l128-mindrecord00"
echo "It is better to use absolute path."
@ -31,17 +31,36 @@ DEVICE_ID=$2
EPOCH_SIZE=$3
DATA_PATH=$4
python train.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_target=$DEVICE_TARGET \
--device_id=$DEVICE_ID \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--checkpoint_path="" \
--save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \
--data_path=$DATA_PATH \
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 &
if [ $DEVICE_TARGET == 'Ascend' ];then
python train.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_target=$DEVICE_TARGET \
--device_id=$DEVICE_ID \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--checkpoint_path="" \
--save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \
--data_path=$DATA_PATH \
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 &
elif [ $DEVICE_TARGET == 'GPU' ];then
export CUDA_VISIBLE_DEVICES="$2"
python train.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_target=$DEVICE_TARGET \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--checkpoint_path="" \
--save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \
--data_path=$DATA_PATH \
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 &
else
echo "Not supported device target."
fi
cd ..

@ -122,7 +122,10 @@ def run_transformer_train():
"""
parser = argparse_init()
args, _ = parser.parse_known_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
if args.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
else:
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
if args.distribute == "true":

Loading…
Cancel
Save