|
|
@ -16,15 +16,15 @@
|
|
|
|
train step wrap
|
|
|
|
train step wrap
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.nn as nn
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
from mindspore import ParameterTuple
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
|
|
|
from mindspore import Parameter, ParameterTuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainStepWrap(nn.Cell):
|
|
|
|
class TrainStepWrap(nn.Cell):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
TrainStepWrap definition
|
|
|
|
TrainStepWrap definition
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, network):
|
|
|
|
def __init__(self, network):
|
|
|
|
super(TrainStepWrap, self).__init__()
|
|
|
|
super(TrainStepWrap, self).__init__()
|
|
|
|
self.network = network
|
|
|
|
self.network = network
|
|
|
@ -39,10 +39,12 @@ class TrainStepWrap(nn.Cell):
|
|
|
|
grads = self.grad(self.network, weights)(x, label)
|
|
|
|
grads = self.grad(self.network, weights)(x, label)
|
|
|
|
return self.optimizer(grads)
|
|
|
|
return self.optimizer(grads)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NetWithLossClass(nn.Cell):
|
|
|
|
class NetWithLossClass(nn.Cell):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
NetWithLossClass definition
|
|
|
|
NetWithLossClass definition
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, network):
|
|
|
|
def __init__(self, network):
|
|
|
|
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
|
|
|
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
|
|
|
self.loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
|
|
self.loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
|
@ -61,6 +63,7 @@ class TrainStepWrap2(nn.Cell):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
TrainStepWrap2 definition
|
|
|
|
TrainStepWrap2 definition
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, sens):
|
|
|
|
def __init__(self, network, sens):
|
|
|
|
super(TrainStepWrap2, self).__init__()
|
|
|
|
super(TrainStepWrap2, self).__init__()
|
|
|
|
self.network = network
|
|
|
|
self.network = network
|
|
|
@ -76,13 +79,16 @@ class TrainStepWrap2(nn.Cell):
|
|
|
|
grads = self.grad(self.network, weights)(x, self.sens)
|
|
|
|
grads = self.grad(self.network, weights)(x, self.sens)
|
|
|
|
return self.optimizer(grads)
|
|
|
|
return self.optimizer(grads)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_step_with_sens(network, sens):
|
|
|
|
def train_step_with_sens(network, sens):
|
|
|
|
return TrainStepWrap2(network, sens)
|
|
|
|
return TrainStepWrap2(network, sens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainStepWrapWithoutOpt(nn.Cell):
|
|
|
|
class TrainStepWrapWithoutOpt(nn.Cell):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
TrainStepWrapWithoutOpt definition
|
|
|
|
TrainStepWrapWithoutOpt definition
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, network):
|
|
|
|
def __init__(self, network):
|
|
|
|
super(TrainStepWrapWithoutOpt, self).__init__()
|
|
|
|
super(TrainStepWrapWithoutOpt, self).__init__()
|
|
|
|
self.network = network
|
|
|
|
self.network = network
|
|
|
@ -93,5 +99,6 @@ class TrainStepWrapWithoutOpt(nn.Cell):
|
|
|
|
grads = self.grad(self.network, self.weights)(x, label)
|
|
|
|
grads = self.grad(self.network, self.weights)(x, label)
|
|
|
|
return grads
|
|
|
|
return grads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_step_without_opt(network):
|
|
|
|
def train_step_without_opt(network):
|
|
|
|
return TrainStepWrapWithoutOpt(NetWithLossClass(network))
|
|
|
|
return TrainStepWrapWithoutOpt(NetWithLossClass(network))
|
|
|
|