|
|
|
@ -18,14 +18,12 @@ import os
|
|
|
|
|
import time
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
from paddle.fluid import io
|
|
|
|
|
from paddle.fluid import Program
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
"load_inference_model", "load_persistable_vars",
|
|
|
|
|
"load_persistables_for_increment", "load_persistables_for_inference",
|
|
|
|
|
"convert_dist_to_sparse_program"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
@ -80,19 +78,28 @@ def __get_prefetch_op_tuples(main_program):
|
|
|
|
|
return prefetch_op_tuples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_dist_to_sparse_program(main_program):
|
|
|
|
|
if not main_program._distributed_lookup_table:
|
|
|
|
|
def convert_dist_to_sparse_program(program):
|
|
|
|
|
"""
|
|
|
|
|
WARNING: this function will only be used for distributed training with distributed lookup table.
|
|
|
|
|
when we train model with distributed lookup table but want to do the local inference, we can use
|
|
|
|
|
this function to convert the train program with distributed lookup table to sparse lookup table.
|
|
|
|
|
|
|
|
|
|
:param program(Program): the program must be the trainer program, which will be get by the distribute transpiler.
|
|
|
|
|
:return:
|
|
|
|
|
program: The `program` is a Program, it's the program replace distributed lookup table to sparse lookup table.
|
|
|
|
|
"""
|
|
|
|
|
if not program._distributed_lookup_table:
|
|
|
|
|
_logger.warn(
|
|
|
|
|
"There are no distributed lookup tables need to be converted")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# create table param and grad var in pserver program
|
|
|
|
|
origin_emb_var = "{}.origin".format(main_program._distributed_lookup_table)
|
|
|
|
|
emb_var = main_program._distributed_lookup_table
|
|
|
|
|
main_program.global_block()._rename_var(emb_var, origin_emb_var)
|
|
|
|
|
origin_param_var = main_program.global_block().vars[origin_emb_var]
|
|
|
|
|
origin_emb_var = "{}.origin".format(program._distributed_lookup_table)
|
|
|
|
|
emb_var = program._distributed_lookup_table
|
|
|
|
|
program.global_block()._rename_var(emb_var, origin_emb_var)
|
|
|
|
|
origin_param_var = program.global_block().vars[origin_emb_var]
|
|
|
|
|
|
|
|
|
|
param_var = main_program.global_block().create_var(
|
|
|
|
|
param_var = program.global_block().create_var(
|
|
|
|
|
name=emb_var,
|
|
|
|
|
shape=origin_param_var.shape,
|
|
|
|
|
dtype=origin_param_var.dtype,
|
|
|
|
@ -100,28 +107,28 @@ def convert_dist_to_sparse_program(main_program):
|
|
|
|
|
persistable=True)
|
|
|
|
|
# parameter must be selected rows
|
|
|
|
|
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
|
|
|
|
|
main_program._sync_with_cpp()
|
|
|
|
|
program._sync_with_cpp()
|
|
|
|
|
|
|
|
|
|
prefetch_op_tuples = __get_prefetch_op_tuples(main_program)
|
|
|
|
|
prefetch_op_tuples = __get_prefetch_op_tuples(program)
|
|
|
|
|
|
|
|
|
|
split_ids_id = prefetch_op_tuples[0]
|
|
|
|
|
|
|
|
|
|
for idx in range(split_ids_id + 2, split_ids_id - 1, -1):
|
|
|
|
|
main_program.global_block()._remove_op(idx)
|
|
|
|
|
main_program.desc.flush()
|
|
|
|
|
program.global_block()._remove_op(idx)
|
|
|
|
|
program.desc.flush()
|
|
|
|
|
|
|
|
|
|
in_out_pairs = zip(prefetch_op_tuples[1], prefetch_op_tuples[2])
|
|
|
|
|
|
|
|
|
|
for in_out_pair in in_out_pairs:
|
|
|
|
|
idx = split_ids_id
|
|
|
|
|
ids = main_program.global_block().vars[in_out_pair[0]]
|
|
|
|
|
out = main_program.global_block().vars[in_out_pair[1]]
|
|
|
|
|
__insert_lookup_sparse_table_op(main_program, idx, ids, param_var, out)
|
|
|
|
|
main_program.desc.flush()
|
|
|
|
|
return main_program
|
|
|
|
|
ids = program.global_block().vars[in_out_pair[0]]
|
|
|
|
|
out = program.global_block().vars[in_out_pair[1]]
|
|
|
|
|
__insert_lookup_sparse_table_op(program, idx, ids, param_var, out)
|
|
|
|
|
program.desc.flush()
|
|
|
|
|
return program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_persistable_vars(executor, dirname, program, lookup_table_var):
|
|
|
|
|
def _load_persistable_vars(executor, dirname, program, lookup_table_vars):
|
|
|
|
|
def _is_checkpoint_var(exclude_fluid_vars=None):
|
|
|
|
|
"""
|
|
|
|
|
the checkpoint will not save or load all the variables.
|
|
|
|
@ -159,7 +166,81 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var):
|
|
|
|
|
|
|
|
|
|
return is_valid
|
|
|
|
|
|
|
|
|
|
def _load_lookup_table_vars(executor, dirname, main_program,
|
|
|
|
|
io.load_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
main_program=program,
|
|
|
|
|
predicate=_is_checkpoint_var(lookup_table_vars),
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_persistables_for_increment(dirname, executor, program,
|
|
|
|
|
lookup_table_var, lookup_table_var_path):
|
|
|
|
|
"""
|
|
|
|
|
WARNING: this function will only be used for distributed training with distributed lookup table.
|
|
|
|
|
for increment trainning, the pserver will not only load dense variables,
|
|
|
|
|
but also load the suitable lookup table var. Because of slice lookup table
|
|
|
|
|
var with HASH, we must load the correct slice var.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param dirname(str): The directory path
|
|
|
|
|
:param executor(Executor): The executor to run for loading inference model.
|
|
|
|
|
:param program(Program): The parameter server program, which will run on Pserver.
|
|
|
|
|
:param lookup_table_var: the distributed lookup tables var name.
|
|
|
|
|
:param lookup_table_var_path: the the distributed lookup tables var location.
|
|
|
|
|
:return: None
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __load_lookup_table_vars(executor, main_program, lookup_table_var,
|
|
|
|
|
lookup_table_var_path):
|
|
|
|
|
emb_var = main_program.global_block().var(lookup_table_var)
|
|
|
|
|
|
|
|
|
|
load_program = Program()
|
|
|
|
|
load_block = load_program.global_block()
|
|
|
|
|
load_block.append_op(
|
|
|
|
|
type='load',
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={'Out': [emb_var]},
|
|
|
|
|
attrs={'file_path': lookup_table_var_path})
|
|
|
|
|
executor.run(load_program)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
|
raise ValueError("There is no directory named '%s'", dirname)
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(lookup_table_var_path):
|
|
|
|
|
raise ValueError("There is no file named '%s'", lookup_table_var_path)
|
|
|
|
|
|
|
|
|
|
if not isinstance(program, Program):
|
|
|
|
|
raise ValueError("program must be an instance of fluid.Program")
|
|
|
|
|
|
|
|
|
|
_logger.info("Start Load Sparse Program With "
|
|
|
|
|
"Distributed Lookup Table Vars from {}, time = {}".format(
|
|
|
|
|
dirname, time.ctime()))
|
|
|
|
|
|
|
|
|
|
_load_persistable_vars(executor, dirname, program, [lookup_table_var])
|
|
|
|
|
__load_lookup_table_vars(executor, program, lookup_table_var,
|
|
|
|
|
lookup_table_var_path)
|
|
|
|
|
|
|
|
|
|
_logger.info("Finish Load Sparse Program With "
|
|
|
|
|
"Distributed Lookup Table Vars from {}, time = {}".format(
|
|
|
|
|
dirname, time.ctime()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_persistables_for_inference(dirname, executor, program,
|
|
|
|
|
lookup_table_var_name):
|
|
|
|
|
"""
|
|
|
|
|
WARNING: this function will only be used for inference with distributed lookup table.
|
|
|
|
|
Inference with distributed lookup table is a little funky, this function will load distributed
|
|
|
|
|
lookup table vars into sparse var, can be used in local inference mode.
|
|
|
|
|
|
|
|
|
|
:param dirname(str): The directory path
|
|
|
|
|
:param executor(Executor): The executor to run for loading inference model.
|
|
|
|
|
:param program(Program): The parameter server program, which will run on Pserver.
|
|
|
|
|
:param lookup_table_var_name: the distributed lookup tables var name.
|
|
|
|
|
:return: None
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __load_lookup_table_vars(executor, dirname, main_program,
|
|
|
|
|
lookup_table_vars):
|
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
|
raise ValueError("There is no directory named '%s'", dirname)
|
|
|
|
@ -209,30 +290,13 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var):
|
|
|
|
|
global_block.append_op(type='delete_var', inputs={'X': sums})
|
|
|
|
|
executor.run(convert_program)
|
|
|
|
|
|
|
|
|
|
_logger.info("Start Load Sparse Program With "
|
|
|
|
|
"Distributed Lookup Table Vars from {}, time = {}".format(
|
|
|
|
|
dirname, time.ctime()))
|
|
|
|
|
|
|
|
|
|
lookup_table_vars = [lookup_table_var]
|
|
|
|
|
|
|
|
|
|
io.load_vars(
|
|
|
|
|
executor,
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
main_program=program,
|
|
|
|
|
predicate=_is_checkpoint_var(lookup_table_vars),
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
_load_lookup_table_vars(executor, dirname, program, lookup_table_vars)
|
|
|
|
|
|
|
|
|
|
_logger.info("Finish Load Sparse Program With "
|
|
|
|
|
"Distributed Lookup Table Vars from {}, time = {}".format(
|
|
|
|
|
dirname, time.ctime()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_inference_model(dirname, executor, lookup_table_var_name):
|
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
|
raise ValueError("There is no directory named '%s'", dirname)
|
|
|
|
|
|
|
|
|
|
if program:
|
|
|
|
|
if not isinstance(program, Program):
|
|
|
|
|
raise ValueError("program must be an instance of fluid.Program")
|
|
|
|
|
else:
|
|
|
|
|
local_model = os.path.join(dirname, model_filename)
|
|
|
|
|
|
|
|
|
|
with open(local_model, "rb") as f:
|
|
|
|
@ -244,13 +308,16 @@ def load_inference_model(dirname, executor, lookup_table_var_name):
|
|
|
|
|
raise ValueError("Unsupported program version: %d\n" %
|
|
|
|
|
program._version())
|
|
|
|
|
|
|
|
|
|
# Binary data also need version.
|
|
|
|
|
load_persistable_vars(executor, dirname, program, lookup_table_var_name)
|
|
|
|
|
_logger.info("Start Load Sparse Program With "
|
|
|
|
|
"Distributed Lookup Table Vars from {}, time = {}".format(
|
|
|
|
|
dirname, time.ctime()))
|
|
|
|
|
|
|
|
|
|
feed_target_names = program.desc.get_feed_target_names()
|
|
|
|
|
fetch_target_names = program.desc.get_fetch_target_names()
|
|
|
|
|
fetch_targets = [
|
|
|
|
|
program.global_block().var(name) for name in fetch_target_names
|
|
|
|
|
]
|
|
|
|
|
_load_persistable_vars(executor, dirname, program, [lookup_table_var_name])
|
|
|
|
|
__load_lookup_table_vars(executor, dirname, program,
|
|
|
|
|
[lookup_table_var_name])
|
|
|
|
|
|
|
|
|
|
_logger.info("Finish Load Sparse Program With "
|
|
|
|
|
"Distributed Lookup Table Vars from {}, time = {}".format(
|
|
|
|
|
dirname, time.ctime()))
|
|
|
|
|
|
|
|
|
|
return [program, feed_target_names, fetch_targets]
|
|
|
|
|
return program
|
|
|
|
|