|
|
|
@ -377,21 +377,16 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|
|
|
|
sens=None):
|
|
|
|
|
"""Defines the computation performed."""
|
|
|
|
|
weights = self.weights
|
|
|
|
|
saved = ()
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
|
|
|
|
F.assign(self.saved_params[i], weights[i])
|
|
|
|
|
|
|
|
|
|
for i in range(self.quant_embedding_list_length):
|
|
|
|
|
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
|
|
|
|
quant_embedding = F.depend(quant_embedding, saved)
|
|
|
|
|
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
|
|
|
|
input_ids = F.depend(input_ids, assign_embedding)
|
|
|
|
|
F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
|
|
|
|
|
|
|
|
|
for i in range(self.quant_weight_list_length):
|
|
|
|
|
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
|
|
|
|
quant_weight = F.depend(quant_weight, saved)
|
|
|
|
|
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
|
|
|
|
input_ids = F.depend(input_ids, assign_weight)
|
|
|
|
|
F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
|
|
|
|
|
|
|
|
|
if sens is None:
|
|
|
|
|
scaling_sens = self.loss_scale
|
|
|
|
@ -411,10 +406,10 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
|
|
|
|
restore = ()
|
|
|
|
|
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
weights[i] = F.depend(weights[i], grads)
|
|
|
|
|
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
|
|
|
|
param = F.depend(self.saved_params[i], grads)
|
|
|
|
|
F.assign(weights[i], param)
|
|
|
|
|
|
|
|
|
|
self.get_status(init)
|
|
|
|
|
flag_sum = self.reduce_sum(init, (0,))
|
|
|
|
@ -431,7 +426,6 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|
|
|
|
succ = False
|
|
|
|
|
else:
|
|
|
|
|
succ = self.optimizer(grads)
|
|
|
|
|
succ = F.depend(succ, restore)
|
|
|
|
|
return succ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -490,21 +484,16 @@ class BertTrainCell(nn.Cell):
|
|
|
|
|
label_ids):
|
|
|
|
|
"""Defines the computation performed."""
|
|
|
|
|
weights = self.weights
|
|
|
|
|
saved = ()
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
|
|
|
|
F.assign(self.saved_params[i], weights[i])
|
|
|
|
|
|
|
|
|
|
for i in range(self.quant_embedding_list_length):
|
|
|
|
|
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
|
|
|
|
quant_embedding = F.depend(quant_embedding, saved)
|
|
|
|
|
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
|
|
|
|
input_ids = F.depend(input_ids, assign_embedding)
|
|
|
|
|
F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
|
|
|
|
|
|
|
|
|
for i in range(self.quant_weight_list_length):
|
|
|
|
|
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
|
|
|
|
quant_weight = F.depend(quant_weight, saved)
|
|
|
|
|
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
|
|
|
|
input_ids = F.depend(input_ids, assign_weight)
|
|
|
|
|
F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
|
|
|
|
|
|
|
|
|
grads = self.grad(self.network, weights)(input_ids,
|
|
|
|
|
input_mask,
|
|
|
|
@ -515,11 +504,10 @@ class BertTrainCell(nn.Cell):
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
|
|
|
|
restore = ()
|
|
|
|
|
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
weights[i] = F.depend(weights[i], grads)
|
|
|
|
|
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
|
|
|
|
param = F.depend(self.saved_params[i], grads)
|
|
|
|
|
F.assign(weights[i], param)
|
|
|
|
|
|
|
|
|
|
succ = self.optimizer(grads)
|
|
|
|
|
succ = F.depend(succ, restore)
|
|
|
|
|
return succ
|
|
|
|
|