# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """momentum""" import mindspore.common.dtype as mstype from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.common.parameter import ParameterTuple from mindspore.common.tensor import Tensor from mindspore.nn.optim.optimizer import Optimizer from mindspore.ops import functional as F, composite as C, operations as P from mindspore.parallel._utils import _get_device_num, _get_mirror_mean from src.grad_reducer_thor import DistributedGradReducerThor momentum_opt = C.MultitypeFuncGraph("momentum_opt") @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): """Apply momentum optimizer to the weight parameter using Tensor.""" success = True success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) return success op_add = P.AddN() apply_decay = C.MultitypeFuncGraph("apply_decay") @apply_decay.register("Number", "Bool", "Tensor", "Tensor") def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: return op_add((weight * weight_decay, gradient)) return gradient class THOR(Optimizer): """THOR""" def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, loss_scale=1.0, decay_filter=lambda x: x.name not in []): super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) if isinstance(momentum, float) and momentum < 0.0: raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.params = self.parameters self.moments = self.params.clone(prefix="moments", init='zeros') self.hyper_map = C.HyperMap() self.opt = P.ApplyMomentum() self.matrix_A = ParameterTuple(matrix_A) self.matrix_G = ParameterTuple(matrix_G) self.A_inv_max = ParameterTuple(A_inv_max) self.G_inv_max = ParameterTuple(G_inv_max) self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() self.transpose = P.Transpose() self.shape = P.Shape() self.reshape = P.Reshape() self.mul = P.Mul() self.weight_idx = [] for i in range(len(self.params)): if "conv" in self.params[i].name or "end_point" in self.params[i].name: self.weight_idx.append(i) self.weight_idx.append(len(self.params)) self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0] mean = _get_mirror_mean() degree = _get_device_num() self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree) self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree) self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree) self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree) self.matrix_A_inv = () self.matrix_G_inv = () self.matrix_max_inv = () for i in range(54): self.matrix_max_inv = self.matrix_max_inv + ( Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) self.log = P.Log() self.exp = P.Exp() self.sqrt = P.Sqrt() self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) self.assign = P.Assign() self.cast = P.Cast() self.thor = True self.weight_decay = weight_decay * loss_scale self.decay_flags = tuple(decay_filter(x) for x in self.parameters) def construct(self, gradients): params = self.params moments = self.moments if self.thor: matrix_A_allreduce = () matrix_G_allreduce = () matrix_A_max_allreduce = () matrix_G_max_allreduce = () for i in range(54): g = gradients[i * 3] matrix_A = self.matrix_A[i] matrix_G = self.matrix_G[i] A_max = self.A_inv_max[i] G_max = self.G_inv_max[i] matrix_A = F.depend(matrix_A, g) matrix_G = F.depend(matrix_G, g) A_max = F.depend(A_max, g) G_max = F.depend(G_max, g) matrix_A_allreduce = matrix_A_allreduce + (matrix_A,) matrix_G_allreduce = matrix_G_allreduce + (matrix_G,) matrix_A_max_allreduce = matrix_A_max_allreduce + (A_max,) matrix_G_max_allreduce = matrix_G_max_allreduce + (G_max,) matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce) matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce) matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce) matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce) new_grads = () for i in range(54): g = gradients[i * 3] temp_a = matrix_A_allreduce[i] temp_g = matrix_G_allreduce[i] temp_a = self.cast(temp_a, mstype.float32) temp_g = self.cast(temp_g, mstype.float32) matrix_A_inv_max = self.log(matrix_A_max_allreduce[i]) matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) matrix_A_inv_max = self.exp(matrix_A_inv_max) temp_a = self.mul(temp_a, matrix_A_inv_max) matrix_G_inv_max = self.log(matrix_G_max_allreduce[i]) matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) matrix_G_inv_max = self.exp(matrix_G_inv_max) temp_g = self.mul(temp_g, matrix_G_inv_max) temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i]) temp_max = self.mul(temp_max, self.feature_map[i]) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) if i == 53: g = self.cube_matmul_left_fc(temp_g, g) g = self.cube_matmul_right_fc(g, temp_a, temp_max) else: g = self.cube_matmul_left(temp_g, g) g = self.cube_matmul_right_mul(g, temp_a, temp_max) fake_A = self.assign(self.matrix_A[i], temp_a) fake_G = self.assign(self.matrix_G[i], temp_g) fake_max = self.assign(self.matrix_max_inv[i], temp_max) g = F.depend(g, fake_A) g = F.depend(g, fake_G) g = F.depend(g, fake_max) if i == 53: new_grads = new_grads + (g,) else: new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2]) gradients = new_grads else: new_grads = () for i in range(54): g = gradients[i * 3] matrix_A = self.matrix_A[i] matrix_G = self.matrix_G[i] matrix_max = self.matrix_max_inv[i] matrix_A = F.depend(matrix_A, g) matrix_G = F.depend(matrix_G, g) matrix_max = F.depend(matrix_max, g) if i == 53: g = self.cube_matmul_left_fc(matrix_G, g) g = self.cube_matmul_right_fc(g, matrix_A, matrix_max) new_grads = new_grads + (g,) else: g = self.cube_matmul_left(matrix_G, g) g = self.cube_matmul_right_mul(g, matrix_A, matrix_max) new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2]) gradients = new_grads if self.weight_decay > 0: gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) return success