support patch data, add load_one_table, fix bug (#18509)

(1)support patch data (merge slots of instances of same line id, modify dense layer which
changes its size)
(2)add fleet load_one_table interface, support load from paddle model and load from pslib model
(3)fix push sparse bug which cause push sparse cost more time(about 10% in my testcase)
(4)when some slots are not in one of your network (join/update, etc.),data feed、collect label info、push/pull sparse will skip these slots, instead of throw error.
(5)add more debug info in TrainFilesWithProfiler
DDDivano-patch-1
jiaqi 6 years ago committed by GitHub
parent fd3aad6cb3
commit d18aabb472
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,7 +42,11 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) {
feed_vec_[i] = var->GetMutable<LoDTensor>();
if (var == nullptr) {
feed_vec_[i] = nullptr;
} else {
feed_vec_[i] = var->GetMutable<LoDTensor>();
}
}
}
}
@ -164,6 +168,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->fp_ = nullptr;
this->thread_id_ = 0;
this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->input_channel_ = nullptr;
this->output_channel_ = nullptr;
this->consume_channel_ = nullptr;
@ -247,6 +252,11 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
template <typename T>
void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX
@ -591,6 +601,9 @@ void MultiSlotDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) {
#ifdef _LINUX
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
@ -684,6 +697,18 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
// VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
if (parse_ins_id_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
instance->ins_id_ = std::string(str + pos, len);
pos += len + 1;
VLOG(3) << "ins_id " << instance->ins_id_;
}
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
@ -699,7 +724,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
// if float feasign is equal to zero, ignore it
if (fabs(feasign) < 1e-6) {
// except when slot is dense
if (fabs(feasign) < 1e-6 && !use_slots_is_dense_[i]) {
continue;
}
FeatureKey f;
@ -710,7 +736,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
// if uint64 feasign is equal to zero, ignore it
if (feasign == 0) {
// except when slot is dense
if (feasign == 0 && !use_slots_is_dense_[i]) {
continue;
}
FeatureKey f;
@ -838,6 +865,9 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
}
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
int total_instance = offset[i].back();
const auto& type = all_slots_type_[i];
if (type[0] == 'f') { // float

@ -102,6 +102,8 @@ class DataFeed {
virtual void SetThreadId(int thread_id) {}
// This function will do nothing at default
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
@ -212,6 +214,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetConsumeChannel(void* channel);
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void LoadIntoMemory();
protected:
@ -221,6 +224,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_;
int thread_num_;
bool parse_ins_id_;
std::ifstream file_;
std::shared_ptr<FILE> fp_;
paddle::framework::ChannelObject<T>* input_channel_;

File diff suppressed because it is too large Load Diff

@ -57,6 +57,10 @@ class Dataset {
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// set channel num
virtual void SetChannelNum(int channel_num) = 0;
// set merge by ins id
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins) = 0;
// get file list
virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num
@ -98,6 +102,8 @@ class Dataset {
virtual int64_t GetMemoryDataSize() = 0;
// get shuffle data size
virtual int64_t GetShuffleDataSize() = 0;
// merge by ins id
virtual void MergeByInsId() = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
@ -120,6 +126,9 @@ class DatasetImpl : public Dataset {
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual void SetChannelNum(int channel_num);
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
@ -145,6 +154,7 @@ class DatasetImpl : public Dataset {
virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
virtual int64_t GetShuffleDataSize();
virtual void MergeByInsId() {}
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
@ -169,12 +179,18 @@ class DatasetImpl : public Dataset {
int64_t fleet_send_batch_size_;
int64_t fleet_send_sleep_seconds_;
std::vector<std::thread> preload_threads_;
bool merge_by_insid_;
bool erase_duplicate_feas_;
bool keep_unmerged_ins_;
int min_merge_size_;
std::vector<std::string> merge_slots_list_;
};
// use std::vector<MultiSlotType> as data type
// use std::vector<MultiSlotType> or Record as data type
class MultiSlotDataset : public DatasetImpl<Record> {
public:
MultiSlotDataset() {}
virtual void MergeByInsId();
virtual ~MultiSlotDataset() {}
};

@ -89,7 +89,12 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
VLOG(3) << "sparse_key_names_[" << i
<< "]: " << sparse_key_names_[table_id][i];
Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]);
if (fea_var == nullptr) {
continue;
}
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var "
<< sparse_key_names_[table_id][i] << " is null";
int64_t* ids = tensor->data<int64_t>();
size_t fea_idx = 0;
// tensor->lod()[0].size() == batch_size + 1
@ -128,7 +133,11 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
std::string slot_name = sparse_key_names_[table_id][i];
std::string emb_slot_name = sparse_value_names_[table_id][i];
Variable* var = thread_scope_->FindVar(slot_name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << slot_name << " is null";
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
Variable* var_emb = thread_scope_->FindVar(emb_slot_name);
@ -198,6 +207,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
int cur_batch;
int batch_cnt = 0;
uint64_t total_inst = 0;
double op_sum_time = 0;
std::unordered_map<std::string, double> op_to_time;
timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) {
timeline.Pause();
@ -346,7 +357,27 @@ void DownpourWorker::TrainFilesWithProfiler() {
for (size_t i = 0; i < op_total_time.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt);
if (op_to_time.find(op_name[i]) == op_to_time.end()) {
op_to_time[op_name[i]] = 0.0;
}
op_to_time[op_name[i]] += op_total_time[i];
op_sum_time += op_total_time[i];
}
for (auto& i : op_to_time) {
fprintf(stderr, "op [%s] run total time: [%f]ms\n", i.first.c_str(),
i.second / batch_cnt);
}
fprintf(stderr, "op run total time: %fs\n", op_sum_time / batch_cnt);
fprintf(stderr, "train total time: %fs\n", total_time / batch_cnt);
fprintf(stderr, "pull sparse time: %fs\n",
pull_sparse_time / batch_cnt);
fprintf(stderr, "fill sparse time: %fs\n",
fill_sparse_time / batch_cnt);
fprintf(stderr, "push sparse time: %fs\n",
push_sparse_time / batch_cnt);
fprintf(stderr, "push dense time: %fs\n", push_dense_time / batch_cnt);
fprintf(stderr, "collect label time: %fs\n",
collect_label_time / batch_cnt);
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "pull sparse time percent: %f\n",

@ -27,8 +27,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
@ -156,7 +158,11 @@ void FleetWrapper::PullSparseVarsSync(
fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
@ -291,29 +297,34 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
grad_dim = emb_dim - 2;
}
CHECK_GE(grad_dim, 0);
push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) {
t.resize(emb_dim + offset);
}
uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
exit(-1);
}
float* g = g_tensor->data<float>();
Variable* var = scope.FindVar(sparse_key_names[i]);
CHECK(var != nullptr) << "var[" << sparse_key_names[i] << "] not found";
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
if (tensor == nullptr) {
LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null";
exit(-1);
}
int len = tensor->numel();
int64_t* ids = tensor->data<int64_t>();
push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) {
t.resize(emb_dim + offset);
Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == nullptr) {
LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null";
exit(-1);
}
float* g = g_tensor->data<float>();
if (scale_sparse_gradient_with_batch_size_ && grad_dim > 0) {
int dim = emb_dim + offset;
Eigen::Map<
@ -355,6 +366,79 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
void FleetWrapper::LoadFromPaddleModel(Scope& scope, const uint64_t table_id,
std::vector<std::string> var_list,
std::string model_path,
std::string model_proto_file,
bool load_combine) {
// load ProgramDesc from model file
auto read_proto_func = [](const std::string& filename) -> ProgramDesc {
std::string contents;
std::ifstream fin(filename, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end);
contents.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&contents[0], contents.size());
fin.close();
ProgramDesc program_desc(contents);
return program_desc;
};
const ProgramDesc old_program = read_proto_func(model_proto_file);
Scope* old_scope = new Scope();
auto& old_block = old_program.Block(0);
auto place = platform::CPUPlace();
std::vector<std::string> old_param_list;
for (auto& t : var_list) {
VarDesc* old_var_desc = old_block.FindVar(t);
if (old_var_desc == nullptr) {
continue;
}
// init variable in scope
Variable* old_var = old_scope->Var(old_var_desc->Name());
InitializeVariable(old_var, old_var_desc->GetType());
old_param_list.push_back(t);
if (load_combine) {
continue;
}
// load variable from model
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", model_path + "/" + old_var_desc->Name()});
auto load_op = paddle::framework::OpRegistry::CreateOp(
"load", {}, {{"Out", {old_var_desc->Name()}}}, attrs);
load_op->Run(*old_scope, place);
}
if (load_combine) {
std::sort(old_param_list.begin(), old_param_list.end());
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", model_path});
auto load_op = paddle::framework::OpRegistry::CreateOp(
"load_combine", {}, {{"Out", old_param_list}}, attrs);
load_op->Run(*old_scope, place);
}
for (auto& t : old_param_list) {
Variable* old_var = old_scope->Var(t);
// old model data, here we assume data type is float
LoDTensor* old_tensor = old_var->GetMutable<LoDTensor>();
float* old_data = old_tensor->data<float>();
// new model data, here we assume data type is float
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* data = tensor->data<float>();
// copy from old data to new data
if (old_tensor->numel() > tensor->numel()) {
memcpy(data, old_data, tensor->numel() * sizeof(float));
} else {
memcpy(data, old_data, old_tensor->numel() * sizeof(float));
}
}
delete old_scope;
PushDenseParamSync(scope, table_id, old_param_list);
}
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->load(path, std::to_string(mode));
@ -368,6 +452,21 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {
#endif
}
void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret =
pslib_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id
<< ", from path: " << path << " failed";
}
#else
VLOG(0) << "FleetWrapper::LoadModel does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode));

@ -131,9 +131,18 @@ class FleetWrapper {
// flush all push requests
void ClientFlush();
// load from paddle model
void LoadFromPaddleModel(Scope& scope, const uint64_t table_id, // NOLINT
std::vector<std::string> var_list,
std::string model_path, std::string model_proto_file,
bool load_combine);
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);

@ -64,7 +64,7 @@ void HogwildWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed =
device_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
device_reader_->AddFeedVar(thread_scope_->Var(name), name);
device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
}
}

