checkpoint notify

port
tangwei12 7 years ago
parent 30880844bb
commit ae12281d9b

@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -36,12 +37,14 @@ class CheckpointNotifyOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir"); std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table");
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < epmap.size(); i++) { for (size_t i = 0; i < epmap.size(); i++) {
VLOG(3) << "sending to " << epmap[i] << " to checkpoint notify ... "; VLOG(3) << "sending " << dir <<" to " << epmap[i] << " to checkpoint notify ... ";
rpc_client->AsyncCheckpointNotify(epmap[i], dir); auto serial_looku_table = string::Sprintf("%s/%s.%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table);
} }
rpc_client->Wait(); rpc_client->Wait();
} }
@ -57,6 +60,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({"127.0.0.1:6164"}); .SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>( AddAttr<std::string>(
"dir", "(string, default '') indicate the folder checkpoint will use"); "dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>(
"lookup_table", "(string, default '') the lookup table name");
AddComment(R"DOC( AddComment(R"DOC(
Prefetch operator Prefetch operator

@ -208,11 +208,14 @@ class RequestCheckpointNotify final : public RequestBase {
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
std::string checkpoint_notify = request_->Varname(); std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->Varname(); std::string checkpoint_dir = request_->OutVarname();
framework::Variable* invar = nullptr; framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, invar, &outvar, request_handler_->Handle(checkpoint_notify, scope, invar, &outvar,
checkpoint_dir); checkpoint_dir);
Finish(reply_, &responder_); Finish(reply_, &responder_);

@ -22,6 +22,7 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -124,6 +125,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const std::string& out_var_name) { const std::string& out_var_name) {
auto lt_varname = string::Sprintf("%s.path", varname);
auto *lt_var = scope->FindVar(lt_varname)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update " << lt_varname << " to: " << out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true; return true;
} }

@ -87,7 +87,7 @@ class SaveOp : public framework::OperatorBase {
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, var); SaveLodTensor(filename, place, var);
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(filename, place, var); SaveSelectedRows(scope, place, var);
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
false, false,
@ -128,9 +128,17 @@ class SaveOp : public framework::OperatorBase {
fout.close(); fout.close();
} }
void SaveSelectedRows(const std::string &filename, void SaveSelectedRows(const framework::Scope &scope,
const platform::Place &place, const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto lt_varname = string::Sprintf("%s.path", Input("X"));
auto *lt_var = scope.FindVar(lt_varname)->GetMutable<std::string>();
PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable %s for SaveSelectedRows",
lt_varname);
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
auto &selectedRows = var->Get<framework::SelectedRows>(); auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool // get device context from pool

@ -471,7 +471,10 @@ def save_checkpoint(executor,
trainer_id, trainer_id,
trainer_args=None, trainer_args=None,
main_program=None, main_program=None,
max_num_checkpoints=3): max_num_checkpoints=3,
lookup_table=None,
ps_endpoint_list=None
):
""" """
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
@ -500,7 +503,7 @@ def save_checkpoint(executor,
if trainer_id == 0: if trainer_id == 0:
save_persist_vars_without_grad(executor, cur_dir, main_program) save_persist_vars_without_grad(executor, cur_dir, main_program)
save_pserver_vars_by_notify(executor, cur_dir, "") save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table)
_scroll_delete(checkpoint_dir, max_num_checkpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints)
@ -600,7 +603,7 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir) _write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, epmap): def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list):
""" """
""" """
cur_dir = _get_lookuptable_dir(dirname) cur_dir = _get_lookuptable_dir(dirname)
@ -609,11 +612,12 @@ def save_pserver_vars_by_notify(executor, dirname, epmap):
checkpoint_notify_block = checkpoint_notify_program.global_block() checkpoint_notify_block = checkpoint_notify_program.global_block()
attrs = {} attrs = {}
attrs['epmap'] = None attrs['epmap'] = ps_endpoint_list
attrs['dir'] = cur_dir attrs['dir'] = cur_dir
attrs['lookup_table'] = lookup_table
checkpoint_notify_block.append_op( checkpoint_notify_block.append_op(
type='checkpoint_notify', inputs={}, output={}, attrs=attrs) type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(checkpoint_notify_program) executor.run(checkpoint_notify_program)
@ -783,3 +787,4 @@ def get_latest_checkpoint_serial(checkpoint_dir):
if success_num > current_dir: if success_num > current_dir:
current_dir = success_num current_dir = success_num
return current_dir return current_dir

@ -838,13 +838,15 @@ class DistributeTranspiler:
""" """
import os import os
pserver_program.global_block().create_var(name="%s.path"%self.table_name, persistable=True, type=core.VarDesc.VarType.RAW)
checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block = pserver_program.create_block(pre_block_idx)
checkpoint_save_block.append_op( checkpoint_save_block.append_op(
type='save', type='save',
inputs={'X': [self.table_name]}, inputs={'X': [self.table_name]},
outputs={}, outputs={},
attrs={ attrs={
'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name) 'file_path': self.table_name)
}) })
return checkpoint_save_block.idx return checkpoint_save_block.idx

Loading…
Cancel
Save