diff --git a/mindspore/_version_check.py b/mindspore/_version_check.py index 2f435ce33e..5d42a5c78f 100644 --- a/mindspore/_version_check.py +++ b/mindspore/_version_check.py @@ -280,4 +280,17 @@ def check_version_and_env_config(): except ImportError as e: env_checker.check_env(e) + +def _set_pb_env(): + """Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow.""" + if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp": + logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\ + When the parameter is too large, it may cause memory limit error.\ + This can be solved by set env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.") + elif os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "": + logger.warning("Set the env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python` to prevent memory overflow.") + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + + check_version_and_env_config() +_set_pb_env() diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index c4d35b4c80..86f014f5da 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -49,17 +49,6 @@ _ckpt_mutex = Lock() SLICE_SIZE = 512 * 1024 * 1024 -def _set_pb_env(): - """Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow.""" - if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp": - logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\ - When the parameter is too large, it may cause memory limit error.\ - This can be solved by set env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.") - else: - os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" - logger.debug("Set the `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.") - - def _special_process_par(par, new_par): """ Processes the special condition. @@ -885,6 +874,3 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) return merged_parameter - - -_set_pb_env()