diff --git a/model_zoo/official/cv/deeplabv3/train.py b/model_zoo/official/cv/deeplabv3/train.py index aa2fe888e1..99f62cf14e 100644 --- a/model_zoo/official/cv/deeplabv3/train.py +++ b/model_zoo/official/cv/deeplabv3/train.py @@ -170,7 +170,7 @@ def train(): ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck) cbs.append(ckpoint_cb) - model.train(args.train_epochs, dataset, callbacks=cbs) + model.train(args.train_epochs, dataset, callbacks=cbs, dataset_sink_mode=(args.device_target != "CPU")) if __name__ == '__main__':