add doc string for executor and update API.spec

test=develop
revert-16555-model_data_cryption_link_all_lib
dongdaxiang 6 years ago
parent d52586a97d
commit b95b80bc76

File diff suppressed because one or more lines are too long

@ -164,6 +164,8 @@ class DownpourWorker : public HogwildWorker {
void CollectLabelInfo(size_t table_id);
private:
bool need_to_push_dense_;
bool need_to_push_sparse_;
DownpourWorkerParameter param_;
// just save the value in param_ for easy access
std::map<uint64_t, std::string> label_var_name_;

@ -58,6 +58,9 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
skip_ops_[i] = param_.skip_ops(i);
}
need_to_push_sparse_ = param_.push_sparse();
need_to_push_dense_ = param_.push_dense();
fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config();
}
@ -239,76 +242,81 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
for (size_t i = 0; i < param_.program_config(0).push_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
if (need_to_push_sparse_) {
for (size_t i = 0;
i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
timeline.Start();
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
timeline.Pause();
push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
}
if (need_to_push_dense_) {
timeline.Start();
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
timeline.Pause();
push_sparse_time += timeline.ElapsedSec();
push_dense_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
timeline.Start();
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
timeline.Pause();
push_dense_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
VLOG(3) << "push sparse and dense gradient done.";
int32_t tmp_push_dense_wait_times = -1;
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
VLOG(3) << "push sparse and dense gradient done.";
int32_t tmp_push_dense_wait_times = -1;
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
}
push_dense_status_.resize(0);
}
push_dense_status_.resize(0);
}
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
push_sparse_status_.resize(0);
}
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
VLOG(3) << "going to increase thread version";
if (need_to_push_sparse_) {
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
}
push_sparse_status_.resize(0);
}
VLOG(3) << "push dense table id size: "
<< param_.program_config(0).push_dense_table_id_size();
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
VLOG(3) << "going to increase thread version";
VLOG(3) << "push dense table id size: "
<< param_.program_config(0).push_dense_table_id_size();
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
}
PrintFetchVars();

@ -46,6 +46,8 @@ message DownpourWorkerParameter {
repeated TableParameter dense_table = 2;
repeated string skip_ops = 3;
repeated ProgramConfig program_config = 4;
bool push_sparse = 5 [ default = true ];
bool push_dense = 6 [ default = true ];
}
message FetchConfig {

File diff suppressed because it is too large Load Diff

@ -46,10 +46,13 @@ from . import regularizer
from . import average
from . import metrics
from . import transpiler
from . import incubate
from . import distribute_lookup_table
from .param_attr import ParamAttr, WeightNormParamAttr
from .data_feeder import DataFeeder
from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope, _Scope
from .incubate import fleet
from .incubate import data_generator
from .transpiler import DistributeTranspiler, \
memory_optimize, release_memory, DistributeTranspilerConfig
from .lod_tensor import create_lod_tensor, create_random_int_lodtensor

@ -25,6 +25,10 @@ class DeviceWorker(object):
Init.
"""
self.program_ = None
self.infer_ = None
def set_infer(self, infer=False):
self.infer_ = infer
def set_fleet_desc(self, fleet_desc):
"""
@ -125,8 +129,7 @@ class DownpourSGD(DeviceWorker):
for i in self.fleet_desc_.trainer_param.dense_table:
if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(
i.dense_variable_name)
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.table_id = \
i.table_id
sparse_table = downpour.sparse_table.add()
@ -149,11 +152,13 @@ class DownpourSGD(DeviceWorker):
if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(
i.dense_variable_name)
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name)
downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)
if self.infer_:
downpour.push_dense = False
downpour.push_sparse = False
class DeviceWorkerFactory(object):

