diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu index 13e42d4650..253908082a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cu @@ -45,6 +45,15 @@ __global__ void ValidateInputAndInferShape(const T *range_start, const T *range_ if (*error_code == DynamicRangeErrorCode::kOk) { int64_t real_output_shape = static_cast(ceil(static_cast(end - start) / delta)); + + // verification in case of precision error during calculation of real_output_shape. one multiplication followed by + // one addition is much more precise than the division that occurs when calculating real_output_shape. + double last_value = start + (delta * (real_output_shape - 1)); + double epsilon = 1e-6; + if ((end > start && last_value > end) || (start > end && last_value < end) || fabsf(last_value - end) < epsilon) { + real_output_shape--; + } + if (real_output_shape > max_output_size) { *error_code = DynamicRangeErrorCode::kMaxSizeExceeded; } diff --git a/tests/st/ops/gpu/test_range_op.py b/tests/st/ops/gpu/test_range_op.py index 102f45bfd9..92c62a275d 100644 --- a/tests/st/ops/gpu/test_range_op.py +++ b/tests/st/ops/gpu/test_range_op.py @@ -22,7 +22,7 @@ from mindspore import Tensor from mindspore.ops import operations as P class RangeNet(nn.Cell): - def __init__(self, maxlen=10000): + def __init__(self, maxlen=50): super(RangeNet, self).__init__() self.range = P.Range(maxlen) @@ -30,6 +30,40 @@ class RangeNet(nn.Cell): return self.range(start, limit, delta) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_range_precision_end_equals_last_element(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + range_net = RangeNet(100) + ms_out = range_net(Tensor(1000.04, mstype.float32), + Tensor(1001.04, mstype.float32), + Tensor(0.01, mstype.float32)).asnumpy() + np_expected = np.arange(1000.04, 1001.04, 0.01, dtype=np.float32) + np.testing.assert_allclose(ms_out, np_expected, rtol=1e-5) + + range_net = RangeNet(1000) + ms_out = range_net(Tensor(100, mstype.float32), + Tensor(101, mstype.float32), + Tensor(0.001, mstype.float32)).asnumpy() + np_expected = np.arange(100, 101, 0.001, dtype=np.float32) + np.testing.assert_allclose(ms_out, np_expected, rtol=1e-5) + + range_net = RangeNet(799900) + ms_out = range_net(Tensor(1, mstype.float32), + Tensor(8000, mstype.float32), + Tensor(0.01, mstype.float32)).asnumpy() + np_expected = np.arange(1, 8000, 0.01, dtype=np.float32) + np.testing.assert_allclose(ms_out, np_expected, rtol=1e-5) + + range_net = RangeNet(53) + ms_out = range_net(Tensor(-12000, mstype.float32), + Tensor(-12053, mstype.float32), + Tensor(-1, mstype.float32)).asnumpy() + np_expected = np.arange(-12000, -12053, -1, dtype=np.float32) + np.testing.assert_allclose(ms_out, np_expected, rtol=1e-5) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -97,7 +131,7 @@ def test_range_invalid_max_output_length(): @pytest.mark.env_onecard def test_range_invalid_input(): with pytest.raises(RuntimeError) as info: - range_net = RangeNet(3500) + range_net = RangeNet() _ = range_net(Tensor(0, mstype.int32), Tensor(5, mstype.int32), Tensor(0, mstype.int32)).asnumpy() assert "delta cannot be equal to zero" in str(info.value) @@ -107,11 +141,11 @@ def test_range_invalid_input(): assert "number of elements in the output exceeds maxlen" in str(info.value) with pytest.raises(RuntimeError) as info: - range_net = RangeNet(3500) + range_net = RangeNet() _ = range_net(Tensor(20, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy() assert "delta cannot be positive when limit < start" in str(info.value) with pytest.raises(RuntimeError) as info: - range_net = RangeNet(3500) + range_net = RangeNet() _ = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(-4, mstype.int32)).asnumpy() assert "delta cannot be negative when limit > start" in str(info.value)