diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index d4a4e015a7..7ff7323deb 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -47,6 +47,16 @@ _ckpt_mutex = Lock() SLICE_SIZE = 512 * 1024 * 1024 +def _set_pb_env(): + """Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow.""" + if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp": + logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\ + When the parameter is too large, it may cause memory limit error.") + else: + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + logger.debug("Set the `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.") + + def _special_process_par(par, new_par): """ Processes the special condition. @@ -787,3 +797,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) return merged_parameter + + +_set_pb_env()