fix bert scripts to adapt the new concept of repeatcount in minddata

pull/3269/head
chenhaozhe 5 years ago
parent ad651f38bf
commit 6fdf380923

@ -64,7 +64,6 @@ def run_pretrain():
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(reserve_class_name_in_scope=False) context.set_context(reserve_class_name_in_scope=False)
context.set_context(variable_memory_max_size="30GB")
ckpt_save_dir = args_opt.save_checkpoint_path ckpt_save_dir = args_opt.save_checkpoint_path
if args_opt.distribute == "true": if args_opt.distribute == "true":
if args_opt.device_target == 'Ascend': if args_opt.device_target == 'Ascend':
@ -99,47 +98,49 @@ def run_pretrain():
logger.warning('Gpu only support fp32 temporarily, run with fp32.') logger.warning('Gpu only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32 bert_net_cfg.compute_type = mstype.float32
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
ds = create_bert_dataset(1, device_num, rank, args_opt.do_shuffle, new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir)
new_repeat_count = args_opt.epoch_size
if args_opt.train_steps > 0: if args_opt.train_steps > 0:
new_repeat_count = min(args_opt.epoch_size, args_opt.train_steps // args_opt.data_sink_steps) new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True) else:
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
if cfg.optimizer == 'Lamb': if cfg.optimizer == 'Lamb':
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=cfg.Lamb.warmup_steps, warmup_steps=cfg.Lamb.warmup_steps,
decay_steps=ds.get_dataset_size() * new_repeat_count, decay_steps=args_opt.train_steps,
power=cfg.Lamb.power) power=cfg.Lamb.power)
params = net_with_loss.trainable_params() params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params)) decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params}] {'params': other_params},
{'order_params': params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum': elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum) momentum=cfg.Momentum.momentum)
elif cfg.optimizer == 'AdamWeightDecay': elif cfg.optimizer == 'AdamWeightDecay':
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=cfg.AdamWeightDecay.warmup_steps, warmup_steps=cfg.AdamWeightDecay.warmup_steps,
decay_steps=ds.get_dataset_size() * new_repeat_count, decay_steps=args_opt.train_steps,
power=cfg.AdamWeightDecay.power) power=cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params() params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}] {'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else: else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
format(cfg.optimizer)) format(cfg.optimizer))
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
if args_opt.enable_save_ckpt == "true": if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num) keep_checkpoint_max=args_opt.save_checkpoint_num)
@ -148,19 +149,22 @@ def run_pretrain():
if args_opt.load_checkpoint_path: if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path) param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(netwithloss, param_dict) load_param_into_net(net_with_loss, param_dict)
if args_opt.enable_lossscale == "true": if args_opt.enable_lossscale == "true":
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor, scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window) scale_window=cfg.scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell) scale_update_cell=update_cell)
else: else:
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
model = Model(net_with_grads)
model.train(new_repeat_count, ds, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
model = Model(netwithgrads)
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
if __name__ == '__main__': if __name__ == '__main__':
numpy.random.seed(0) numpy.random.seed(0)
run_pretrain() run_pretrain()

@ -23,11 +23,9 @@ from mindspore import log as logger
from .config import bert_net_cfg from .config import bert_net_cfg
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true", def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
data_sink_steps=1, data_dir=None, schema_dir=None):
"""create train dataset""" """create train dataset"""
# apply repeat operations # apply repeat operations
repeat_count = epoch_size
files = os.listdir(data_dir) files = os.listdir(data_dir)
data_files = [] data_files = []
for file_name in files: for file_name in files:
@ -40,11 +38,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
num_shards=device_num, shard_id=rank, shard_equal_rows=True) num_shards=device_num, shard_id=rank, shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size() ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size) print('origin dataset size: ', ori_dataset_size)
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
ds.set_dataset_size(new_size)
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
@ -56,7 +49,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeat count: {}".format(ds.get_repeat_count())) logger.info("repeat count: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count return ds
def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",

@ -18,6 +18,7 @@ Functional Cells used in Bert finetune and evaluation.
""" """
import os import os
import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor

Loading…
Cancel
Save