|
|
|
@ -166,7 +166,8 @@ def run_transformer_train():
|
|
|
|
|
|
|
|
|
|
netwithgrads.set_train(True)
|
|
|
|
|
model = Model(netwithgrads)
|
|
|
|
|
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
|
|
|
|
|
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"),
|
|
|
|
|
sink_size=args.save_checkpoint_steps)
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
run_transformer_train()
|
|
|
|
|