From 1fd05aeb1861e9b79f2cac25b230b0a973010d7e Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Mon, 30 Nov 2020 23:15:41 -0500 Subject: [PATCH] Fix the bug of broadcasting the parameters in IGamma and LBeta --- mindspore/nn/layer/math.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index ec759a5c15..ade0c977cb 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -669,9 +669,11 @@ class IGamma(Cell): domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) ax = a * self.log(x) - x - self.lgamma(a) - boradcastto = P.BroadcastTo(self.shape(ax)) - a = boradcastto(a) - x = boradcastto(x) + para_shape = self.shape(ax) + boradcastto = P.BroadcastTo(para_shape) + if para_shape != (): + x = boradcastto(x) + y = boradcastto(y) x_is_zero = self.equal(x, 0) if a_dtype == mstype.float16: log_maxfloat = self.log_maxfloat16 @@ -744,9 +746,11 @@ class LBeta(Cell): _check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name) _check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name) x_plus_y = x + y - boradcastto = P.BroadcastTo(self.shape(x_plus_y)) - x = boradcastto(x) - y = boradcastto(y) + para_shape = self.shape(x_plus_y) + boradcastto = P.BroadcastTo(para_shape) + if para_shape != (): + x = boradcastto(x) + y = boradcastto(y) comp_less = self.less(x, y) x_min = self.select(comp_less, x, y) y_max = self.select(comp_less, y, x)