diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc index 8b9267f820..010c691e75 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -21,9 +21,10 @@ namespace mindspore { namespace device { namespace ascend { -constexpr uint64_t kAscendDeviceMemGB = 30; +constexpr uint64_t kAscendInitDeviceMemGB = 30; +constexpr uint64_t kAscendMaxDeviceMemGB = 31; constexpr uint64_t kMemSizeGB = 30; -constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); +constexpr uint64_t kAscendDeviceMemSize = (kAscendInitDeviceMemGB << kMemSizeGB); void AscendMemoryManager::MallocDeviceMemory() { auto context_mem = GetDeviceMemSizeFromContext(); @@ -58,8 +59,8 @@ uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { auto gb_str = variable_memory_max_size.substr(0, pos); auto gb_var = std::stoull(gb_str); MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; - if (gb_var > kAscendDeviceMemGB || gb_var == 0) { - MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB"; + if (gb_var > kAscendMaxDeviceMemGB || gb_var == 0) { + MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-31]GB"; } return gb_var << kMemSizeGB; } diff --git a/mindspore/context.py b/mindspore/context.py index 8c7bdc5d17..d4914aef5a 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -225,8 +225,8 @@ class _Context: """set values of variable_memory_max_size and graph_memory_max_size""" if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern): raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") - if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: - raise ValueError("Context param variable_memory_max_size should be less than 31GB.") + if int(variable_memory_max_size[:-2]) > _DEVICE_APP_MEMORY_SIZE: + raise ValueError("Context param variable_memory_max_size should be not greater than 31GB.") variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" diff --git a/tests/ut/python/pynative_mode/test_context.py b/tests/ut/python/pynative_mode/test_context.py index 77e71e969b..53d6e97f78 100644 --- a/tests/ut/python/pynative_mode/test_context.py +++ b/tests/ut/python/pynative_mode/test_context.py @@ -115,7 +115,7 @@ def test_variable_memory_max_size(): with pytest.raises(ValueError): context.set_context(variable_memory_max_size="1G") with pytest.raises(ValueError): - context.set_context(variable_memory_max_size="31GB") + context.set_context(variable_memory_max_size="32GB") context.set_context(variable_memory_max_size="3GB")