|
|
|
@ -138,6 +138,7 @@ class TrainOneStepCell(nn.Cell):
|
|
|
|
|
def __init__(self, network, optimizer, sens=1.0):
|
|
|
|
|
super(TrainOneStepCell, self).__init__(auto_prefix=True)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_grad()
|
|
|
|
|
self.network.add_flags(defer_inline=True)
|
|
|
|
|
self.weights = ParameterTuple(network.trainable_params())
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
@ -167,7 +168,6 @@ class TrainGAT(nn.Cell):
|
|
|
|
|
def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff):
|
|
|
|
|
super(TrainGAT, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_grad()
|
|
|
|
|
loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff)
|
|
|
|
|
optimizer = nn.Adam(loss_net.trainable_params(),
|
|
|
|
|
learning_rate=learning_rate)
|
|
|
|
|