|
|
@ -13,7 +13,7 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# ============================================================================
|
|
|
|
# ============================================================================
|
|
|
|
"""Transformer evaluation script."""
|
|
|
|
"""Transformer evaluation script."""
|
|
|
|
|
|
|
|
import os
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
@ -41,8 +41,13 @@ def run_gru_eval():
|
|
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
|
|
|
|
device_id=args.device_id, save_graphs=False)
|
|
|
|
device_id=args.device_id, save_graphs=False)
|
|
|
|
|
|
|
|
prefix = "multi30k_test_mindrecord_32"
|
|
|
|
|
|
|
|
mindrecord_file = os.path.join(args.dataset_path, prefix)
|
|
|
|
|
|
|
|
if not os.path.exists(mindrecord_file):
|
|
|
|
|
|
|
|
print("dataset file {} not exists, please check!".format(mindrecord_file))
|
|
|
|
|
|
|
|
raise ValueError(mindrecord_file)
|
|
|
|
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
|
|
|
|
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
|
|
|
|
dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
|
|
|
|
dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
|
|
|
|
dataset_size = dataset.get_dataset_size()
|
|
|
|
dataset_size = dataset.get_dataset_size()
|
|
|
|
print("dataset size is {}".format(dataset_size))
|
|
|
|
print("dataset size is {}".format(dataset_size))
|
|
|
|
network = Seq2Seq(config, is_training=False)
|
|
|
|
network = Seq2Seq(config, is_training=False)
|
|
|
|