|
|
|
@ -247,15 +247,15 @@ def fc_with_initialize(input_channels, out_channels):
|
|
|
|
|
class BNReshapeDenseBNNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(BNReshapeDenseBNNet, self).__init__()
|
|
|
|
|
self.batch_norm = bn_with_initialize(512)
|
|
|
|
|
self.batch_norm = bn_with_initialize(2)
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
|
|
|
|
|
self.fc = fc_with_initialize(512 * 32 * 32, 512)
|
|
|
|
|
self.fc = fc_with_initialize(2 * 32 * 32, 512)
|
|
|
|
|
self.loss = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
|
|
|
|
|
|
|
|
|
|
def construct(self, x, label):
|
|
|
|
|
x = self.batch_norm(x)
|
|
|
|
|
x = self.reshape(x, (16, 512*32*32))
|
|
|
|
|
x = self.reshape(x, (16, 2*32*32))
|
|
|
|
|
x = self.fc(x)
|
|
|
|
|
x = self.batch_norm2(x)
|
|
|
|
|
loss = self.loss(x, label)
|
|
|
|
@ -266,7 +266,7 @@ def test_bn_reshape_dense_bn_train_loss():
|
|
|
|
|
batch_size = 16
|
|
|
|
|
device_num = 16
|
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, global_rank=0)
|
|
|
|
|
input = Tensor(np.ones([batch_size, 512, 32, 32]).astype(np.float32) * 0.01)
|
|
|
|
|
input = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
|
|
|
|
|
label = Tensor(np.ones([batch_size]), dtype=ms.int32)
|
|
|
|
|
|
|
|
|
|
net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
|
|
|
|
|