|
|
|
@ -7,17 +7,10 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig
|
|
|
|
|
from . import optimizer as v2_optimizer
|
|
|
|
|
from . import parameters as v2_parameters
|
|
|
|
|
|
|
|
|
|
__all__ = ['ITrainer', 'SGDTrainer', 'CompleteTrainOneBatch', 'BaseEvent']
|
|
|
|
|
__all__ = ['ITrainer', 'SGDTrainer', 'EndIteration']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseEvent(object):
|
|
|
|
|
"""
|
|
|
|
|
Just a marker class
|
|
|
|
|
"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CompleteTrainOneBatch(BaseEvent):
|
|
|
|
|
class EndIteration(object):
|
|
|
|
|
"""
|
|
|
|
|
Event On One Batch Training Complete.
|
|
|
|
|
"""
|
|
|
|
@ -117,7 +110,7 @@ class SGDTrainer(ITrainer):
|
|
|
|
|
cost = cost_vec.sum() / len(data_batch)
|
|
|
|
|
updater.finishBatch(cost)
|
|
|
|
|
event_handler(
|
|
|
|
|
CompleteTrainOneBatch(
|
|
|
|
|
EndIteration(
|
|
|
|
|
pass_id=pass_id, batch_id=batch_id, cost=cost))
|
|
|
|
|
|
|
|
|
|
updater.finishPass()
|
|
|
|
|