|
|
|
@ -1,4 +1,6 @@
|
|
|
|
|
import collections
|
|
|
|
|
import gzip
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
import py_paddle.swig_paddle as api
|
|
|
|
|
|
|
|
|
@ -96,6 +98,18 @@ class SGD(object):
|
|
|
|
|
self.__gradient_machine__.prefetch(in_args)
|
|
|
|
|
self.__parameter_updater__.getParametersRemote()
|
|
|
|
|
|
|
|
|
|
def save_parameter(self, dir_name, file_name):
|
|
|
|
|
if not os.path.exists(dir_name):
|
|
|
|
|
os.makedirs(dir_name)
|
|
|
|
|
param_file_name = dir_name + "/" + file_name + '.tar.gz'
|
|
|
|
|
assert not os.path.exists(param_file_name)
|
|
|
|
|
self.__parameter_updater__.catchUpWith()
|
|
|
|
|
self.__parameter_updater__.apply()
|
|
|
|
|
self.__parameter_updater__.getParametersRemote(True, True)
|
|
|
|
|
with gzip.open(param_file_name, 'w') as f:
|
|
|
|
|
self.__parameters__.to_tar(f)
|
|
|
|
|
self.__parameter_updater__.restore()
|
|
|
|
|
|
|
|
|
|
def train(self, reader, num_passes=1, event_handler=None, feeding=None):
|
|
|
|
|
"""
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
|