|
|
|
@ -47,6 +47,16 @@ _ckpt_mutex = Lock()
|
|
|
|
|
SLICE_SIZE = 512 * 1024 * 1024
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_pb_env():
|
|
|
|
|
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow."""
|
|
|
|
|
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp":
|
|
|
|
|
logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\
|
|
|
|
|
When the parameter is too large, it may cause memory limit error.")
|
|
|
|
|
else:
|
|
|
|
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
|
|
|
|
logger.debug("Set the `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _special_process_par(par, new_par):
|
|
|
|
|
"""
|
|
|
|
|
Processes the special condition.
|
|
|
|
@ -785,3 +795,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
|
|
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
|
|
|
|
|
|
|
|
|
|
return merged_parameter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_set_pb_env()
|
|
|
|
|