|
|
|
@ -387,7 +387,38 @@ def test_switch_layer():
|
|
|
|
|
ret = F.switch_layer(index, self.layers)(x) * self.z3
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
index = Tensor(0)
|
|
|
|
|
net = SwitchLayerCell()
|
|
|
|
|
net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
|
|
|
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
|
|
|
C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
|
|
|
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
|
|
|
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
|
|
|
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
|
|
|
|
|
|
|
|
def test_index_to_switch_layer():
|
|
|
|
|
class Layer1(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(Layer1, self).__init__()
|
|
|
|
|
self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return x * self.z1
|
|
|
|
|
|
|
|
|
|
class Layer2(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(Layer2, self).__init__()
|
|
|
|
|
self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return x * self.z2
|
|
|
|
|
|
|
|
|
|
class SwitchLayerCell(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(SwitchLayerCell, self).__init__()
|
|
|
|
|
self.layers = (Layer1(), Layer2())
|
|
|
|
|
self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
|
|
|
|
def construct(self, index, x):
|
|
|
|
|
ret = self.layers[index](x) * self.z3
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
index = Tensor(0)
|
|
|
|
|
net = SwitchLayerCell()
|
|
|
|
|
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
|
|
|
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
|
|
|
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
|
|
|