@ -612,24 +612,22 @@ class Executor(object):
def _run_inference(self, exe, feed):
return exe.run(feed)
def infer_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
pass
def train_from_dataset(self,
program=None,
dataset=None,
scope=None,
thread=0,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
def _dump_debug_info(self, program=None, trainer=None):
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(trainer._desc())
if program._fleet_opt:
with open("fleet_desc.prototxt", "w") as fout:
fout.write(str(program._fleet_opt["fleet_desc"]))
def _prepare_trainer(self,
program=None,
dataset=None,
scope=None,
thread=0,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
if scope is None:
scope = global_scope()
if fetch_list is None:
@ -648,23 +646,148 @@ class Executor(object):
if thread <= 0:
if dataset.thread_num <= 0:
raise RuntimeError(
"You should set thread num first, either in Dataset or in Executor.train_from_dataset"
)
"You should set thread num first, either in Dataset"
"or in Executor.train_from_dataset")
else:
trainer.set_thread(dataset.thread_num)
else:
trainer.set_thread(thread)
trainer.set_debug(debug)
trainer.set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return trainer
def infer_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
"""
The document of infer_from_dataset is almost the same as
train_from_dataset, except that in distributed training,
push gradients will be disabled in infer_from_dataset.
infer_from_dataset() can be used for evaluation in multi-thread
very easily.
Args:
program(Program|CompiledProgram): the program that needs to be run,
if not provided, then default_main_program (not compiled) will be used.
dataset(paddle.fluid.Dataset): dataset created outside this function,
a user should provide a well-defined dataset before calling this function.
Please check the document of Dataset if needed.
scope(Scope): the scope used to run this program, you can switch it to different scope
for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread)
debug(bool): whether a user wants to run train_from_dataset
fetch_list(Variable List): fetch variable list, each variable
will be printed during training
fetch_info(String List): print information for each variable
print_period(int): the number of mini-batches for each print
Example:
.. code-block:: python
import paddle.fluid as fluid
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x = fluid.layers.data(name="x", type="int64")
y = fluid.layers.data(name="y", type="int64")
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y])
filelist = ["dataA.txt", "dataB.txt"]
dataset.set_filelist(filelist)
exe.run(fluid.default_startup_program())
exe.infer_from_dataset(program=fluid.default_main_program(),
dataset=dataset)
"""
trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
thread=thread,
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer.gen_trainer_desc()
trainer.set_infer(True)
dataset._prepare_to_run()
if debug:
self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())
def train_from_dataset(self,
program=None,
dataset=None,
scope=None,
thread=0,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
"""
Train from a pre-defined Dataset. Dataset is defined in paddle.fluid.dataset.
Given a program, either a program or compiled program, train_from_dataset will
consume all data samples in dataset. Input scope can be given by users. By default,
scope is global_scope(). The total number of thread run in training is `thread`.
Thread number used in training will be minimum value of threadnum in Dataset and
the value of thread in this interface. Debug can be set so that executor will display
Run-Time for all operators and the throughputs of current training task.
Note: train_from_dataset will destroy all resources created within executor for each run.
Args:
program(Program|CompiledProgram): the program that needs to be run,
if not provided, then default_main_program (not compiled) will be used.
dataset(paddle.fluid.Dataset): dataset created outside this function,
a user should provide a well-defined dataset before calling this function.
Please check the document of Dataset if needed.
scope(Scope): the scope used to run this program, you can switch it to different scope
for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread)
debug(bool): whether a user wants to run train_from_dataset
fetch_list(Variable List): fetch variable list, each variable
will be printed during training
fetch_info(String List): print information for each variable
print_period(int): the number of mini-batches for each print
Example:
.. code-block:: python
import paddle.fluid as fluid
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x = fluid.layers.data(name="x", type="int64")
y = fluid.layers.data(name="y", type="int64")
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y])
dataset.set_thread(2)
filelist = ["dataA.txt", "dataB.txt"]
dataset.set_filelist(filelist)
exe.run(fluid.default_startup_program())
exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset)
"""
trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
thread=thread,
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer.gen_trainer_desc()
dataset._prepare_to_run()
if debug:
#with open("train_desc.prototxt", "w") as fout:
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(trainer._desc())
if program._fleet_opt:
with open("fleet_desc.prototxt", "w") as fout:
fout.write(str(program._fleet_opt["fleet_desc"]))
self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())

@ -35,6 +35,7 @@ class TrainerDesc(object):
self.fleet_desc_ = None
self.device_worker_ = None
self.program_ = None
self.infer_ = False
def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for i, v in enumerate(fetch_vars):
@ -52,6 +53,9 @@ class TrainerDesc(object):
def set_device_worker(self, device_worker):
self.device_worker_ = device_worker
def set_infer(self, infer):
self.infer_ = infer
def set_fleet_desc(self, fleet_desc):
self.fleet_desc_ = fleet_desc
@ -77,6 +81,7 @@ class MultiTrainer(TrainerDesc):
def gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.set_infer(self.infer_)
self.device_worker_.gen_worker_desc(self.proto_desc)
@ -94,5 +99,6 @@ class DistMultiTrainer(TrainerDesc):
self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None:
print("None program")
self.device_worker_.set_infer(self.infer_)
self.device_worker_.set_program(self.program_)
self.device_worker_.gen_worker_desc(self.proto_desc)

Loading…
Cancel
Save