|
|
|
@ -28,7 +28,6 @@ from src.model_thor import Model
|
|
|
|
|
from src.utils import LossCallBack, BertLearningRate
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
import mindspore.communication.management as D
|
|
|
|
|
from mindspore.communication.management import get_rank
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
|
|
|
@ -41,55 +40,8 @@ from mindspore.common import set_seed
|
|
|
|
|
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_pretrain():
|
|
|
|
|
"""pre-train bert_clue"""
|
|
|
|
|
parser = argparse.ArgumentParser(description='bert pre_training')
|
|
|
|
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
|
|
|
|
help='device where the code will be implemented. (Default: Ascend)')
|
|
|
|
|
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
|
|
|
|
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
|
|
|
|
|
parser.add_argument("--device_id", type=int, default=4, help="Device id, default is 0.")
|
|
|
|
|
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
|
|
|
|
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
|
|
|
|
|
parser.add_argument("--enable_lossscale", type=str, default="false", help="Use lossscale or not, default is not.")
|
|
|
|
|
parser.add_argument("--do_shuffle", type=str, default="false", help="Enable shuffle for dataset, default is true.")
|
|
|
|
|
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
|
|
|
|
parser.add_argument("--data_sink_steps", type=int, default="100", help="Sink steps for each epoch, default is 1.")
|
|
|
|
|
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
|
|
|
|
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
|
|
|
|
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
|
|
|
|
"default is 1000.")
|
|
|
|
|
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
|
|
|
|
|
"meaning run all steps according to epoch number.")
|
|
|
|
|
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
|
|
|
|
|
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
|
|
|
|
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|
from src.thor_for_bert_arg import THOR
|
|
|
|
|
else:
|
|
|
|
|
from src.thor_for_bert import THOR
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
|
|
|
|
device_id=args_opt.device_id, save_graphs=False)
|
|
|
|
|
context.set_context(reserve_class_name_in_scope=False)
|
|
|
|
|
context.set_context(variable_memory_max_size="30GB")
|
|
|
|
|
context.set_context(max_call_depth=3000)
|
|
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|
if args_opt.device_target == 'Ascend':
|
|
|
|
|
D.init()
|
|
|
|
|
device_num = args_opt.device_num
|
|
|
|
|
rank = args_opt.device_id % device_num
|
|
|
|
|
else:
|
|
|
|
|
D.init()
|
|
|
|
|
device_num = D.get_group_size()
|
|
|
|
|
rank = D.get_rank()
|
|
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
|
|
|
|
|
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
|
|
|
|
device_num=device_num)
|
|
|
|
|
def _set_bert_all_reduce_split():
|
|
|
|
|
"""set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
|
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
|
if bert_net_cfg.num_hidden_layers == 12:
|
|
|
|
|
if bert_net_cfg.use_relative_positions:
|
|
|
|
@ -113,31 +65,17 @@ def run_pretrain():
|
|
|
|
|
"hccl_world_groupsum1")
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397],
|
|
|
|
|
"hccl_world_groupsum3")
|
|
|
|
|
else:
|
|
|
|
|
rank = 0
|
|
|
|
|
device_num = 1
|
|
|
|
|
|
|
|
|
|
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
|
|
|
|
|
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
|
|
|
|
|
bert_net_cfg.compute_type = mstype.float32
|
|
|
|
|
|
|
|
|
|
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
|
|
|
|
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
|
|
|
|
|
|
|
|
|
|
new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
|
|
|
|
|
if args_opt.train_steps > 0:
|
|
|
|
|
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
|
|
|
|
else:
|
|
|
|
|
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
|
|
|
|
|
logger.info("train steps: {}".format(args_opt.train_steps))
|
|
|
|
|
|
|
|
|
|
def _get_optimizer(args_opt, network):
|
|
|
|
|
"""get bert optimizer, support Lamb, Momentum, AdamWeightDecay and Thor."""
|
|
|
|
|
if cfg.optimizer == 'Lamb':
|
|
|
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
|
|
|
|
end_learning_rate=cfg.Lamb.end_learning_rate,
|
|
|
|
|
warmup_steps=cfg.Lamb.warmup_steps,
|
|
|
|
|
decay_steps=args_opt.train_steps,
|
|
|
|
|
power=cfg.Lamb.power)
|
|
|
|
|
params = net_with_loss.trainable_params()
|
|
|
|
|
params = network.trainable_params()
|
|
|
|
|
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
|
|
|
|
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
|
|
|
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
|
|
|
@ -145,7 +83,7 @@ def run_pretrain():
|
|
|
|
|
{'order_params': params}]
|
|
|
|
|
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
|
|
|
|
elif cfg.optimizer == 'Momentum':
|
|
|
|
|
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
|
|
|
|
optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
|
|
|
|
momentum=cfg.Momentum.momentum)
|
|
|
|
|
elif cfg.optimizer == 'AdamWeightDecay':
|
|
|
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
|
|
|
@ -153,7 +91,7 @@ def run_pretrain():
|
|
|
|
|
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
|
|
|
|
decay_steps=args_opt.train_steps,
|
|
|
|
|
power=cfg.AdamWeightDecay.power)
|
|
|
|
|
params = net_with_loss.trainable_params()
|
|
|
|
|
params = network.trainable_params()
|
|
|
|
|
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
|
|
|
|
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
|
|
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
|
|
|
@ -162,16 +100,83 @@ def run_pretrain():
|
|
|
|
|
|
|
|
|
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
|
|
|
|
elif cfg.optimizer == "Thor":
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|
from src.thor_for_bert_arg import THOR
|
|
|
|
|
else:
|
|
|
|
|
from src.thor_for_bert import THOR
|
|
|
|
|
lr = get_bert_lr()
|
|
|
|
|
damping = get_bert_damping()
|
|
|
|
|
optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum,
|
|
|
|
|
filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()),
|
|
|
|
|
filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()),
|
|
|
|
|
optimizer = THOR(filter(lambda x: x.requires_grad, network.get_parameters()), lr, cfg.Thor.momentum,
|
|
|
|
|
filter(lambda x: 'matrix_A' in x.name, network.get_parameters()),
|
|
|
|
|
filter(lambda x: 'matrix_G' in x.name, network.get_parameters()),
|
|
|
|
|
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
|
|
|
|
|
bert_net_cfg.batch_size, damping)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
|
|
|
|
|
format(cfg.optimizer))
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_pretrain():
|
|
|
|
|
"""pre-train bert_clue"""
|
|
|
|
|
parser = argparse.ArgumentParser(description='bert pre_training')
|
|
|
|
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
|
|
|
|
help='device where the code will be implemented. (Default: Ascend)')
|
|
|
|
|
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
|
|
|
|
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
|
|
|
|
|
parser.add_argument("--device_id", type=int, default=4, help="Device id, default is 0.")
|
|
|
|
|
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
|
|
|
|
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
|
|
|
|
|
parser.add_argument("--enable_lossscale", type=str, default="false", help="Use lossscale or not, default is not.")
|
|
|
|
|
parser.add_argument("--do_shuffle", type=str, default="false", help="Enable shuffle for dataset, default is true.")
|
|
|
|
|
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
|
|
|
|
parser.add_argument("--data_sink_steps", type=int, default="100", help="Sink steps for each epoch, default is 1.")
|
|
|
|
|
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
|
|
|
|
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
|
|
|
|
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
|
|
|
|
"default is 1000.")
|
|
|
|
|
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
|
|
|
|
|
"meaning run all steps according to epoch number.")
|
|
|
|
|
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
|
|
|
|
|
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
|
|
|
|
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
|
|
|
|
device_id=args_opt.device_id, save_graphs=False)
|
|
|
|
|
context.set_context(reserve_class_name_in_scope=False)
|
|
|
|
|
context.set_context(variable_memory_max_size="30GB")
|
|
|
|
|
context.set_context(max_call_depth=3000)
|
|
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|
D.init()
|
|
|
|
|
device_num = D.get_group_size()
|
|
|
|
|
rank = D.get_rank()
|
|
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
|
|
|
|
|
_set_bert_all_reduce_split()
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
|
|
|
|
device_num=device_num)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
rank = 0
|
|
|
|
|
device_num = 1
|
|
|
|
|
|
|
|
|
|
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
|
|
|
|
|
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
|
|
|
|
|
bert_net_cfg.compute_type = mstype.float32
|
|
|
|
|
|
|
|
|
|
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
|
|
|
|
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
|
|
|
|
|
|
|
|
|
|
new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
|
|
|
|
|
if args_opt.train_steps > 0:
|
|
|
|
|
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
|
|
|
|
else:
|
|
|
|
|
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
|
|
|
|
|
logger.info("train steps: {}".format(args_opt.train_steps))
|
|
|
|
|
|
|
|
|
|
optimizer = _get_optimizer(args_opt, net_with_loss)
|
|
|
|
|
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
|
|
|
|
|
if args_opt.enable_save_ckpt == "true" and rank == 0:
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
|
|
|
|