transformer bucket batch modification

pull/6057/head
yuchaojie 4 years ago
parent a0e3fd6bf3
commit fa1247a85e

@ -54,10 +54,10 @@ After dataset preparation, you can start training and evaluation as follows:
```bash
# run training example
sh scripts/run_standalone_train_ascend.sh 0 52 /path/ende-l128-mindrecord00
sh scripts/run_standalone_train_ascend.sh 0 52 /path/ende-l128-mindrecord
# run distributed training example
sh scripts/run_distribute_train_ascend.sh 8 52 /path/newstest2014-l128-mindrecord rank_table.json
sh scripts/run_distribute_train_ascend.sh 8 52 /path/ende-l128-mindrecord rank_table.json
# run evaluation example
python eval.py > eval.log 2>&1 &
@ -104,6 +104,7 @@ usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [
[--enable_data_sink ENABLE_DATA_SINK] [--save_checkpoint_steps N]
[--save_checkpoint_num N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--data_path DATA_PATH]
[--bucket_boundaries BUCKET_LENGTH]
options:
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
@ -119,6 +120,7 @@ options:
--save_checkpoint_num number for saving checkpoint files: N, default is 30
--save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/"
--data_path path to dataset file: PATH, default is ""
--bucket_boundaries sequence lengths for different bucket: LIST, default is [16, 32, 48, 64, 128]
```
### Running Options
@ -179,13 +181,13 @@ Parameters for learning rate:
``` bash
paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all
python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128
python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128 --bucket [16, 32, 48, 64, 128]
```
- Convert the original data to mindrecord for evaluation:
``` bash
paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all
python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True
python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True --bucket [128]
```

@ -51,20 +51,29 @@ class SampleInstance():
return self.__str__()
def write_instance_to_file(writer, instance, tokenizer, max_seq_length):
def write_instance_to_file(writer, instance, tokenizer, max_seq_length, bucket):
"""Create files from `SampleInstance`s."""
def _find_bucket_length(num):
assert num <= bucket[-1]
for index in range(1, len(bucket)):
if bucket[index - 1] < num <= bucket[index]:
return bucket[index]
return bucket[0]
def _convert_ids_and_mask(input_tokens):
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
input_mask = [1] * len(input_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
seq_max_bucket_length = _find_bucket_length(len(input_ids))
while len(input_ids) < seq_max_bucket_length:
input_ids.append(0)
input_mask.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(input_ids) == seq_max_bucket_length
assert len(input_mask) == seq_max_bucket_length
return input_ids, input_mask
@ -93,7 +102,6 @@ def create_training_instance(source_words, target_words, max_seq_length, clip_to
if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length:
if clip_to_max_len:
print("####lalalal")
source_words = source_words[:min([len(source_words, max_seq_length-1)])]
target_words = target_words[:min([len(target_words, max_seq_length-1)])]
else:
@ -123,6 +131,8 @@ def main():
parser.add_argument("--clip_to_max_len", type=bool, default=False,
help='clip sequences to maximum sequence length.')
parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
parser.add_argument("--bucket", type=list, default=[16, 32, 48, 64, 128], help='bucket sequence length')
args = parser.parse_args()
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
@ -179,7 +189,7 @@ def main():
if instance is None:
continue
features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length)
features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length, args.bucket)
total_written += 1
if total_written <= 20:

@ -52,7 +52,7 @@ do
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--enable_data_sink="false" \
--checkpoint_path="" \
--save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \

@ -37,7 +37,7 @@ python train.py \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--enable_data_sink="false" \
--checkpoint_path="" \
--save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \

@ -134,6 +134,7 @@ class BeamSearchDecoder(nn.Cell):
eos_id=2,
compute_type=mstype.float32):
super(BeamSearchDecoder, self).__init__(auto_prefix=False)
self.seq_length = seq_length
self.batch_size = batch_size
self.vocab_size = vocab_size
self.beam_width = beam_width
@ -182,7 +183,7 @@ class BeamSearchDecoder(nn.Cell):
"""
One step for decode
"""
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask)
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask, self.seq_length)
log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size))
# select topk indices

