|
|
|
@ -117,6 +117,7 @@ _cast_before_mirror = None
|
|
|
|
|
_loss_repeated_mean = None
|
|
|
|
|
_communication_backend = None
|
|
|
|
|
_has_checkpointed = False
|
|
|
|
|
_enable_all_reduce_fusion = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _checkpoint_auto_parallel_context():
|
|
|
|
@ -133,6 +134,7 @@ def _checkpoint_auto_parallel_context():
|
|
|
|
|
global _cast_before_mirror
|
|
|
|
|
global _loss_repeated_mean
|
|
|
|
|
global _communication_backend
|
|
|
|
|
global _enable_all_reduce_fusion
|
|
|
|
|
_parallel_mode = auto_parallel_context().get_parallel_mode()
|
|
|
|
|
_device_num = _get_device_num()
|
|
|
|
|
_global_rank = _get_global_rank()
|
|
|
|
@ -141,6 +143,7 @@ def _checkpoint_auto_parallel_context():
|
|
|
|
|
_cast_before_mirror = auto_parallel_context().get_cast_before_mirror()
|
|
|
|
|
_loss_repeated_mean = auto_parallel_context().get_loss_repeated_mean()
|
|
|
|
|
_communication_backend = auto_parallel_context().get_communication_backend()
|
|
|
|
|
_enable_all_reduce_fusion = auto_parallel_context().get_enable_all_reduce_fusion()
|
|
|
|
|
_has_checkpointed = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -154,10 +157,12 @@ def _restore_auto_parallel_context():
|
|
|
|
|
global _cast_before_mirror
|
|
|
|
|
global _loss_repeated_mean
|
|
|
|
|
global _communication_backend
|
|
|
|
|
global _enable_all_reduce_fusion
|
|
|
|
|
_set_auto_parallel_context(parallel_mode=_parallel_mode, device_num=_device_num, global_rank=_global_rank,
|
|
|
|
|
parameter_broadcast=_parameter_broadcast, mirror_mean=_mirror_mean,
|
|
|
|
|
cast_before_mirror=_cast_before_mirror, loss_repeated_mean=_loss_repeated_mean)
|
|
|
|
|
auto_parallel_context().set_communication_backend(_communication_backend)
|
|
|
|
|
auto_parallel_context().set_enable_all_reduce_fusion(_enable_all_reduce_fusion)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _reset_checkpoint_auto_parallel_context():
|
|
|
|
|