add checkpoint_notify in python

port
tangwei12 7 years ago
parent 1c2e9bdd49
commit 985026ce42

@ -76,7 +76,7 @@ class CheckpointNotifyOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpointnotify, ops::CheckpointNotifyOp,
REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp,
paddle::framework::EmptyGradOpMaker,
ops::CheckpointNotifyOpMaker,
ops::CheckpointNotifyOpShapeInference);

@ -382,7 +382,7 @@ class Operator(object):
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine',
'ncclInit', 'channel_create', 'channel_close', 'channel_send',
'channel_recv', 'select'
'channel_recv', 'select', 'checkpoint_notify'
}
def __init__(self,

@ -613,7 +613,7 @@ def save_pserver_vars_by_notify(executor, dirname, epmap):
attrs['dir'] = cur_dir
checkpoint_notify_block.append_op(
type='checkpointnotify', inputs={}, output={}, attrs=attrs)
type='checkpoint_notify', inputs={}, output={}, attrs=attrs)
executor.run(checkpoint_notify_program)

Loading…
Cancel
Save