@ -15,30 +15,40 @@
"""Data operations, will be used in train.py."""
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as deC
from .config import transformer_net_cfg
de.config.set_seed(1)
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true",
dataset_path=None):
dataset_path=None, bucket_boundaries=None):
"""create dataset"""
repeat_count = epoch_count
ds = de.MindDataset(dataset_path,
columns_list=["source_eos_ids", "source_eos_mask",
"target_sos_ids", "target_sos_mask",
"target_eos_ids", "target_eos_mask"],
shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id)
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
def batch_per_bucket(bucket_len, dataset_path):
dataset_path = dataset_path + "_" + str(bucket_len) + "_00"
ds = de.MindDataset(dataset_path,
columns_list=["source_eos_ids", "source_eos_mask",
"target_sos_ids", "target_sos_mask",
"target_eos_ids", "target_eos_mask"],
shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id)
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
# apply batch operations
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
# apply batch operations
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(epoch_count)
return ds
for i, _ in enumerate(bucket_boundaries):
bucket_len = bucket_boundaries[i]
ds_per = batch_per_bucket(bucket_len, dataset_path)
if i == 0:
ds = ds_per
else:
ds = ds + ds_per
ds = ds.shuffle(ds.get_dataset_size())
ds.channel_name = 'transformer'
return ds

@ -95,12 +95,13 @@ class TransformerTrainingLoss(nn.Cell):
self.flatten = P.Flatten()
self.neg = P.Neg()
self.cast = P.Cast()
self.flat_shape = (config.batch_size * config.seq_length,)
self.batch_size = config.batch_size
def construct(self, prediction_scores, label_ids, label_weights):
def construct(self, prediction_scores, label_ids, label_weights, seq_length):
"""Defines the computation performed."""
label_ids = self.reshape(label_ids, self.flat_shape)
label_weights = self.cast(self.reshape(label_weights, self.flat_shape), mstype.float32)
flat_shape = (self.batch_size * seq_length,)
label_ids = self.reshape(label_ids, flat_shape)
label_weights = self.cast(self.reshape(label_weights, flat_shape), mstype.float32)
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
@ -128,6 +129,7 @@ class TransformerNetworkWithLoss(nn.Cell):
self.transformer = TransformerModel(config, is_training, use_one_hot_embeddings)
self.loss = TransformerTrainingLoss(config)
self.cast = P.Cast()
self.shape = P.Shape()
def construct(self,
source_ids,
@ -136,8 +138,10 @@ class TransformerNetworkWithLoss(nn.Cell):
target_mask,
label_ids,
label_weights):
"""Transformer network with loss."""
prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask)
total_loss = self.loss(prediction_scores, label_ids, label_weights)
seq_length = self.shape(source_ids)[1]
total_loss = self.loss(prediction_scores, label_ids, label_weights, seq_length)
return self.cast(total_loss, mstype.float32)
@ -156,7 +160,6 @@ class TransformerTrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0):
super(TransformerTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)

File diff suppressed because it is too large Load Diff

@ -105,6 +105,9 @@ def argparse_init():
parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, "
"default is ./checkpoint/")
parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--bucket_boundaries", type=list, default=[16, 32, 48, 64, 128], help="sequence length for "
"different bucket")
return parser
def run_transformer_train():
@ -129,7 +132,8 @@ def run_transformer_train():
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
rank_id=rank_id, do_shuffle=args.do_shuffle,
enable_data_sink=args.enable_data_sink,
dataset_path=args.data_path)
dataset_path=args.data_path,
bucket_boundaries=args.bucket_boundaries)
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)

@ -24,12 +24,13 @@ from mindspore.nn.optim import Adam
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.callback import Callback
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC
from mindspore import context
from model_zoo.official.nlp.transformer.src.transformer_model import TransformerConfig
from model_zoo.official.nlp.transformer.src.transformer_for_train import TransformerNetworkWithLoss, \
TransformerTrainOneStepWithLossScaleCell
from model_zoo.official.nlp.transformer.src.config import cfg
from model_zoo.official.nlp.transformer.src.dataset import create_transformer_dataset
from model_zoo.official.nlp.transformer.src.config import cfg, transformer_net_cfg
from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr
DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"]
@ -76,6 +77,24 @@ def get_config(version='base', batch_size=1):
transformer_cfg = TransformerConfig(batch_size=batch_size)
return transformer_cfg
def load_test_data(batch_size=1, data_file=None):
"""Load test dataset."""
ds = de.MindDataset(data_file,
columns_list=["source_eos_ids", "source_eos_mask",
"target_sos_ids", "target_sos_mask",
"target_eos_ids", "target_eos_mask"],
shuffle=False)
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
return ds
class ModelCallback(Callback):
def __init__(self):
super(ModelCallback, self).__init__()
@ -120,10 +139,7 @@ def test_transformer():
batch_size = 96
epoch_size = 3
config = get_config(version=version, batch_size=batch_size)
dataset = create_transformer_dataset(epoch_count=1,
do_shuffle="false",
enable_data_sink="false",
dataset_path=DATA_DIR)
dataset = load_test_data(batch_size=transformer_net_cfg.batch_size, data_file=DATA_DIR)
netwithloss = TransformerNetworkWithLoss(config, True)

Loading…
Cancel
Save