gpu bert script

pull/1797/head
wilfChen 5 years ago
parent 5aeba82af3
commit 9ce9c21526

@ -21,6 +21,7 @@ import os
import argparse import argparse
import numpy import numpy
import mindspore.communication.management as D import mindspore.communication.management as D
import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
@ -28,6 +29,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
from mindspore import log as logger
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from src.dataset import create_bert_dataset from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg from src.config import cfg, bert_net_cfg
@ -55,6 +57,8 @@ class LossCallBack(Callback):
def run_pretrain(): def run_pretrain():
"""pre-train bert_clue""" """pre-train bert_clue"""
parser = argparse.ArgumentParser(description='bert pre_training') parser = argparse.ArgumentParser(description='bert pre_training')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
@ -74,11 +78,21 @@ def run_pretrain():
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(reserve_class_name_in_scope=False) context.set_context(reserve_class_name_in_scope=False)
ckpt_save_dir = args_opt.checkpoint_path
if args_opt.distribute == "true": if args_opt.distribute == "true":
device_num = args_opt.device_num if args_opt.device_target == 'Ascend':
D.init('hccl')
device_num = args_opt.device_num
rank = args_opt.device_id % device_num
else:
D.init('nccl')
device_num = D.get_group_size()
rank = D.get_rank()
ckpt_save_dir = args_opt.checkpoint_path + 'ckpt_' + str(rank) + '/'
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num) device_num=device_num)
@ -93,12 +107,15 @@ def run_pretrain():
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421]) auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421])
else: else:
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
D.init()
rank = args_opt.device_id % device_num
else: else:
rank = 0 rank = 0
device_num = 1 device_num = 1
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps, args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir) args_opt.data_dir, args_opt.schema_dir)
@ -130,7 +147,7 @@ def run_pretrain():
if args_opt.enable_save_ckpt == "true": if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num) keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', config=config_ck) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)
callback.append(ckpoint_cb) callback.append(ckpoint_cb)
if args_opt.checkpoint_path: if args_opt.checkpoint_path:

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
RANK_SIZE=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
mpirun --allow-run-as-root -n $RANK_SIZE \
python run_pretrain.py \
--device_target="GPU" \
--distribute="true" \
--epoch_size=$EPOCH_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="false" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &

@ -37,7 +37,7 @@ python run_pretrain.py \
--enable_lossscale="true" \ --enable_lossscale="true" \
--do_shuffle="true" \ --do_shuffle="true" \
--enable_data_sink="true" \ --enable_data_sink="true" \
--data_sink_steps=100 \ --data_sink_steps=1 \
--checkpoint_path="" \ --checkpoint_path="" \
--save_checkpoint_steps=10000 \ --save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \ --save_checkpoint_num=1 \

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
echo "=============================================================================================================="
DEVICE_ID=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
export CUDA_VISIBLE_DEVICES=$DEVICE_ID
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python run_pretrain.py \
--device_target="GPU" \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="false" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
Loading…
Cancel
Save