|
|
|
@ -24,12 +24,13 @@ from mindspore.nn.optim import Adam
|
|
|
|
|
from mindspore.train.model import Model
|
|
|
|
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
|
|
|
|
from mindspore.train.callback import Callback
|
|
|
|
|
import mindspore.dataset.engine as de
|
|
|
|
|
import mindspore.dataset.transforms.c_transforms as deC
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from model_zoo.official.nlp.transformer.src.transformer_model import TransformerConfig
|
|
|
|
|
from model_zoo.official.nlp.transformer.src.transformer_for_train import TransformerNetworkWithLoss, \
|
|
|
|
|
TransformerTrainOneStepWithLossScaleCell
|
|
|
|
|
from model_zoo.official.nlp.transformer.src.config import cfg
|
|
|
|
|
from model_zoo.official.nlp.transformer.src.dataset import create_transformer_dataset
|
|
|
|
|
from model_zoo.official.nlp.transformer.src.config import cfg, transformer_net_cfg
|
|
|
|
|
from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr
|
|
|
|
|
|
|
|
|
|
DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"]
|
|
|
|
@ -76,6 +77,24 @@ def get_config(version='base', batch_size=1):
|
|
|
|
|
transformer_cfg = TransformerConfig(batch_size=batch_size)
|
|
|
|
|
return transformer_cfg
|
|
|
|
|
|
|
|
|
|
def load_test_data(batch_size=1, data_file=None):
|
|
|
|
|
"""Load test dataset."""
|
|
|
|
|
ds = de.MindDataset(data_file,
|
|
|
|
|
columns_list=["source_eos_ids", "source_eos_mask",
|
|
|
|
|
"target_sos_ids", "target_sos_mask",
|
|
|
|
|
"target_eos_ids", "target_eos_mask"],
|
|
|
|
|
shuffle=False)
|
|
|
|
|
type_cast_op = deC.TypeCast(mstype.int32)
|
|
|
|
|
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
|
|
|
|
|
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
|
|
|
|
|
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
|
|
|
|
|
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
|
|
|
|
|
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
|
|
|
|
|
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
|
|
|
|
|
# apply batch operations
|
|
|
|
|
ds = ds.batch(batch_size, drop_remainder=True)
|
|
|
|
|
return ds
|
|
|
|
|
|
|
|
|
|
class ModelCallback(Callback):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(ModelCallback, self).__init__()
|
|
|
|
@ -120,10 +139,7 @@ def test_transformer():
|
|
|
|
|
batch_size = 96
|
|
|
|
|
epoch_size = 3
|
|
|
|
|
config = get_config(version=version, batch_size=batch_size)
|
|
|
|
|
dataset = create_transformer_dataset(epoch_count=1,
|
|
|
|
|
do_shuffle="false",
|
|
|
|
|
enable_data_sink="false",
|
|
|
|
|
dataset_path=DATA_DIR)
|
|
|
|
|
dataset = load_test_data(batch_size=transformer_net_cfg.batch_size, data_file=DATA_DIR)
|
|
|
|
|
|
|
|
|
|
netwithloss = TransformerNetworkWithLoss(config, True)
|
|
|
|
|
|
|
|
|
|