diff --git a/example/resnet50_imagenet2012_THOR/model/thor.py b/example/resnet50_imagenet2012_THOR/model/thor.py index 0da1714fe6..6786cb7485 100644 --- a/example/resnet50_imagenet2012_THOR/model/thor.py +++ b/example/resnet50_imagenet2012_THOR/model/thor.py @@ -151,6 +151,8 @@ class THOR(Optimizer): temp_g = self.mul(temp_g, matrix_G_inv_max) temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i]) temp_max = self.mul(temp_max, self.feature_map[i]) + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) if i == 53: g = self.cube_matmul_left_fc(temp_g, g) g = self.cube_matmul_right_fc(g, temp_a, temp_max) diff --git a/example/resnet50_imagenet2012_THOR/model/thor_layer.py b/example/resnet50_imagenet2012_THOR/model/thor_layer.py index fea74605b6..995c7b01d0 100644 --- a/example/resnet50_imagenet2012_THOR/model/thor_layer.py +++ b/example/resnet50_imagenet2012_THOR/model/thor_layer.py @@ -171,7 +171,6 @@ class Conv2d_Thor(_Conv): self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) self.fake_G = Tensor( np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape)) - self.fake_G_inv_max = Tensor(np.zeros([1,]).astype(np.float32)) self.shape = P.Shape() self.reshape = P.Reshape() @@ -286,7 +285,6 @@ class Conv2d_Thor(_Conv): matrix_A_inv = self.device_shape_pad(matrix_A_inv) matrix_A_inv = self.reshape(matrix_A_inv, self.matrix_A_device_temp_shape) matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3)) - self.G_inv_max = self.fake_G_inv_max self.matrix_A_inv = matrix_A_inv self.matrix_G_inv = self.fake_G out = self.conv2d(x, self.weight)