|
|
|
@ -39,7 +39,7 @@ class SGD(object):
|
|
|
|
|
:type parameters: paddle.v2.parameters.Parameters
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, cost, parameters, update_equation):
|
|
|
|
|
def __init__(self, cost, parameters, update_equation, extra_layers=None):
|
|
|
|
|
|
|
|
|
|
if not isinstance(parameters, v2_parameters.Parameters):
|
|
|
|
|
raise TypeError('parameters should be parameters')
|
|
|
|
@ -47,7 +47,7 @@ class SGD(object):
|
|
|
|
|
if not isinstance(update_equation, v2_optimizer.Optimizer):
|
|
|
|
|
raise TypeError("update equation parameter must be "
|
|
|
|
|
"paddle.v2.optimizer.Optimizer")
|
|
|
|
|
topology = Topology(cost)
|
|
|
|
|
topology = Topology(cost, extra_layers=extra_layers)
|
|
|
|
|
self.__optimizer__ = update_equation
|
|
|
|
|
self.__topology__ = topology
|
|
|
|
|
self.__parameters__ = parameters
|
|
|
|
|