fix thor train failed

pull/8654/head
wangmin 4 years ago
parent 80b5b86fe1
commit 41bdb1e4cf

@ -118,7 +118,7 @@ class Model_Thor(Model):
dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
if dataset_sink_mode:
if dataset_sink_mode and context.get_context("device_target") != "GPU":
network = connect_network_with_dataset(network, dataset_helper)
network.set_train(is_train)
network.phase = phase

Loading…
Cancel
Save