From 051eaa5fc7551583d5f29fe6092becf1874d80ce Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 14 Aug 2018 10:14:55 +0800 Subject: [PATCH] add ditriubted attrs --- python/paddle/fluid/io.py | 14 ++++++++------ .../fluid/transpiler/distribute_transpiler.py | 12 +++++++++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 44f7f12b9c..87c91475ba 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -733,6 +733,9 @@ def load_inference_model(dirname, if not os.path.isdir(dirname): raise ValueError("There is no directory named '%s'", dirname) + if training_role == "PSERVER": + _load_lookup_table_vars(executor, dirname, program, role_id) + if model_filename is not None: model_filename = os.path.basename(model_filename) else: @@ -749,10 +752,7 @@ def load_inference_model(dirname, load_persistables(executor, dirname, program, params_filename) if pserver_endpoints: - _endpoints_replacement(program, pserver_endpoints) - - if training_role == "PSERVER": - _load_lookup_table_vars(executor, dirname, program, role_id) + program = _endpoints_replacement(program, pserver_endpoints) feed_target_names = program.desc.get_feed_target_names() fetch_target_names = program.desc.get_fetch_target_names() @@ -871,8 +871,10 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id): def _endpoints_replacement(program, endpoints): ENDPOINT_MAP = "epmap" for op in program.global_block().ops: - if op.attrs.has_key(ENDPOINT_MAP): - op.attrs[ENDPOINT_MAP] = endpoints + if op.has_attr(ENDPOINT_MAP): + op.set_attr(ENDPOINT_MAP, endpoints) + program = program.clone() + return program def get_parameter_value(para, executor): diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 15675b4e9f..9897837ae5 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -501,6 +501,8 @@ class DistributeTranspiler(object): checkpoint_block_id = self._create_checkpoint_save_block( pserver_program, table_opt_block.idx) + pserver_program._distributed_lookup_table = self.table_name + # NOTE: if has_distributed_lookup_table is False, then prefetch_block will # not be executed, so it's safe to use optimize_block to hold the place if self.has_distributed_lookup_table: @@ -527,9 +529,13 @@ class DistributeTranspiler(object): outputs={}, attrs=attrs) - # add slice vars - slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint) - pserver_program._slice_vars_and_atts = slice_vars_and_atts + # add distributed attrs + pserver_program._slice_vars_and_atts = self._get_slice_vars_and_atts( + endpoint) + pserver_program._is_distributed = True + pserver_program._endpoints = self.pserver_endpoints + pserver_program._is_chief = self.trainer_id == 0 + pserver_program._distributed_lookup_table = self.table_name if self.table_name else None pserver_program._sync_with_cpp() return pserver_program