@ -170,7 +170,7 @@ def train():
ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck)
cbs.append(ckpoint_cb)
model.train(args.train_epochs, dataset, callbacks=cbs)
model.train(args.train_epochs, dataset, callbacks=cbs, dataset_sink_mode=(args.device_target != "CPU"))
if __name__ == '__main__':