|
|
|
@ -49,17 +49,6 @@ _ckpt_mutex = Lock()
|
|
|
|
|
SLICE_SIZE = 512 * 1024 * 1024
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_pb_env():
|
|
|
|
|
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow."""
|
|
|
|
|
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp":
|
|
|
|
|
logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\
|
|
|
|
|
When the parameter is too large, it may cause memory limit error.\
|
|
|
|
|
This can be solved by set env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.")
|
|
|
|
|
else:
|
|
|
|
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
|
|
|
|
logger.debug("Set the `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _special_process_par(par, new_par):
|
|
|
|
|
"""
|
|
|
|
|
Processes the special condition.
|
|
|
|
@ -892,6 +881,3 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
|
|
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
|
|
|
|
|
|
|
|
|
|
return merged_parameter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_set_pb_env()
|
|
|
|
|