|
|
|
@ -38,6 +38,19 @@ default_envs = {
|
|
|
|
|
GPUS = 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gpu_ids(gpus):
|
|
|
|
|
if os.getenv("CUDA_VISIBLE_DEVICES"):
|
|
|
|
|
ids = [int(i)
|
|
|
|
|
for i in os.getenv("CUDA_VISIBLE_DEVICES").split(",")][:gpus]
|
|
|
|
|
if gpus > len(ids):
|
|
|
|
|
raise EnvironmentError(
|
|
|
|
|
"The count of env CUDA_VISIBLE_DEVICES should not greater than the passed gpus: %s"
|
|
|
|
|
% gpus)
|
|
|
|
|
return ids
|
|
|
|
|
else:
|
|
|
|
|
return [i for i in range(gpus)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start_procs(gpus, entrypoint, entrypoint_args, log_dir):
|
|
|
|
|
procs = []
|
|
|
|
|
log_fns = []
|
|
|
|
@ -61,8 +74,8 @@ def start_procs(gpus, entrypoint, entrypoint_args, log_dir):
|
|
|
|
|
all_nodes_devices_endpoints += "%s:617%d" % (n, i)
|
|
|
|
|
nranks = num_nodes * gpus
|
|
|
|
|
# ======== for dist training =======
|
|
|
|
|
|
|
|
|
|
for i in range(gpus):
|
|
|
|
|
gpu_ids = get_gpu_ids(gpus)
|
|
|
|
|
for i in gpu_ids:
|
|
|
|
|
curr_env = {}
|
|
|
|
|
curr_env.update(default_envs)
|
|
|
|
|
curr_env.update({
|
|
|
|
|