From 9fb6f0c34b9c79799652e77025bf10b6e4414102 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Thu, 27 Aug 2020 20:24:05 +0800 Subject: [PATCH] fix sink_size bug for transformer --- model_zoo/official/nlp/transformer/train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/nlp/transformer/train.py b/model_zoo/official/nlp/transformer/train.py index f4144515f4..2ac2c09369 100644 --- a/model_zoo/official/nlp/transformer/train.py +++ b/model_zoo/official/nlp/transformer/train.py @@ -170,8 +170,14 @@ def run_transformer_train(): netwithgrads.set_train(True) model = Model(netwithgrads) - model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"), - sink_size=args.save_checkpoint_steps) + + enable_sink = (args.enable_data_sink == "true") + if enable_sink: + sink_size = args.save_checkpoint_steps + model.train(args.epoch_size*dataset.get_dataset_size()//sink_size, dataset, callbacks=callbacks, + dataset_sink_mode=enable_sink, sink_size=sink_size) + else: + model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=enable_sink) if __name__ == '__main__': run_transformer_train()