|
|
|
@ -157,8 +157,23 @@ class TrainOneStepCell(Cell):
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
|
|
|
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
|
|
|
>>> #1) Using the WithLossCell existing provide
|
|
|
|
|
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
|
|
|
|
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
|
|
|
|
>>>
|
|
|
|
|
>>> #2) Using user-defined WithLossCell
|
|
|
|
|
>>>class MyWithLossCell(nn.cell):
|
|
|
|
|
>>> def __init__(self, backbone, loss_fn):
|
|
|
|
|
>>> super(WithLossCell, self).__init__(auto_prefix=False)
|
|
|
|
|
>>> self._backbone = backbone
|
|
|
|
|
>>> self._loss_fn = loss_fn
|
|
|
|
|
>>>
|
|
|
|
|
>>> def construct(self, x, y, label):
|
|
|
|
|
>>> out = self._backbone(x, y)
|
|
|
|
|
>>> return self._loss_fn(out, label)
|
|
|
|
|
>>>
|
|
|
|
|
>>> loss_net = MyWithLossCell(net, loss_fn)
|
|
|
|
|
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, network, optimizer, sens=1.0):
|
|
|
|
|
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
|
|
|
|