|
|
@ -1,4 +1,3 @@
|
|
|
|
import py_paddle.swig_paddle as swig_api
|
|
|
|
|
|
|
|
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
|
|
|
|
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
|
|
|
|
import paddle.trainer_config_helpers.optimizers as v1_optimizers
|
|
|
|
import paddle.trainer_config_helpers.optimizers as v1_optimizers
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -17,6 +16,7 @@ __all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
class Optimizer(object):
|
|
|
|
class Optimizer(object):
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
|
|
|
|
import py_paddle.swig_paddle as swig_api
|
|
|
|
if 'batch_size' in kwargs:
|
|
|
|
if 'batch_size' in kwargs:
|
|
|
|
del kwargs['batch_size'] # not important for python library.
|
|
|
|
del kwargs['batch_size'] # not important for python library.
|
|
|
|
|
|
|
|
|
|
|
@ -25,8 +25,6 @@ class Optimizer(object):
|
|
|
|
|
|
|
|
|
|
|
|
self.__opt_conf_proto__ = config_parser_utils.parse_optimizer_config(
|
|
|
|
self.__opt_conf_proto__ = config_parser_utils.parse_optimizer_config(
|
|
|
|
__impl__)
|
|
|
|
__impl__)
|
|
|
|
if swig_api is None:
|
|
|
|
|
|
|
|
raise RuntimeError("paddle.v2 currently need swig_paddle")
|
|
|
|
|
|
|
|
self.__opt_conf__ = swig_api.OptimizationConfig.createFromProto(
|
|
|
|
self.__opt_conf__ = swig_api.OptimizationConfig.createFromProto(
|
|
|
|
self.__opt_conf_proto__)
|
|
|
|
self.__opt_conf_proto__)
|
|
|
|
|
|
|
|
|
|
|
@ -37,18 +35,22 @@ class Optimizer(object):
|
|
|
|
For each optimizer(SGD, Adam), GradientMachine should enable different
|
|
|
|
For each optimizer(SGD, Adam), GradientMachine should enable different
|
|
|
|
buffers.
|
|
|
|
buffers.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
import py_paddle.swig_paddle as swig_api
|
|
|
|
tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__)
|
|
|
|
tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__)
|
|
|
|
assert isinstance(tmp, swig_api.ParameterOptimizer)
|
|
|
|
assert isinstance(tmp, swig_api.ParameterOptimizer)
|
|
|
|
return tmp.getParameterTypes()
|
|
|
|
return tmp.getParameterTypes()
|
|
|
|
|
|
|
|
|
|
|
|
def __create_local_updater__(self):
|
|
|
|
def __create_local_updater__(self):
|
|
|
|
|
|
|
|
import py_paddle.swig_paddle as swig_api
|
|
|
|
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
|
|
|
|
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
|
|
|
|
|
|
|
|
|
|
|
|
def __create_remote_updater__(self, pass_num, use_sparse_updater):
|
|
|
|
def __create_remote_updater__(self, pass_num, use_sparse_updater):
|
|
|
|
|
|
|
|
import py_paddle.swig_paddle as swig_api
|
|
|
|
return swig_api.ParameterUpdater.createRemoteUpdater(
|
|
|
|
return swig_api.ParameterUpdater.createRemoteUpdater(
|
|
|
|
self.__opt_conf__, pass_num, use_sparse_updater)
|
|
|
|
self.__opt_conf__, pass_num, use_sparse_updater)
|
|
|
|
|
|
|
|
|
|
|
|
def __create_new_remote_updater__(self, pserver_spec, use_etcd):
|
|
|
|
def __create_new_remote_updater__(self, pserver_spec, use_etcd):
|
|
|
|
|
|
|
|
import py_paddle.swig_paddle as swig_api
|
|
|
|
return swig_api.ParameterUpdater.createNewRemoteUpdater(
|
|
|
|
return swig_api.ParameterUpdater.createNewRemoteUpdater(
|
|
|
|
self.__opt_conf__, pserver_spec, use_etcd)
|
|
|
|
self.__opt_conf__, pserver_spec, use_etcd)
|
|
|
|
|
|
|
|
|
|
|
|