|
|
|
@ -204,3 +204,35 @@ def test_reshape_unexpand_6():
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
_executor.compile(net, x)
|
|
|
|
|
|
|
|
|
|
def test_reshape_unexpand_7():
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
|
def __init__(self, in_channel=3, out_channel=8, axis=1, input_shape=(32, 4, 110, -1),
|
|
|
|
|
mul_size=(32, 1, 220, 220)):
|
|
|
|
|
super().__init__()
|
|
|
|
|
mul_np = np.full(mul_size, 0.5, dtype=np.float32)
|
|
|
|
|
self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
|
|
|
|
|
self.mul = P.Mul()
|
|
|
|
|
self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
|
|
|
|
|
kernel_size=5, has_bias=True, weight_init='ones',
|
|
|
|
|
bias_init='ones', pad_mode='valid')
|
|
|
|
|
self.softmax = nn.Softmax(axis=axis)
|
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
self.input_shape = input_shape
|
|
|
|
|
|
|
|
|
|
def construct(self, inputs):
|
|
|
|
|
x = self.conv(inputs)
|
|
|
|
|
x = self.softmax(x)
|
|
|
|
|
x = self.relu(x)
|
|
|
|
|
x = self.mul(x, self.mul_weight)
|
|
|
|
|
x = self.reshape(x, self.input_shape)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
size = 8
|
|
|
|
|
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
|
|
|
|
x = Tensor(np.ones([32, 3, 224, 224]), dtype=ms.float32)
|
|
|
|
|
net = GradWrap(NetWithLoss(Net()))
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
_executor.compile(net, x)
|
|
|
|
|