Fix save and load lookup table/optimizer vars ()

*  fix mkdir conflict

*  fix load/save lookup tables

 test=develop

* add lookup_table_utils

* fix load optimize vars on pserver

* delete lookup table utils

* fix save and load lookup tables

* fix load optimizer var

* fix load optimizer var, test=develop

* fix python 3 style, test=develop

* move lookup_table_utils to contrib utils
local_add_cudnn_lstm
tangwei12 7 years ago committed by GitHub
parent 2fc32b17a2
commit 3639d99f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -67,6 +67,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
"The sparse table only support FP32"); "The sparse table only support FP32");
w_t->Get(ids_t, out_t, true, is_test); w_t->Get(ids_t, out_t, true, is_test);
out_t->set_lod(ids_t.lod());
} }
}; };

@ -127,6 +127,9 @@ class SumKernel : public framework::OpKernel<T> {
math::scatter::MergeAdd<DeviceContext, T> merge_add; math::scatter::MergeAdd<DeviceContext, T> merge_add;
merge_add(context.template device_context<DeviceContext>(), inputs, merge_add(context.template device_context<DeviceContext>(), inputs,
out); out);
out->SyncIndex();
} else { } else {
// no data, just set a empty out tensor. // no data, just set a empty out tensor.
out->mutable_value()->mutable_data<T>(framework::make_ddim({0}), out->mutable_value()->mutable_data<T>(framework::make_ddim({0}),

@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
from . import lookup_table_utils
from .lookup_table_utils import *
from . import hdfs_utils from . import hdfs_utils
from .hdfs_utils import * from .hdfs_utils import *
__all__ = lookup_table_utils.__all__
__all__ = hdfs_utils.__all__ __all__ = hdfs_utils.__all__

File diff suppressed because it is too large Load Diff

@ -1698,6 +1698,7 @@ class Program(object):
p._copy_param_info_from(self) p._copy_param_info_from(self)
p._copy_data_info_from(self) p._copy_data_info_from(self)
p._copy_dist_param_info_from(self)
return p return p
def _prune(self, targets): def _prune(self, targets):
@ -1938,6 +1939,25 @@ class Program(object):
"program, with represent the same topology") "program, with represent the same topology")
self.global_block()._copy_param_info_from(other.global_block()) self.global_block()._copy_param_info_from(other.global_block())
def _copy_dist_param_info_from(self, other):
"""
Copy the information of distributed information from other program.
Args:
other(Program): Other program
Returns:
None
"""
if not isinstance(other, Program):
raise TypeError("_copy_dist_param_info_from should be invoked with "
"Program")
self._is_distributed = other._is_distributed
self._is_chief = other._is_chief
self._slice_vars_and_attrs = other._slice_vars_and_attrs
self._endpoints = other._endpoints
self._distributed_lookup_table = other._distributed_lookup_table
def _copy_data_info_from(self, other): def _copy_data_info_from(self, other):
""" """
Copy the information of data variables from other program. Copy the information of data variables from other program.

@ -165,6 +165,7 @@ def save_vars(executor,
save_vars( save_vars(
executor, executor,
main_program=main_program,
dirname=dirname, dirname=dirname,
vars=list(filter(predicate, main_program.list_vars())), vars=list(filter(predicate, main_program.list_vars())),
filename=filename) filename=filename)
@ -172,11 +173,18 @@ def save_vars(executor,
save_program = Program() save_program = Program()
save_block = save_program.global_block() save_block = save_program.global_block()
if main_program is None:
main_program = default_main_program()
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
save_var_map = {} save_var_map = {}
for each_var in vars: for each_var in vars:
# NOTE: don't save the variable which type is RAW # NOTE: don't save the variable which type is RAW
if each_var.type == core.VarDesc.VarType.RAW: if each_var.type == core.VarDesc.VarType.RAW:
continue continue
if each_var.name == main_program._distributed_lookup_table:
continue
new_var = _clone_var_in_block_(save_block, each_var) new_var = _clone_var_in_block_(save_block, each_var)
if filename is None: if filename is None:
save_block.append_op( save_block.append_op(
@ -198,6 +206,16 @@ def save_vars(executor,
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(dirname, filename)})
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
attrs = {}
attrs['epmap'] = main_program._endpoints
attrs['dir'] = lookup_table_filename
attrs['lookup_table'] = main_program._distributed_lookup_table
save_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(save_program) executor.run(save_program)
@ -379,11 +397,22 @@ def load_vars(executor,
load_prog = Program() load_prog = Program()
load_block = load_prog.global_block() load_block = load_prog.global_block()
if main_program is None:
main_program = default_main_program()
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
load_slice_vars = []
for each_var in main_program._slice_vars_and_attrs:
load_slice_vars.append(each_var[2].name)
load_var_map = {} load_var_map = {}
for each_var in vars: for each_var in vars:
assert isinstance(each_var, Variable) assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW: if each_var.type == core.VarDesc.VarType.RAW:
continue continue
if each_var.name in load_slice_vars:
continue
new_var = _clone_var_in_block_(load_block, each_var) new_var = _clone_var_in_block_(load_block, each_var)
if filename is None: if filename is None:
load_block.append_op( load_block.append_op(
@ -406,9 +435,6 @@ def load_vars(executor,
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
if main_program is None:
main_program = default_main_program()
# load slice vars on pserver, if have it. # load slice vars on pserver, if have it.
_load_slice_up_vars(executor, dirname, _load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_attrs) main_program._slice_vars_and_attrs)
@ -618,13 +644,6 @@ def save_inference_model(dirname,
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
_save_lookup_tables_by_notify(executor, lookup_table_filename,
main_program._distributed_lookup_table,
main_program._endpoints)
# when a pserver and a trainer running on the same machine, mkdir may conflict # when a pserver and a trainer running on the same machine, mkdir may conflict
try: try:
os.makedirs(dirname) os.makedirs(dirname)
@ -642,6 +661,9 @@ def save_inference_model(dirname,
# it can only be loaded for inference directly. If it's false, the whole # it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be # original program and related meta are saved so that future usage can be
# more flexible. # more flexible.
origin_program = main_program.clone()
if export_for_deployment: if export_for_deployment:
main_program = main_program.clone() main_program = main_program.clone()
global_block = main_program.global_block() global_block = main_program.global_block()
@ -666,8 +688,11 @@ def save_inference_model(dirname,
with open(model_basename + ".main_program", "wb") as f: with open(model_basename + ".main_program", "wb") as f:
f.write(main_program.desc.serialize_to_string()) f.write(main_program.desc.serialize_to_string())
main_program._copy_dist_param_info_from(origin_program)
if params_filename is not None: if params_filename is not None:
params_filename = os.path.basename(params_filename) params_filename = os.path.basename(params_filename)
save_persistables(executor, dirname, main_program, params_filename) save_persistables(executor, dirname, main_program, params_filename)
@ -897,6 +922,9 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
slice_var = var_tuple[2] slice_var = var_tuple[2]
end = start + slice_var.shape[0] end = start + slice_var.shape[0]
orig_var_name = orig_var.name
orig_var.name = "{}.origin".format(orig_var_name)
clone_orig_var = load_block.create_var( clone_orig_var = load_block.create_var(
name=orig_var.name, name=orig_var.name,
type=orig_var.type, type=orig_var.type,
@ -915,7 +943,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
type='load', type='load',
inputs={}, inputs={},
outputs={'Out': [clone_orig_var]}, outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)}) attrs={'file_path': os.path.join(dirname, orig_var_name)})
load_block.append_op( load_block.append_op(
type="slice", type="slice",
inputs={'Input': clone_orig_var}, inputs={'Input': clone_orig_var},
@ -924,6 +952,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
'starts': [start], 'starts': [start],
'ends': [end]}) 'ends': [end]})
need_delete_vars.append(clone_orig_var) need_delete_vars.append(clone_orig_var)
load_block.append_op( load_block.append_op(
type='delete_var', type='delete_var',
inputs={'X': need_delete_vars}, ) inputs={'X': need_delete_vars}, )

@ -644,6 +644,9 @@ in a single call.")
else: else:
recv_inputs.append(single_trainer_var) recv_inputs.append(single_trainer_var)
self._slice_params_and_optimizes = self._get_slice_vars_and_attrs(
endpoint)
# step 3 # step 3
# Create a union-find data structure from optimize ops, # Create a union-find data structure from optimize ops,
# If two ops are connected, we could add these two ops # If two ops are connected, we could add these two ops
@ -766,7 +769,7 @@ in a single call.")
grad_to_block_id, merged_var, grad_to_block_id, merged_var,
lr_ops) lr_ops)
# dedup grad to ids list # dedup grad to ids list
grad_to_block_id = list(set(grad_to_block_id)) grad_to_block_id = list(set(grad_to_block_id))
# append global ops # append global ops
if global_ops: if global_ops:
@ -827,8 +830,8 @@ in a single call.")
attrs=attrs) attrs=attrs)
# add distributed attrs # add distributed attrs
pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs( pserver_program._slice_vars_and_attrs = list(
endpoint) self._slice_params_and_optimizes.values())
pserver_program._sync_with_cpp() pserver_program._sync_with_cpp()
# save pserver program to generate pserver side startup relatively. # save pserver program to generate pserver side startup relatively.
@ -941,12 +944,12 @@ to transpile() call.")
outputs={"Out": startup_tmpvar}) outputs={"Out": startup_tmpvar})
# add slice vars # add slice vars
s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint) s_prog._slice_vars_and_attrs = pserver_program._slice_vars_and_attrs
return s_prog return s_prog
def _get_slice_vars_and_attrs(self, endpoint): def _get_slice_vars_and_attrs(self, endpoint):
slice_vars_and_attrs = [] slice_vars_and_attrs = {}
block_suffix = "block" block_suffix = "block"
for param in self.param_grad_ep_mapping[endpoint]["params"]: for param in self.param_grad_ep_mapping[endpoint]["params"]:
orig_var_name, block_name, _ = self._get_varname_parts(param.name) orig_var_name, block_name, _ = self._get_varname_parts(param.name)
@ -960,8 +963,7 @@ to transpile() call.")
slice_vars = self.param_var_mapping[orig_var_name] slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]: for slice_var in slice_vars[:block_idx]:
skip_dim0 += slice_var.shape[0] skip_dim0 += slice_var.shape[0]
slice_vars_and_attrs.append([orig_var, skip_dim0, param]) slice_vars_and_attrs[param.name] = [orig_var, skip_dim0, param]
return slice_vars_and_attrs return slice_vars_and_attrs
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
@ -1662,10 +1664,10 @@ to transpile() call.")
if key in ["Param", "Grad", "LearningRate"]: if key in ["Param", "Grad", "LearningRate"]:
continue continue
var = self.origin_program.global_block().vars[opt_op.input(key)[0]] var = self.origin_program.global_block().vars[opt_op.input(key)[0]]
param_var = new_inputs["Param"]
# update accumulator variable shape # update accumulator variable shape
param_shape = new_inputs["Param"].shape new_shape = self._get_optimizer_input_shape(
new_shape = self._get_optimizer_input_shape(opt_op.type, key, opt_op.type, key, var.shape, param_var.shape)
var.shape, param_shape)
tmpvar = pserver_block.create_var( tmpvar = pserver_block.create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
@ -1673,6 +1675,13 @@ to transpile() call.")
shape=new_shape) shape=new_shape)
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
# var shape been changed
if new_shape != var.shape:
slice_var_args = self._slice_params_and_optimizes[
param_var.name]
self._slice_params_and_optimizes[
var.name] = [var, slice_var_args[1], tmpvar]
# change output's ParamOut variable # change output's ParamOut variable
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)

Loading…
Cancel
Save