|
|
|
@ -26,7 +26,7 @@ from mindspore.communication.management import get_group_size
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from utils import ClipByGlobalNorm
|
|
|
|
|
from src.utils import ClipByGlobalNorm
|
|
|
|
|
|
|
|
|
|
GRADIENT_CLIP_TYPE = 1
|
|
|
|
|
GRADIENT_CLIP_VALUE = 1.0
|
|
|
|
@ -77,6 +77,7 @@ class GPTTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
|
|
|
|
|
super(GPTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.add_flags(defer_inline=True)
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.enable_global_norm = enable_global_norm
|
|
|
|
|