|
|
|
@ -733,6 +733,9 @@ def load_inference_model(dirname,
|
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
|
raise ValueError("There is no directory named '%s'", dirname)
|
|
|
|
|
|
|
|
|
|
if training_role == "PSERVER":
|
|
|
|
|
_load_lookup_table_vars(executor, dirname, program, role_id)
|
|
|
|
|
|
|
|
|
|
if model_filename is not None:
|
|
|
|
|
model_filename = os.path.basename(model_filename)
|
|
|
|
|
else:
|
|
|
|
@ -749,10 +752,7 @@ def load_inference_model(dirname,
|
|
|
|
|
load_persistables(executor, dirname, program, params_filename)
|
|
|
|
|
|
|
|
|
|
if pserver_endpoints:
|
|
|
|
|
_endpoints_replacement(program, pserver_endpoints)
|
|
|
|
|
|
|
|
|
|
if training_role == "PSERVER":
|
|
|
|
|
_load_lookup_table_vars(executor, dirname, program, role_id)
|
|
|
|
|
program = _endpoints_replacement(program, pserver_endpoints)
|
|
|
|
|
|
|
|
|
|
feed_target_names = program.desc.get_feed_target_names()
|
|
|
|
|
fetch_target_names = program.desc.get_fetch_target_names()
|
|
|
|
@ -871,8 +871,10 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id):
|
|
|
|
|
def _endpoints_replacement(program, endpoints):
|
|
|
|
|
ENDPOINT_MAP = "epmap"
|
|
|
|
|
for op in program.global_block().ops:
|
|
|
|
|
if op.attrs.has_key(ENDPOINT_MAP):
|
|
|
|
|
op.attrs[ENDPOINT_MAP] = endpoints
|
|
|
|
|
if op.has_attr(ENDPOINT_MAP):
|
|
|
|
|
op.set_attr(ENDPOINT_MAP, endpoints)
|
|
|
|
|
program = program.clone()
|
|
|
|
|
return program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_parameter_value(para, executor):
|
|
|
|
|