diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index a92094f50d..859c175d24 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -333,9 +333,10 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p w_axis = 2; } int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group"); - if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) { - MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis] - << " (channels) must be divisible by group = " << group; + if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) && + ((x_shape[c_axis] / group) != w_shape[c_axis])) { + MS_LOG(EXCEPTION) << "x_shape[C_in] / group must equal to w_shape[C_in] = " << w_shape[c_axis] << ", but got " + << (x_shape[c_axis] / group); } int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel"); if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) { diff --git a/tests/ut/python/model/test_lenet_core_after_exception.py b/tests/ut/python/model/test_lenet_core_after_exception.py index fdc6e81ab7..fde3147d79 100644 --- a/tests/ut/python/model/test_lenet_core_after_exception.py +++ b/tests/ut/python/model/test_lenet_core_after_exception.py @@ -53,5 +53,6 @@ def test_lenet5_exception(): predict = Tensor(in1) label = Tensor(in2) net = train_step_with_loss_warp(LeNet5()) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError) as info: _executor.compile(net, predict, label) + assert "x_shape[C_in] / group must equal to w_shape[C_in] = " in str(info.value)