Trainer save load params (#10386)

* Load/save the params from the params_path

* Switch to use load_persistables and save_persistables

* Instaed of setup the executor to run program and scope. Pass the program to the load_persistables
simplify_fluid_api_recognize_digit
Jeff Wang 7 years ago committed by GitHub
parent 5812076e7d
commit bd66eed50a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,7 +13,9 @@
# limitations under the License.
import core
import framework
import executor
import io
__all__ = ['Inferencer', ]
@ -29,6 +31,15 @@ class Inferencer(object):
# 4. load params from param_path into scope
self.scope = core.Scope()
self.place = place
self.startup_program = framework.Program()
# TODO: generate the startup_program with network_func
exe = executor.Executor(place)
exe.run(self.startup_program, scope=self.scope)
if param_path:
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
def infer(self, inputs):
# run self.program

@ -18,6 +18,7 @@ import framework
import executor
import data_feeder
import contextlib
import io
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module
@ -93,8 +94,7 @@ class Trainer(object):
if param_path:
# load params from param_path into scope
# TODO(yuyang): This depends on parameters implementation.
pass
io.load_persistables(exe, dirname=param_path)
def dist_transpile_if_necessary(self, optimize_ops, params_grads):
if "PADDLE_TRAINING_ROLE" not in os.environ:
@ -172,7 +172,9 @@ class Trainer(object):
def save_params(self, param_path):
# reference: save_persistables in io.py
pass
exe = executor.Executor(self.place)
io.save_persistables(
exe, dirname=param_path, main_program=self.startup_program)
@staticmethod
def _check_and_get_place(place):

Loading…
Cancel
Save