|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
import distributed_splitter as splitter
|
|
|
|
@ -26,6 +27,10 @@ LOOKUP_TABLE_TYPE = "lookup_table"
|
|
|
|
|
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
|
|
|
|
|
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
|
|
|
|
|
|
|
|
|
|
# for checkpoint
|
|
|
|
|
SUCCESS = "_SUCCESS"
|
|
|
|
|
SERIAL_VAR_NAME = "SERIAL_NUMBER"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VarBlock:
|
|
|
|
|
def __init__(self, varname, offset, size):
|
|
|
|
@ -153,7 +158,8 @@ class DistributeTranspiler:
|
|
|
|
|
pservers="127.0.0.1:6174",
|
|
|
|
|
trainers=1,
|
|
|
|
|
split_method=splitter.round_robin,
|
|
|
|
|
sync_mode=True):
|
|
|
|
|
sync_mode=True,
|
|
|
|
|
checkpoint_dir=None):
|
|
|
|
|
"""
|
|
|
|
|
Transpile the program to distributed data-parallelism programs.
|
|
|
|
|
The main_program will be transformed to use a remote parameter server
|
|
|
|
@ -315,22 +321,22 @@ class DistributeTranspiler:
|
|
|
|
|
"sync_mode": self.sync_mode
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
serial_var = program.global_block().create_var(
|
|
|
|
|
name="SERIAL_NUMBER",
|
|
|
|
|
persistable=True,
|
|
|
|
|
type=core.VarDesc.VarType.RAW)
|
|
|
|
|
if checkpoint_dir and self.is_chief:
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name=SERIAL_VAR_NAME,
|
|
|
|
|
persistable=True,
|
|
|
|
|
type=core.VarDesc.VarType.RAW)
|
|
|
|
|
|
|
|
|
|
save_vars = []
|
|
|
|
|
for var in self.origin_program.list_vars():
|
|
|
|
|
if self.is_persistable(var):
|
|
|
|
|
save_vars.append(var.name)
|
|
|
|
|
save_vars = []
|
|
|
|
|
for var in self.origin_program.list_vars():
|
|
|
|
|
if self._is_persistable(var):
|
|
|
|
|
save_vars.append(var.name)
|
|
|
|
|
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type="checkpoint_save",
|
|
|
|
|
inputs={"X": save_vars},
|
|
|
|
|
outputs={"Serial": serial_var},
|
|
|
|
|
attrs={"overwrite": False,
|
|
|
|
|
"dir": "/workspace/ckpt/"})
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type="checkpoint_save",
|
|
|
|
|
inputs={"X": save_vars},
|
|
|
|
|
attrs={"overwrite": True,
|
|
|
|
|
"dir": checkpoint_dir})
|
|
|
|
|
|
|
|
|
|
# step4: Concat the parameters splits together after recv.
|
|
|
|
|
for varname, splited_var in param_var_mapping.iteritems():
|
|
|
|
@ -512,13 +518,6 @@ class DistributeTranspiler:
|
|
|
|
|
pserver_program.sync_with_cpp()
|
|
|
|
|
return pserver_program
|
|
|
|
|
|
|
|
|
|
def is_persistable(self, var):
|
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.RAW :
|
|
|
|
|
return False
|
|
|
|
|
return var.persistable
|
|
|
|
|
|
|
|
|
|
def get_train_startup_program(self, checkpoint_load_dir=None):
|
|
|
|
|
"""
|
|
|
|
|
Get train startup program.
|
|
|
|
@ -532,13 +531,16 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
load_vars = []
|
|
|
|
|
for var in startup_prog.list_vars():
|
|
|
|
|
if self.is_persistable(var):
|
|
|
|
|
if self._is_persistable(var):
|
|
|
|
|
load_vars.append(var.name)
|
|
|
|
|
|
|
|
|
|
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir)
|
|
|
|
|
|
|
|
|
|
startup_prog.global_block().append_op(
|
|
|
|
|
type="checkpoint_load",
|
|
|
|
|
outputs={"Out": load_vars},
|
|
|
|
|
attrs={"dir": checkpoint_load_dir})
|
|
|
|
|
inputs={"X": load_vars},
|
|
|
|
|
attrs={"dir": checkpoint_load_dir,
|
|
|
|
|
"Serial": serial_number})
|
|
|
|
|
return startup_prog
|
|
|
|
|
|
|
|
|
|
def get_startup_program(self,
|
|
|
|
@ -603,12 +605,55 @@ class DistributeTranspiler:
|
|
|
|
|
if not checkpoint_load_dir:
|
|
|
|
|
return s_prog
|
|
|
|
|
|
|
|
|
|
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir)
|
|
|
|
|
|
|
|
|
|
s_prog.global_block().append_op(
|
|
|
|
|
type="checkpoint_load",
|
|
|
|
|
inputs={"X": load_vars},
|
|
|
|
|
attrs={"dir": checkpoint_load_dir})
|
|
|
|
|
attrs={"dir": checkpoint_load_dir,
|
|
|
|
|
"Serial": serial_number})
|
|
|
|
|
|
|
|
|
|
return s_prog
|
|
|
|
|
|
|
|
|
|
def _is_persistable(self, var):
|
|
|
|
|
"""only save LodTensor variable"""
|
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.RAW :
|
|
|
|
|
return False
|
|
|
|
|
return var.persistable
|
|
|
|
|
|
|
|
|
|
def _get_lastest_checkpoint_dir(self, checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
get the biggest number in checkpoint_dir, which has _SUCCESS
|
|
|
|
|
"""
|
|
|
|
|
if not checkpoint_dir.strip():
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
def has_success(checkpoint_dir, cur_dir):
|
|
|
|
|
"""
|
|
|
|
|
is _SUCCESS in this dir
|
|
|
|
|
"""
|
|
|
|
|
if not os.path.isdir(cur_dir):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
int(cur_dir)
|
|
|
|
|
except ValueError:
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
|
|
|
|
|
if os.path.isfile(success_path):
|
|
|
|
|
return int(cur_dir)
|
|
|
|
|
|
|
|
|
|
current_dir = 0
|
|
|
|
|
dirs = os.listdir(checkpoint_dir)
|
|
|
|
|
for cur_dir in dirs:
|
|
|
|
|
success_num = has_success(checkpoint_dir, cur_dir)
|
|
|
|
|
if success_num > current_dir:
|
|
|
|
|
current_dir = success_num
|
|
|
|
|
return str(current_dir)
|
|
|
|
|
|
|
|
|
|
# transpiler function for dis lookup_table
|
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
|
|
|
|
|
eplist):
|
|
|
|
|