|
|
|
@ -1,12 +1,13 @@
|
|
|
|
|
import collections
|
|
|
|
|
|
|
|
|
|
import py_paddle.swig_paddle as api
|
|
|
|
|
from paddle.proto.ModelConfig_pb2 import ModelConfig
|
|
|
|
|
from py_paddle import DataProviderConverter
|
|
|
|
|
|
|
|
|
|
from paddle.proto.ModelConfig_pb2 import ModelConfig
|
|
|
|
|
from . import event as v2_event
|
|
|
|
|
from . import layer as v2_layer
|
|
|
|
|
from . import optimizer as v2_optimizer
|
|
|
|
|
from . import parameters as v2_parameters
|
|
|
|
|
from . import event as v2_event
|
|
|
|
|
|
|
|
|
|
__all__ = ['ITrainer', 'SGD']
|
|
|
|
|
|
|
|
|
@ -73,7 +74,7 @@ class SGD(ITrainer):
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
|
|
|
|
|
|
:param train_data_reader:
|
|
|
|
|
:param topology: Network Topology, a protobuf ModelConfig message.
|
|
|
|
|
:param topology: Network Topology, use one or more Layers to represent it.
|
|
|
|
|
:param parameters: The parameter pools.
|
|
|
|
|
:param num_passes: The total train passes.
|
|
|
|
|
:param test_data_reader:
|
|
|
|
@ -87,6 +88,8 @@ class SGD(ITrainer):
|
|
|
|
|
if event_handler is None:
|
|
|
|
|
event_handler = default_event_handler
|
|
|
|
|
|
|
|
|
|
topology = v2_layer.parse_network(topology)
|
|
|
|
|
|
|
|
|
|
__check_train_args__(**locals())
|
|
|
|
|
|
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
|
|