diff --git a/model_zoo/official/cv/resnet_thor/src/model_thor.py b/model_zoo/official/cv/resnet_thor/src/model_thor.py index e9b4f32e24..36d457e756 100644 --- a/model_zoo/official/cv/resnet_thor/src/model_thor.py +++ b/model_zoo/official/cv/resnet_thor/src/model_thor.py @@ -118,7 +118,7 @@ class Model_Thor(Model): dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order) - if dataset_sink_mode: + if dataset_sink_mode and context.get_context("device_target") != "GPU": network = connect_network_with_dataset(network, dataset_helper) network.set_train(is_train) network.phase = phase