|
|
|
@ -25,7 +25,7 @@ class CompleteTrainOneBatch(BaseEvent):
|
|
|
|
|
self.pass_id = pass_id
|
|
|
|
|
self.batch_id = batch_id
|
|
|
|
|
self.cost = cost
|
|
|
|
|
self.paramters = parameters
|
|
|
|
|
self.parameters = parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_event_handler(event):
|
|
|
|
@ -44,6 +44,17 @@ class ITrainer(object):
|
|
|
|
|
|
|
|
|
|
class LazyParameterPool(v2_parameters.IParameterPool):
|
|
|
|
|
"""
|
|
|
|
|
Lazy Parameter Pool stores a reference to GradientMachine. User could invoke
|
|
|
|
|
`get_parameter` if needed, but the operation is lazy. It means the parameter
|
|
|
|
|
will only fetched from GPU or Parameter Server if `get_parameter` is
|
|
|
|
|
invoked. Also, set flag = writable will make a extra host2device copy after
|
|
|
|
|
reading/modifying parameter.
|
|
|
|
|
|
|
|
|
|
This class is not exposed to User. User should treat this class as a normal
|
|
|
|
|
IParameterPool.
|
|
|
|
|
|
|
|
|
|
See IParameterPool for usage documentation.
|
|
|
|
|
|
|
|
|
|
:type __gradient_machine__: api.GradientMachine
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -130,12 +141,22 @@ class CustomizeUpdateEquation(object):
|
|
|
|
|
shape)
|
|
|
|
|
g = param.getBuf(api.PARAMETER_GRADIENT).toNumpyArrayInplace(
|
|
|
|
|
).reshape(shape)
|
|
|
|
|
args = [v, g]
|
|
|
|
|
for arg in self.local_params[conf.name]:
|
|
|
|
|
args.append(arg)
|
|
|
|
|
self.__callback__(*args)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
v = param.getBuf(api.PARAMETER_VALUE).copyToNumpyArray().reshape(
|
|
|
|
|
shape)
|
|
|
|
|
g = param.getBuf(api.PARAMETER_GRADIENT).copyToNumpyArray().reshape(
|
|
|
|
|
shape)
|
|
|
|
|
|
|
|
|
|
args = [v, g]
|
|
|
|
|
for arg in self.local_params[conf.name]:
|
|
|
|
|
args.append(arg)
|
|
|
|
|
self.__callback__(*args)
|
|
|
|
|
|
|
|
|
|
if api.isUsingGpu():
|
|
|
|
|
param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(v.flatten(
|
|
|
|
|
).astype('float32'))
|
|
|
|
|
# discard gradient changed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SGDTrainer(ITrainer):
|
|
|
|
|