|
|
@ -1,13 +1,12 @@
|
|
|
|
import collections
|
|
|
|
import collections
|
|
|
|
|
|
|
|
|
|
|
|
import py_paddle.swig_paddle as api
|
|
|
|
import py_paddle.swig_paddle as api
|
|
|
|
from py_paddle import DataProviderConverter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from data_feeder import DataFeeder
|
|
|
|
from data_feeder import DataFeeder
|
|
|
|
|
|
|
|
from topology import Topology
|
|
|
|
from . import event as v2_event
|
|
|
|
from . import event as v2_event
|
|
|
|
from . import optimizer as v2_optimizer
|
|
|
|
from . import optimizer as v2_optimizer
|
|
|
|
from . import parameters as v2_parameters
|
|
|
|
from . import parameters as v2_parameters
|
|
|
|
from . import topology as v2_topology
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['ITrainer', 'SGD']
|
|
|
|
__all__ = ['ITrainer', 'SGD']
|
|
|
|
|
|
|
|
|
|
|
@ -69,7 +68,6 @@ class SGD(ITrainer):
|
|
|
|
test_data_reader=None,
|
|
|
|
test_data_reader=None,
|
|
|
|
event_handler=None,
|
|
|
|
event_handler=None,
|
|
|
|
batch_size=32,
|
|
|
|
batch_size=32,
|
|
|
|
data_types=None,
|
|
|
|
|
|
|
|
reader_dict=None):
|
|
|
|
reader_dict=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
@ -83,13 +81,12 @@ class SGD(ITrainer):
|
|
|
|
occurred.
|
|
|
|
occurred.
|
|
|
|
:type event_handler: (BaseEvent) => None
|
|
|
|
:type event_handler: (BaseEvent) => None
|
|
|
|
:param batch_size: Not important, will be removed after data refactor.
|
|
|
|
:param batch_size: Not important, will be removed after data refactor.
|
|
|
|
:param data_types: Not important, will be removed after data refactor.
|
|
|
|
|
|
|
|
:return:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if event_handler is None:
|
|
|
|
if event_handler is None:
|
|
|
|
event_handler = default_event_handler
|
|
|
|
event_handler = default_event_handler
|
|
|
|
|
|
|
|
|
|
|
|
topology = v2_topology.Topology(topology)
|
|
|
|
topology = Topology(topology)
|
|
|
|
|
|
|
|
|
|
|
|
__check_train_args__(**locals())
|
|
|
|
__check_train_args__(**locals())
|
|
|
|
|
|
|
|
|
|
|
@ -109,10 +106,7 @@ class SGD(ITrainer):
|
|
|
|
assert isinstance(pass_evaluator, api.Evaluator)
|
|
|
|
assert isinstance(pass_evaluator, api.Evaluator)
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|
|
|
|
|
|
|
|
data_types_lists = [data_type[1] for data_type in topology.data_type()]
|
|
|
|
feeder = DataFeeder(topology.data_type(), reader_dict)
|
|
|
|
converter = DataProviderConverter(input_types=data_types_lists)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feeder = DataFeeder(data_types, reader_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for pass_id in xrange(num_passes):
|
|
|
|
for pass_id in xrange(num_passes):
|
|
|
|
event_handler(v2_event.BeginPass(pass_id))
|
|
|
|
event_handler(v2_event.BeginPass(pass_id))
|
|
|
@ -195,7 +189,7 @@ def __check_train_args__(train_data_reader, topology, parameters,
|
|
|
|
raise ValueError('test_data_reader should be a function, which can '
|
|
|
|
raise ValueError('test_data_reader should be a function, which can '
|
|
|
|
'return a iterator')
|
|
|
|
'return a iterator')
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(topology, v2_topology.Topology):
|
|
|
|
if not isinstance(topology, Topology):
|
|
|
|
raise ValueError('topology should be a model config')
|
|
|
|
raise ValueError('topology should be a model config')
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(parameters, v2_parameters.Parameters):
|
|
|
|
if not isinstance(parameters, v2_parameters.Parameters):
|
|
|
|