|
|
@ -851,15 +851,26 @@ class TheOnePSRuntime(RuntimeBase):
|
|
|
|
|
|
|
|
|
|
|
|
return is_valid
|
|
|
|
return is_valid
|
|
|
|
|
|
|
|
|
|
|
|
def _save_sparse_params(self, executor, dirname, context, main_program):
|
|
|
|
def _save_sparse_params(self, executor, dirname, context, main_program,
|
|
|
|
|
|
|
|
mode):
|
|
|
|
|
|
|
|
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
|
|
|
|
|
|
|
|
distributed_varnames = get_sparse_tablenames(
|
|
|
|
|
|
|
|
self.compiled_strategy.origin_main_program, True)
|
|
|
|
values = []
|
|
|
|
values = []
|
|
|
|
for id, names in context.items():
|
|
|
|
for id, names in context.items():
|
|
|
|
|
|
|
|
if names not in distributed_varnames:
|
|
|
|
|
|
|
|
# only save sparse param to local
|
|
|
|
|
|
|
|
self._worker.recv_and_save_model(id, dirname)
|
|
|
|
|
|
|
|
# save sparse & distributed param on server
|
|
|
|
|
|
|
|
self._worker.save_one_model(id, dirname, mode)
|
|
|
|
values.extend(names)
|
|
|
|
values.extend(names)
|
|
|
|
self._worker.save_one_model(id, dirname, 0)
|
|
|
|
|
|
|
|
return values
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
|
|
def _save_distributed_persistables(self, executor, dirname, main_program,
|
|
|
|
def _save_distributed_persistables(self,
|
|
|
|
mode):
|
|
|
|
executor,
|
|
|
|
|
|
|
|
dirname,
|
|
|
|
|
|
|
|
main_program,
|
|
|
|
|
|
|
|
mode=0):
|
|
|
|
|
|
|
|
|
|
|
|
denses = self.compiled_strategy.get_the_one_recv_context(
|
|
|
|
denses = self.compiled_strategy.get_the_one_recv_context(
|
|
|
|
is_dense=True,
|
|
|
|
is_dense=True,
|
|
|
@ -870,14 +881,14 @@ class TheOnePSRuntime(RuntimeBase):
|
|
|
|
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
|
|
|
|
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
|
|
|
|
use_origin_program=True)
|
|
|
|
use_origin_program=True)
|
|
|
|
|
|
|
|
|
|
|
|
recv_sparse_varnames = self._save_sparse_params(executor, dirname,
|
|
|
|
sparse_varnames = self._save_sparse_params(executor, dirname, sparses,
|
|
|
|
sparses, main_program)
|
|
|
|
main_program, mode)
|
|
|
|
|
|
|
|
|
|
|
|
recv_dense_varnames = []
|
|
|
|
recv_dense_varnames = []
|
|
|
|
for id, names in denses.items():
|
|
|
|
for id, names in denses.items():
|
|
|
|
recv_dense_varnames.extend(names)
|
|
|
|
recv_dense_varnames.extend(names)
|
|
|
|
|
|
|
|
|
|
|
|
saved_varnames = recv_sparse_varnames
|
|
|
|
saved_varnames = sparse_varnames
|
|
|
|
|
|
|
|
|
|
|
|
remaining_vars = list(
|
|
|
|
remaining_vars = list(
|
|
|
|
filter(
|
|
|
|
filter(
|
|
|
@ -925,6 +936,7 @@ class TheOnePSRuntime(RuntimeBase):
|
|
|
|
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
|
|
|
|
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Todo(MrChengmo): Save optimizer status
|
|
|
|
self._save_distributed_persistables(executor, dirname, main_program,
|
|
|
|
self._save_distributed_persistables(executor, dirname, main_program,
|
|
|
|
mode)
|
|
|
|
mode)
|
|
|
|
|
|
|
|
|
|
|
@ -971,8 +983,7 @@ class TheOnePSRuntime(RuntimeBase):
|
|
|
|
|
|
|
|
|
|
|
|
program = Program.parse_from_string(program_desc_str)
|
|
|
|
program = Program.parse_from_string(program_desc_str)
|
|
|
|
program._copy_dist_param_info_from(fluid.default_main_program())
|
|
|
|
program._copy_dist_param_info_from(fluid.default_main_program())
|
|
|
|
self._ps_inference_save_persistables(
|
|
|
|
self._ps_inference_save_persistables(executor, dirname, program)
|
|
|
|
executor, dirname, program, mode=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_inference_model(self, *args, **kwargs):
|
|
|
|
def _save_inference_model(self, *args, **kwargs):
|
|
|
|
self._ps_inference_save_inference_model(*args, **kwargs)
|
|
|
|
self._ps_inference_save_inference_model(*args, **kwargs)
|
|
|
|