|
|
|
@ -304,15 +304,19 @@ class WithEvalCell(Cell):
|
|
|
|
|
>>> eval_net = nn.WithEvalCell(net, loss_fn)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, loss_fn):
|
|
|
|
|
def __init__(self, network, loss_fn, add_cast_fp32=False):
|
|
|
|
|
super(WithEvalCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self._network = network
|
|
|
|
|
self._loss_fn = loss_fn
|
|
|
|
|
self.add_cast_fp32 = add_cast_fp32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, data, label):
|
|
|
|
|
outputs = self._network(data)
|
|
|
|
|
label = _mp_cast_helper(mstype.float32, label)
|
|
|
|
|
loss = self._loss_fn(F.cast(outputs, mstype.float32), label)
|
|
|
|
|
if self.add_cast_fp32:
|
|
|
|
|
label = _mp_cast_helper(mstype.float32, label)
|
|
|
|
|
outputs = F.cast(outputs, mstype.float32)
|
|
|
|
|
loss = self._loss_fn(outputs, label)
|
|
|
|
|
return loss, outputs, label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|