@ -99,6 +99,10 @@ void BindDataset(py::module* m) {
.def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize,
py::call_guard<py::gil_scoped_release>())
.def("set_queue_num", &framework::Dataset::SetChannelNum,
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId,
py::call_guard<py::gil_scoped_release>())
.def("merge_by_lineid", &framework::Dataset::MergeByInsId,
py::call_guard<py::gil_scoped_release>());
}

@ -57,7 +57,10 @@ void BindFleetWrapper(py::module* m) {
&framework::FleetWrapper::CreateClient2ClientConnection)
.def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable)
.def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable)
.def("client_flush", &framework::FleetWrapper::ClientFlush);
.def("client_flush", &framework::FleetWrapper::ClientFlush)
.def("load_from_paddle_model",
&framework::FleetWrapper::LoadFromPaddleModel)
.def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable);
} // end FleetWrapper
} // end namespace pybind
} // end namespace paddle

@ -237,6 +237,7 @@ class InMemoryDataset(DatasetBase):
self.proto_desc.name = "MultiSlotInMemoryDataFeed"
self.fleet_send_batch_size = 80000
self.queue_num = None
self.merge_by_lineid = False
def _prepare_to_run(self):
"""
@ -258,7 +259,7 @@ class InMemoryDataset(DatasetBase):
Set Dataset output queue num, training threads get data from queues
Args:
set_queue_num(int): dataset output queue num
queue_num(int): dataset output queue num
Examples:
.. code-block:: python
@ -287,6 +288,40 @@ class InMemoryDataset(DatasetBase):
"""
self.fleet_send_batch_size = fleet_send_batch_size
def set_merge_by_lineid(self,
var_list,
erase_duplicate_feas=True,
min_merge_size=2,
keep_unmerged_ins=True):
"""
Set merge by line id, instances of same line id will be merged after
shuffle, you should parse line id in data generator.
Args:
var_list(list): slots that can be merge. each element in var_list
is Variable. some slots such as show and click, we
usually don't merge them for same line id, so user
should specify which slot can be merged.
erase_duplicate_feas(bool): whether erase duplicate feasigns when
merge. default is True.
min_merge_size(int): minimal size to merge. default is 2.
keep_unmerged_ins(bool): whether to keep unmerged ins, such as
ins with unique id or the num of ins with
same id is less than min_merge_size.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_lineid()
"""
var_name_list = [i.name for i in var_list]
self.dataset.set_merge_by_lineid(var_name_list, erase_duplicate_feas,
min_merge_size, keep_unmerged_ins)
self.merge_by_lineid = True
def load_into_memory(self):
"""
Load data into memory
@ -386,6 +421,10 @@ class InMemoryDataset(DatasetBase):
self.dataset.global_shuffle()
if fleet is not None:
fleet._role_maker._barrier_worker()
if self.merge_by_lineid:
self.dataset.merge_by_lineid()
if fleet is not None:
fleet._role_maker._barrier_worker()
def release_memory(self):
"""
@ -530,6 +569,9 @@ class QueueDataset(DatasetBase):
Global shuffle is not supported in QueueDataset
NotImplementedError will be raised
Args:
fleet(Fleet): fleet singleton. Default None.
Examples:
.. code-block:: python
@ -547,9 +589,12 @@ class QueueDataset(DatasetBase):
class FileInstantDataset(DatasetBase):
"""
FileInstantDataset, it will process data streamly.
Example:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory.create_dataset("FileInstantDataset")
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory.create_dataset("FileInstantDataset")
"""
def __init__(self):
@ -561,8 +606,7 @@ class FileInstantDataset(DatasetBase):
def local_shuffle(self):
"""
Local shuffle
FileInstantDataset does not support local shuffle
Local shuffle, FileInstantDataset does not support local shuffle
"""
raise NotImplementedError(
"FileInstantDataset does not support local shuffle, "

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import os
import sys
from optimizer_factory import *
from google.protobuf import text_format
@ -191,6 +192,7 @@ class PSLib(Fleet):
when using fleet, it will save sparse and dense feature
Args:
executor(Executor): fluid executor
dirname(str): save path. It can be hdfs/afs path or local path
main_program(Program): fluid program, default None
kwargs: use define property, current support following
@ -261,6 +263,115 @@ class PSLib(Fleet):
decay)
self._role_maker._barrier_worker()
def load_one_table(self, table_id, model_path, **kwargs):
"""
load pslib model for one table or load params from paddle model
Args:
table_id(int): load table id
model_path(str): load model path, can be local or hdfs/afs path
kwargs(dict): user defined params, currently support following:
only for load pslib model for one table:
mode(int): load model mode. 0 is for load whole model, 1 is
for load delta model (load diff), default is 0.
only for load params from paddle model:
scope(Scope): Scope object
model_proto_file(str): path of program desc proto binary
file, can be local or hdfs/afs file
load_combine(bool): load from a file or splited param files
default False.
Examples:
.. code-block:: python
# load pslib model for one table
fleet.load_one_table(0, "hdfs:/my_fleet_model/20190714/0/")
fleet.load_one_table(1, "hdfs:/xx/xxx", mode = 0)
# load params from paddle model
fleet.load_one_table(2, "hdfs:/my_paddle_model/",
scope = my_scope,
model_proto_file = "./my_program.bin",
load_combine = False)
# below is how to save proto binary file
with open("my_program.bin", "wb") as fout:
my_program = fluid.default_main_program()
fout.write(my_program.desc.serialize_to_string())
"""
mode = kwargs.get("mode", 0)
scope = kwargs.get("scope", None)
model_proto_file = kwargs.get("model_proto_file", None)
load_combine = kwargs.get("load_combine", False)
if scope is not None and model_proto_file is not None:
self._load_one_table_from_paddle_model(
scope, table_id, model_path, model_proto_file, load_combine)
else:
self._fleet_ptr.load_model_one_table(table_id, model_path, mode)
def _load_one_table_from_paddle_model(self,
scope,
table_id,
model_path,
model_proto_file,
load_combine=False):
"""
load params from paddle model, and push params to pserver
Args:
scope(Scope): Scope object
table_id(int): the id of table to load
model_path(str): path of paddle model, can be local or hdfs/afs file
model_proto_file(str): path of program desc proto binary file,
can be local or hdfs/afs file
load_combine(bool): load from a file or splited param files
"""
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
# get fs config from fleet_desc
fs_name = self._opt_info["fleet_desc"].fs_client_param.uri
fs_ugi = self._opt_info["fleet_desc"].fs_client_param.user + "," + \
self._opt_info["fleet_desc"].fs_client_param.passwd
hadoop_bin = self._opt_info["fleet_desc"].fs_client_param.hadoop_bin
# download model_path if it's hdfs/afs
if model_path.startswith("hdfs:") or model_path.startswith("afs:"):
dest = "./model_for_load_table_%s" % table_id
cmd = hadoop_bin + " fs -D fs.default.name=" + fs_name + \
" -D hadoop.job.ugi=" + fs_ugi + " -get " + model_path + \
" " + dest
ret = os.system(cmd)
if ret != 0:
raise RuntimeError("download model failed")
model_path = dest
# download model_proto_file if it's hdfs/afs
if model_proto_file.startswith("hdfs:") or \
model_proto_file.startswith("afs:"):
dest = "./model_proto_file_for_load_table_%s" % table_id
cmd = hadoop_bin + " fs -D fs.default.name=" + fs_name + \
" -D hadoop.job.ugi=" + fs_ugi + " -get " + \
model_proto_file + " " + dest
ret = os.system(cmd)
if ret != 0:
raise RuntimeError("download model proto file failed")
model_proto_file = dest
for i in self._opt_info["fleet_desc"].trainer_param.dense_table:
if table_id is not None and table_id != i.table_id:
continue
var_list = [var for var in i.dense_variable_name]
skip = False
for var in var_list:
if scope.find_var(var) is None:
skip = True
break
if skip:
continue
self._fleet_ptr.load_from_paddle_model(
scope, table_id, var_list, model_path, model_proto_file,
load_combine)
self._role_maker._barrier_worker()
def _set_opt_info(self, opt_info):
"""
this function saves the result from DistributedOptimizer.minimize()

@ -76,7 +76,8 @@ class TrainerDesc(object):
return self.proto_desc.SerializeToString()
def __str__(self):
return str(self.proto_desc)
from google.protobuf import text_format
return text_format.MessageToString(self.proto_desc)
class MultiTrainer(TrainerDesc):

Loading…
Cancel
Save