BUGFIX: Correct the calculation of the output 'd_x' of the operator LayernormGradGrad

pull/9979/head
hedongdong 4 years ago
parent 855834e7cd
commit 6cc9d2c087

@ -58,8 +58,7 @@ class LayerNormGradGradGpuKernel : public GpuKernel {
cudaMemsetAsync(global_sum2, 0, input_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemsetAsync global_sum2 failed");
const T epsilon = 10e-12;
LayerNormGradGrad(input_row_, input_col_, param_dim_, global_sum1, global_sum2, epsilon, dy, x, mean, var, gamma,
LayerNormGradGrad(input_row_, input_col_, param_dim_, global_sum1, global_sum2, epsilon_, dy, x, mean, var, gamma,
grad_dx, grad_dg, grad_db, d_dy, d_x, d_gamma, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -88,6 +87,12 @@ class LayerNormGradGradGpuKernel : public GpuKernel {
param_dim_ *= input_shape[i];
}
epsilon_ = 1e-12;
auto type_id = TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
if (std::strncmp(type_id, "kNumberTypeFloat16", std::strlen(type_id)) == 0) {
epsilon_ = 1e-7;
}
InitSizeLists();
return true;
}
@ -122,6 +127,7 @@ class LayerNormGradGradGpuKernel : public GpuKernel {
int input_col_;
int param_dim_;
int input_size_;
T epsilon_;
};
} // namespace kernel
} // namespace mindspore

@ -696,13 +696,13 @@ def get_bprop_layer_norm(self):
@bprop_getters.register(G.LayerNormGrad)
def get_bprop_layer_norm_grad(self):
"""Grad definition for `LayerNorm` operation."""
"""Grad definition for `LayerNormGrad` operation."""
layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis)
def bprop(x, dy, variance, mean, gamma, out, dout):
d_x, d_dy, d_gamma = layer_norm_grad_grad(
x, dy, variance, mean, gamma, dout[0], dout[1], dout[2])
return d_x, d_dy, d_gamma, zeros_like(variance), zeros_like(mean)
return d_x, d_dy, zeros_like(variance), zeros_like(mean), d_gamma
return bprop

@ -1105,6 +1105,7 @@ class LayerNormGradGrad(PrimitiveWithInfer):
def __call__(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
raise NotImplementedError
def infer_shape(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
return x, dy, gamma

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save