fix grad debug event (#4536)

revert-4814-Add_sequence_project_op
武毅 8 years ago committed by GitHub
parent c3b46d1683
commit 3f874143fe

@ -10,7 +10,8 @@ There are:
* EndPass
"""
__all__ = [
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult',
'EndForwardBackward'
]
@ -73,6 +74,17 @@ class BeginIteration(object):
self.batch_id = batch_id
class EndForwardBackward(object):
"""
Event On One Batch ForwardBackward Complete.
"""
def __init__(self, pass_id, batch_id, gm):
self.pass_id = pass_id
self.batch_id = batch_id
self.gm = gm
class EndIteration(WithMetric):
"""
Event On One Batch Training Complete.

@ -164,11 +164,18 @@ class SGD(object):
pass_type)
self.__gradient_machine__.eval(pass_evaluator)
self.__gradient_machine__.eval(batch_evaluator)
event_handler(
v2_event.EndForwardBackward(
pass_id=pass_id,
batch_id=batch_id,
gm=self.__gradient_machine__))
for each_param in self.__gradient_machine__.getNonStaticParameters(
):
self.__parameter_updater__.update(each_param)
cost_sum = out_args.sum()
cost = cost_sum / len(data_batch)
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
event_handler(
v2_event.EndIteration(
pass_id=pass_id,
@ -176,8 +183,6 @@ class SGD(object):
cost=cost,
evaluator=batch_evaluator,
gm=self.__gradient_machine__))
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
self.__parameter_updater__.finishPass()
pass_evaluator.finish()

Loading…
Cancel
Save