|
|
@ -76,8 +76,8 @@ if __name__ == "__main__":
|
|
|
|
gradients_mean=True)
|
|
|
|
gradients_mean=True)
|
|
|
|
init()
|
|
|
|
init()
|
|
|
|
elif device_target == "GPU":
|
|
|
|
elif device_target == "GPU":
|
|
|
|
init()
|
|
|
|
|
|
|
|
if device_num > 1:
|
|
|
|
if device_num > 1:
|
|
|
|
|
|
|
|
init()
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
gradients_mean=True)
|
|
|
|
gradients_mean=True)
|
|
|
|