fix doc & code bugs for gnmt

pull/9529/head
zhaojichen 4 years ago
parent 6b9e402790
commit de7e8850da

File diff suppressed because it is too large Load Diff

@ -1,7 +1,4 @@
{ {
"training_platform": {
"modelarts": false
},
"dataset_config": { "dataset_config": {
"random_seed": 50, "random_seed": 50,
"epochs": 6, "epochs": 6,
@ -9,10 +6,8 @@
"dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json", "dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json",
"pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001", "pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001",
"fine_tune_dataset": null, "fine_tune_dataset": null,
"test_dataset": null,
"valid_dataset": null, "valid_dataset": null,
"dataset_sink_mode": true, "dataset_sink_mode": true
"dataset_sink_step": 2
}, },
"model_config": { "model_config": {
"seq_length": 51, "seq_length": 51,

@ -53,7 +53,6 @@ def get_source_list(folder: str) -> List:
PARAM_NODES = {"dataset_config", PARAM_NODES = {"dataset_config",
"training_platform",
"model_config", "model_config",
"loss_scale_config", "loss_scale_config",
"learn_rate_config", "learn_rate_config",
@ -65,88 +64,99 @@ class GNMTConfig:
Configuration for `GNMT`. Configuration for `GNMT`.
Args: Args:
random_seed (int): Random seed. random_seed (int): Random seed, it can be changed.
batch_size (int): Batch size of input dataset.
epochs (int): Epoch number. epochs (int): Epoch number.
dataset_sink_mode (bool): Whether enable dataset sink mode. batch_size (int): Batch size of input dataset.
dataset_sink_step (int): Dataset sink step.
lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
lr (float): Initial learning rate.
min_lr (float): Minimum learning rate.
decay_start_step (int): Step to decay.
warmup_steps (int): Warm up steps.
dataset_schema (str): Path of dataset schema file. dataset_schema (str): Path of dataset schema file.
pre_train_dataset (str): Path of pre-training dataset file or folder. pre_train_dataset (str): Path of pre-training dataset file or folder.
fine_tune_dataset (str): Path of fine-tune dataset file or folder. fine_tune_dataset (str): Path of fine-tune dataset file or folder.
test_dataset (str): Path of test dataset file or folder. test_dataset (str): Path of test dataset file or folder.
valid_dataset (str): Path of validation dataset file or folder. valid_dataset (str): Path of validation dataset file or folder.
ckpt_path (str): Checkpoints save path. dataset_sink_mode (bool): Whether enable dataset sink mode.
save_ckpt_steps (int): Interval of saving ckpt. seq_length (int): Length of input sequence.
ckpt_prefix (str): Prefix of ckpt file. vocab_size (int): The shape of each embedding vector.
keep_ckpt_max (int): Max ckpt files number. hidden_size (int): Size of embedding, attention, dim.
seq_length (int): Length of input sequence. Default: 64.
vocab_size (int): The shape of each embedding vector. Default: 46192.
hidden_size (int): Size of embedding, attention, dim. Default: 512.
num_hidden_layers (int): Encoder, Decoder layers. num_hidden_layers (int): Encoder, Decoder layers.
intermediate_size (int): Size of intermediate layer in the Transformer intermediate_size (int): Size of intermediate layer in the Transformer
encoder/decoder cell. Default: 4096. encoder/decoder cell.
hidden_act (str): Activation function used in the Transformer encoder/decoder hidden_act (str): Activation function used in the Transformer encoder/decoder
cell. Default: "relu". cell.
hidden_dropout_prob (float): The dropout probability for hidden outputs.
attention_dropout_prob (float): The dropout probability for Attention module.
initializer_range (float): Initialization value of TruncatedNormal.
label_smoothing (float): Label smoothing setting.
beam_width (int): Beam width for beam search in inferring.
length_penalty_weight (float): Penalty for sentence length.
max_decode_length (int): Max decode length for inferring.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
dataset.
init_loss_scale (int): Initialized loss scale. init_loss_scale (int): Initialized loss scale.
loss_scale_factor (int): Loss scale factor. loss_scale_factor (int): Loss scale factor.
scale_window (int): Window size of loss scale. scale_window (int): Window size of loss scale.
beam_width (int): Beam width for beam search in inferring. Default: 4. lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
length_penalty_weight (float): Penalty for sentence length. Default: 1.0. optimizer (str): Optimizer for training, e.g. Adam, Lamb, momentum. Default: Adam.
label_smoothing (float): Label smoothing setting. Default: 0.1. lr (float): Initial learning rate.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from min_lr (float): Minimum learning rate.
dataset. Default: True. decay_steps (int): Decay steps.
lr_scheduler_power(float): A value used to calculate decayed learning rate.
warmup_lr_remain_steps (int or float): Start decay at 'remain_steps' iteration.
warmup_lr_decay_interval (int):interval between LR decay steps.
decay_start_step (int): Step to decay.
warmup_steps (int): Warm up steps.
existed_ckpt (str): Using existed checkpoint to keep training or not.
save_ckpt_steps (int): Interval of saving ckpt.
keep_ckpt_max (int): Max ckpt files number.
ckpt_prefix (str): Prefix of ckpt file.
ckpt_path (str): Checkpoints save path.
save_graphs (bool): Whether to save graphs, please set to True if mindinsight save_graphs (bool): Whether to save graphs, please set to True if mindinsight
is wanted. is wanted.
dtype (mstype): Data type of the input. Default: mstype.float32. dtype (mstype): Data type of the input.
max_decode_length (int): Max decode length for inferring. Default: 64.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. Note:
attention_dropout_prob (float): The dropout probability for There are three types of learning rate scheduler, square root scheduler, polynomial
Multi-head Self-Attention. Default: 0.1. decay scheduler and warmup multistep learning rate scheduler.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. In square root scheduler, the following parameters can be used, lr, decay_start_step,
warmup_steps and min_lr.
In polynomial decay scheduler, the following parameters can be used, lr, min_lr, decay_steps,
warmup_steps, lr_scheduler_power.
In warmmup multistep learning rate scheduler, the following parameters can be used, lr, warmup_steps,
warmup_lr_remain_steps, warmup_lr_decay_interval, decay_steps, lr_scheduler_power.
""" """
def __init__(self, def __init__(self,
modelarts=False, random_seed=74, random_seed=50,
epochs=6, batch_size=64, epochs=6, batch_size=128,
dataset_schema: str = None, dataset_schema: str = None,
pre_train_dataset: str = None, pre_train_dataset: str = None,
fine_tune_dataset: str = None, fine_tune_dataset: str = None,
test_dataset: str = None, test_dataset: str = None,
valid_dataset: str = None, valid_dataset: str = None,
dataset_sink_mode=True, dataset_sink_step=1, dataset_sink_mode=True,
seq_length=51, vocab_size=32320, hidden_size=1024, seq_length=51, vocab_size=32320, hidden_size=1024,
num_hidden_layers=4, intermediate_size=4096, num_hidden_layers=4, intermediate_size=4096,
hidden_act="tanh", hidden_act="tanh",
hidden_dropout_prob=0.2, attention_dropout_prob=0.2, hidden_dropout_prob=0.2, attention_dropout_prob=0.2,
initializer_range=0.1, initializer_range=0.1,
label_smoothing=0.1, label_smoothing=0.1,
beam_width=5, beam_width=2,
length_penalty_weight=1.0, length_penalty_weight=0.6,
max_decode_length=50, max_decode_length=50,
input_mask_from_dataset=False, input_mask_from_dataset=False,
init_loss_scale=2 ** 10, init_loss_scale=65536,
loss_scale_factor=2, scale_window=128, loss_scale_factor=2, scale_window=1000,
lr_scheduler="", optimizer="adam", lr_scheduler="WarmupMultiStepLR",
lr=1e-4, min_lr=1e-6, optimizer="adam",
decay_steps=4, lr_scheduler_power=1, lr=2e-3, min_lr=1e-6,
decay_steps=4, lr_scheduler_power=0.5,
warmup_lr_remain_steps=0.666, warmup_lr_decay_interval=-1, warmup_lr_remain_steps=0.666, warmup_lr_decay_interval=-1,
decay_start_step=-1, warmup_steps=200, decay_start_step=-1, warmup_steps=200,
existed_ckpt="", save_ckpt_steps=2000, keep_ckpt_max=20, existed_ckpt="", save_ckpt_steps=3452, keep_ckpt_max=6,
ckpt_prefix="gnmt", ckpt_path: str = None, ckpt_prefix="gnmt", ckpt_path: str = None,
save_step=10000,
save_graphs=False, save_graphs=False,
dtype=mstype.float32): dtype=mstype.float32):
self.save_graphs = save_graphs self.save_graphs = save_graphs
self.random_seed = random_seed self.random_seed = random_seed
self.modelarts = modelarts
self.save_step = save_step
self.dataset_schema = dataset_schema self.dataset_schema = dataset_schema
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str] self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str] self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
@ -158,7 +168,6 @@ class GNMTConfig:
self.epochs = epochs self.epochs = epochs
self.dataset_sink_mode = dataset_sink_mode self.dataset_sink_mode = dataset_sink_mode
self.dataset_sink_step = dataset_sink_step
self.ckpt_path = ckpt_path self.ckpt_path = ckpt_path
self.keep_ckpt_max = keep_ckpt_max self.keep_ckpt_max = keep_ckpt_max
@ -201,8 +210,6 @@ class GNMTConfig:
self.decay_start_step = decay_start_step self.decay_start_step = decay_start_step
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.train_url = ""
@classmethod @classmethod
def from_dict(cls, json_object: dict): def from_dict(cls, json_object: dict):
"""Constructs a `TransformerConfig` from a Python dictionary of parameters.""" """Constructs a `TransformerConfig` from a Python dictionary of parameters."""

@ -1,7 +1,4 @@
{ {
"training_platform": {
"modelarts": false
},
"dataset_config": { "dataset_config": {
"random_seed": 50, "random_seed": 50,
"epochs": 6, "epochs": 6,
@ -11,8 +8,7 @@
"fine_tune_dataset": null, "fine_tune_dataset": null,
"test_dataset": "/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001", "test_dataset": "/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001",
"valid_dataset": null, "valid_dataset": null,
"dataset_sink_mode": true, "dataset_sink_mode": true
"dataset_sink_step": 2
}, },
"model_config": { "model_config": {
"seq_length": 107, "seq_length": 107,
@ -29,9 +25,9 @@
"max_decode_length": 80 "max_decode_length": 80
}, },
"loss_scale_config": { "loss_scale_config": {
"init_loss_scale": 8192, "init_loss_scale": 65536,
"loss_scale_factor": 2, "loss_scale_factor": 2,
"scale_window": 128 "scale_window": 1000
}, },
"learn_rate_config": { "learn_rate_config": {
"optimizer": "adam", "optimizer": "adam",

@ -49,11 +49,12 @@ if __name__ == '__main__':
schema_address=args.output_folder + "/" + test_src_file + ".json" schema_address=args.output_folder + "/" + test_src_file + ".json"
) )
print(f" | It's writing, please wait a moment.") print(f" | It's writing, please wait a moment.")
test.write_to_tfrecord( test.write_to_mindrecord(
path=os.path.join( path=os.path.join(
args.output_folder, args.output_folder,
os.path.basename(test_src_file) + ".tfrecord" os.path.basename(test_src_file) + ".mindrecord"
) ),
train_mode=False
) )
train = BiLingualDataLoader( train = BiLingualDataLoader(
@ -65,11 +66,12 @@ if __name__ == '__main__':
schema_address=args.output_folder + "/" + train_src_file + ".json" schema_address=args.output_folder + "/" + train_src_file + ".json"
) )
print(f" | It's writing, please wait a moment.") print(f" | It's writing, please wait a moment.")
train.write_to_tfrecord( train.write_to_mindrecord(
path=os.path.join( path=os.path.join(
args.output_folder, args.output_folder,
os.path.basename(train_src_file) + ".tfrecord" os.path.basename(train_src_file) + ".mindrecord"
) ),
train_mode=True
) )
print(f" | Vocabulary size: {tokenizer.vocab_size}.") print(f" | Vocabulary size: {tokenizer.vocab_size}.")

@ -14,16 +14,16 @@
# ============================================================================ # ============================================================================
"""Base class of data loader.""" """Base class of data loader."""
import os import os
import collections
import numpy as np import numpy as np
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
from .schema import SCHEMA from .schema import SCHEMA, TEST_SCHEMA
class DataLoader: class DataLoader:
"""Data loader for dataset.""" """Data loader for dataset."""
_SCHEMA = SCHEMA _SCHEMA = SCHEMA
_TEST_SCHEMA = TEST_SCHEMA
def __init__(self): def __init__(self):
self._examples = [] self._examples = []
@ -41,7 +41,7 @@ class DataLoader:
new_sen[:sen.shape[0]] = sen[:] new_sen[:sen.shape[0]] = sen[:]
return new_sen return new_sen
def write_to_mindrecord(self, path, shard_num=1, desc=""): def write_to_mindrecord(self, path, train_mode, shard_num=1, desc="gnmt"):
""" """
Write mindrecord file. Write mindrecord file.
@ -54,7 +54,10 @@ class DataLoader:
path = os.path.abspath(path) path = os.path.abspath(path)
writer = FileWriter(file_name=path, shard_num=shard_num) writer = FileWriter(file_name=path, shard_num=shard_num)
writer.add_schema(self._SCHEMA, desc) if train_mode:
writer.add_schema(self._SCHEMA, desc)
else:
writer.add_schema(self._TEST_SCHEMA, desc)
if not self._examples: if not self._examples:
self._load() self._load()
@ -62,41 +65,5 @@ class DataLoader:
writer.commit() writer.commit()
print(f"| Wrote to {path}.") print(f"| Wrote to {path}.")
def write_to_tfrecord(self, path, shard_num=1):
"""
Write to tfrecord.
Args:
path (str): Output file path.
shard_num (int): Shard num.
"""
import tensorflow as tf
if not os.path.isabs(path):
path = os.path.abspath(path)
output_files = []
for i in range(shard_num):
output_file = path + "-%03d-of-%03d" % (i + 1, shard_num)
output_files.append(output_file)
# create writers
writers = []
for output_file in output_files:
writers.append(tf.io.TFRecordWriter(output_file))
if not self._examples:
self._load()
# create feature
features = collections.OrderedDict()
for example in self._examples:
for key in example:
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=example[key].tolist()))
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
for writer in writers:
writer.write(tf_example.SerializeToString())
for writer in writers:
writer.close()
for p in output_files:
print(f" | Write to {p}.")
def _add_example(self, example): def _add_example(self, example):
self._examples.append(example) self._examples.append(example)

@ -19,9 +19,9 @@ import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC import mindspore.dataset.transforms.c_transforms as deC
def _load_dataset(input_files, schema_file, batch_size, epoch_count=1, def _load_dataset(input_files, schema_file, batch_size, sink_mode=False,
sink_mode=False, sink_step=1, rank_size=1, rank_id=0, shuffle=True, rank_size=1, rank_id=0, shuffle=True, drop_remainder=True,
drop_remainder=True, is_translate=False): is_translate=False):
""" """
Load dataset according to passed in params. Load dataset according to passed in params.
@ -29,9 +29,7 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
input_files (list): Data files. input_files (list): Data files.
schema_file (str): Schema file path. schema_file (str): Schema file path.
batch_size (int): Batch size. batch_size (int): Batch size.
epoch_count (int): Epoch count.
sink_mode (bool): Whether enable sink mode. sink_mode (bool): Whether enable sink mode.
sink_step (int): Step to sink.
rank_size (int): Rank size. rank_size (int): Rank size.
rank_id (int): Rank id. rank_id (int): Rank id.
shuffle (bool): Whether shuffle dataset. shuffle (bool): Whether shuffle dataset.
@ -57,15 +55,14 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
print(f" | Loading {datafile}.") print(f" | Loading {datafile}.")
if not is_translate: if not is_translate:
ds = de.TFRecordDataset( ds = de.MindDataset(
input_files, schema_file, input_files, columns_list=[
columns_list=[
"src", "src_padding", "src", "src_padding",
"prev_opt", "prev_opt",
"target", "tgt_padding" "target", "tgt_padding"
], ], shuffle=False, num_shards=rank_size, shard_id=rank_id,
shuffle=False, num_shards=rank_size, shard_id=rank_id, num_parallel_workers=8
shard_equal_rows=True, num_parallel_workers=8) )
ori_dataset_size = ds.get_dataset_size() ori_dataset_size = ds.get_dataset_size()
print(f" | Dataset size: {ori_dataset_size}.") print(f" | Dataset size: {ori_dataset_size}.")
@ -92,13 +89,13 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
) )
ds = ds.batch(batch_size, drop_remainder=drop_remainder) ds = ds.batch(batch_size, drop_remainder=drop_remainder)
else: else:
ds = de.TFRecordDataset( ds = de.MindDataset(
input_files, schema_file, input_files, columns_list=[
columns_list=[
"src", "src_padding" "src", "src_padding"
], ],
shuffle=False, num_shards=rank_size, shard_id=rank_id, shuffle=False, num_shards=rank_size, shard_id=rank_id,
shard_equal_rows=True, num_parallel_workers=8) num_parallel_workers=8
)
ori_dataset_size = ds.get_dataset_size() ori_dataset_size = ds.get_dataset_size()
print(f" | Dataset size: {ori_dataset_size}.") print(f" | Dataset size: {ori_dataset_size}.")
@ -119,7 +116,7 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
return ds return ds
def load_dataset(data_files: list, schema: str, batch_size: int, epoch_count: int, sink_mode: bool, sink_step: int = 1, def load_dataset(data_files: list, schema: str, batch_size: int, sink_mode: bool,
rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False): rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False):
""" """
Load dataset. Load dataset.
@ -128,9 +125,7 @@ def load_dataset(data_files: list, schema: str, batch_size: int, epoch_count: in
data_files (list): Data files. data_files (list): Data files.
schema (str): Schema file path. schema (str): Schema file path.
batch_size (int): Batch size. batch_size (int): Batch size.
epoch_count (int): Epoch count.
sink_mode (bool): Whether enable sink mode. sink_mode (bool): Whether enable sink mode.
sink_step (int): Step to sink.
rank_size (int): Rank size. rank_size (int): Rank size.
rank_id (int): Rank id. rank_id (int): Rank id.
shuffle (bool): Whether shuffle dataset. shuffle (bool): Whether shuffle dataset.
@ -138,6 +133,5 @@ def load_dataset(data_files: list, schema: str, batch_size: int, epoch_count: in
Returns: Returns:
Dataset, dataset instance. Dataset, dataset instance.
""" """
return _load_dataset(data_files, schema, batch_size, epoch_count, sink_mode, return _load_dataset(data_files, schema, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle,
sink_step, rank_size, rank_id, shuffle=shuffle,
drop_remainder=drop_remainder, is_translate=is_translate) drop_remainder=drop_remainder, is_translate=is_translate)

@ -21,3 +21,8 @@ SCHEMA = {
"target": {"type": "int64", "shape": [-1]}, "target": {"type": "int64", "shape": [-1]},
"tgt_padding": {"type": "int64", "shape": [-1]}, "tgt_padding": {"type": "int64", "shape": [-1]},
} }
TEST_SCHEMA = {
"src": {"type": "int64", "shape": [-1]},
"src_padding": {"type": "int64", "shape": [-1]},
}

@ -189,7 +189,6 @@ def infer(config):
eval_dataset = load_dataset(data_files=config.test_dataset, eval_dataset = load_dataset(data_files=config.test_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=1,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
drop_remainder=False, drop_remainder=False,
is_translate=True, is_translate=True,

@ -16,8 +16,6 @@
import time import time
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.communication.management import get_rank
from config import GNMTConfig from config import GNMTConfig
@ -51,9 +49,6 @@ class LossCallBack(Callback):
"""step end.""" """step end."""
cb_params = run_context.original_args() cb_params = run_context.original_args()
file_name = "./loss.log" file_name = "./loss.log"
if self.config.modelarts:
import os
file_name = "/home/work/workspace/loss/loss_{}.log".format(os.getenv('DEVICE_ID'))
with open(file_name, "a+") as f: with open(file_name, "a+") as f:
time_stamp_current = self._get_ms_timestamp() time_stamp_current = self._get_ms_timestamp()
f.write("time: {}, epoch: {}, step: {}, outputs: [loss: {}, overflow: {}, loss scale value: {} ].\n".format( f.write("time: {}, epoch: {}, step: {}, outputs: [loss: {}, overflow: {}, loss scale value: {} ].\n".format(
@ -65,14 +60,6 @@ class LossCallBack(Callback):
str(cb_params.net_outputs[2].asnumpy()) str(cb_params.net_outputs[2].asnumpy())
)) ))
if self.config.modelarts:
from modelarts.data_util import upload_output
rank_id = get_rank()
if cb_params.cur_step_num % self.config.save_step == 1 \
and cb_params.cur_step_num != 1 and rank_id in [0, 8]:
upload_output("/home/work/workspace/loss", self.config.train_url)
upload_output("/cache/ckpt_0", self.config.train_url)
@staticmethod @staticmethod
def _get_ms_timestamp(): def _get_ms_timestamp():
t = time.time() t = time.time()

@ -87,10 +87,7 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
# validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
# validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
# validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
# validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name): def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name):

@ -169,9 +169,7 @@ def _get_optimizer(config, network, lr):
if config.optimizer.lower() == "adam": if config.optimizer.lower() == "adam":
optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98) optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
elif config.optimizer.lower() == "lamb": elif config.optimizer.lower() == "lamb":
optimizer = Lamb(network.trainable_params(), decay_steps=12000, optimizer = Lamb(network.trainable_params(), learning_rate=lr,
start_learning_rate=config.lr, end_learning_rate=config.min_lr,
power=10.0, warmup_steps=config.warmup_steps, weight_decay=0.01,
eps=1e-6) eps=1e-6)
elif config.optimizer.lower() == "momentum": elif config.optimizer.lower() == "momentum":
optimizer = Momentum(network.trainable_params(), lr, momentum=0.9) optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
@ -277,25 +275,21 @@ def train_parallel(config: GNMTConfig):
data_files=config.pre_train_dataset, data_files=config.pre_train_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank() rank_id=MultiAscend.get_rank()
) if config.pre_train_dataset else None ) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset( fine_tune_dataset = load_dataset(
data_files=config.fine_tune_dataset, schema=config.dataset_schema, data_files=config.fine_tune_dataset, schema=config.dataset_schema,
batch_size=config.batch_size, epoch_count=config.epochs, batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank() rank_id=MultiAscend.get_rank()
) if config.fine_tune_dataset else None ) if config.fine_tune_dataset else None
test_dataset = load_dataset( test_dataset = load_dataset(
data_files=config.test_dataset, schema=config.dataset_schema, data_files=config.test_dataset, schema=config.dataset_schema,
batch_size=config.batch_size, epoch_count=config.epochs, batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank() rank_id=MultiAscend.get_rank()
) if config.test_dataset else None ) if config.test_dataset else None
@ -318,21 +312,15 @@ def train_single(config: GNMTConfig):
pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs, sink_mode=config.dataset_sink_mode) if config.pre_train_dataset else None
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs, sink_mode=config.dataset_sink_mode) if config.fine_tune_dataset else None
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None
test_dataset = load_dataset(data_files=config.test_dataset, test_dataset = load_dataset(data_files=config.test_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs, sink_mode=config.dataset_sink_mode) if config.test_dataset else None
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.test_dataset else None
_build_training_pipeline(config=config, _build_training_pipeline(config=config,
pre_training_dataset=pre_train_dataset, pre_training_dataset=pre_train_dataset,

Loading…
Cancel
Save