|
|
|
@ -16,28 +16,83 @@
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
import mindspore.ops.operations as P
|
|
|
|
|
from mindspore import context, Tensor
|
|
|
|
|
from mindspore.nn import Cell
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
|
|
|
|
|
class Net_Pool(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(Net_Pool, self).__init__()
|
|
|
|
|
self.maxpool_fun = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="VALID")
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return self.maxpool_fun(x)
|
|
|
|
|
class MaxPoolWithArgMax_Net(Cell):
|
|
|
|
|
def __init__(self, padding, ksize, strides):
|
|
|
|
|
super(MaxPoolWithArgMax_Net, self).__init__()
|
|
|
|
|
self.maxpool_with_argmax = P.MaxPoolWithArgmax(padding=padding, ksize=ksize, strides=strides)
|
|
|
|
|
|
|
|
|
|
def construct(self, input_data):
|
|
|
|
|
output, argmax = self.maxpool_with_argmax(input_data)
|
|
|
|
|
return output, argmax
|
|
|
|
|
|
|
|
|
|
class Net_Pool2(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(Net_Pool2, self).__init__()
|
|
|
|
|
self.maxpool_fun = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME")
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return self.maxpool_fun(x)
|
|
|
|
|
class Grad(Cell):
|
|
|
|
|
def __init__(self, network, argmax):
|
|
|
|
|
super(Grad, self).__init__()
|
|
|
|
|
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.sens = (Tensor(np.ones(argmax.shape).astype(np.float32)),
|
|
|
|
|
Tensor(np.ones(argmax.shape).astype(np.int32)))
|
|
|
|
|
|
|
|
|
|
def construct(self, input_data):
|
|
|
|
|
gout = self.grad(self.network)(input_data, self.sens)
|
|
|
|
|
return gout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_gpu_training
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_train_forward_backward():
|
|
|
|
|
x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float32)
|
|
|
|
|
expect_output = np.array([[[[5, 6, 7, 7],
|
|
|
|
|
[9, 10, 11, 11],
|
|
|
|
|
[9, 10, 11, 11]],
|
|
|
|
|
[[17, 18, 19, 19],
|
|
|
|
|
[21, 22, 23, 23],
|
|
|
|
|
[21, 22, 23, 23]],
|
|
|
|
|
[[29, 30, 31, 31],
|
|
|
|
|
[33, 34, 35, 35],
|
|
|
|
|
[33, 34, 35, 35]]]]).astype(np.float32)
|
|
|
|
|
expect_argmax = np.array([[[[5, 6, 7, 7],
|
|
|
|
|
[9, 10, 11, 11],
|
|
|
|
|
[9, 10, 11, 11]],
|
|
|
|
|
[[17, 18, 19, 19],
|
|
|
|
|
[21, 22, 23, 23],
|
|
|
|
|
[21, 22, 23, 23]],
|
|
|
|
|
[[29, 30, 31, 31],
|
|
|
|
|
[33, 34, 35, 35],
|
|
|
|
|
[33, 34, 35, 35]]]]).astype(np.int32)
|
|
|
|
|
expect_dx = np.array([[[[0, 0, 0, 0],
|
|
|
|
|
[0, 1, 1, 2],
|
|
|
|
|
[0, 2, 2, 4]],
|
|
|
|
|
[[0, 0, 0, 0],
|
|
|
|
|
[0, 1, 1, 2],
|
|
|
|
|
[0, 2, 2, 4]],
|
|
|
|
|
[[0, 0, 0, 0],
|
|
|
|
|
[0, 1, 1, 2],
|
|
|
|
|
[0, 2, 2, 4]]]]).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
|
|
|
|
net = MaxPoolWithArgMax_Net(padding="SAME", ksize=2, strides=1)
|
|
|
|
|
output_tensor, argmax_tensor = net(Tensor(x))
|
|
|
|
|
assert output_tensor.shape == expect_output.shape
|
|
|
|
|
assert argmax_tensor.shape == expect_argmax.shape
|
|
|
|
|
|
|
|
|
|
error = np.ones(shape=expect_output.shape) * 1.0e-5
|
|
|
|
|
diff_output = output_tensor.asnumpy() - expect_output
|
|
|
|
|
assert np.all(diff_output < error)
|
|
|
|
|
|
|
|
|
|
net_grad = Grad(net, argmax_tensor)
|
|
|
|
|
dx = net_grad(Tensor(x))[0].asnumpy()
|
|
|
|
|
assert dx.shape == expect_dx.shape
|
|
|
|
|
diff = dx - expect_dx
|
|
|
|
|
assert np.all(diff < error)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_gpu_training
|
|
|
|
@ -73,8 +128,8 @@ def test_maxpool_with_argmax_2d():
|
|
|
|
|
]]]))
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
|
|
|
|
maxpool2d = Net_Pool()
|
|
|
|
|
maxpool2d2 = Net_Pool2()
|
|
|
|
|
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
|
|
|
|
|
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
|
|
|
|
|
output2, index2 = maxpool2d2(x)
|
|
|
|
|
output, index = maxpool2d(x)
|
|
|
|
|
assert (output.asnumpy() == expect_result).all()
|
|
|
|
@ -83,8 +138,8 @@ def test_maxpool_with_argmax_2d():
|
|
|
|
|
assert (index2.asnumpy() == expect__index_result2).all()
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
|
|
|
|
maxpool2d = Net_Pool()
|
|
|
|
|
maxpool2d2 = Net_Pool2()
|
|
|
|
|
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
|
|
|
|
|
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
|
|
|
|
|
output2, index2 = maxpool2d2(x)
|
|
|
|
|
output, index = maxpool2d(x)
|
|
|
|
|
assert (output.asnumpy() == expect_result).all()
|
|
|
|
@ -126,8 +181,8 @@ def test_maxpool_with_argmax_2d_fp16():
|
|
|
|
|
]]]))
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
|
|
|
|
maxpool2d = Net_Pool()
|
|
|
|
|
maxpool2d2 = Net_Pool2()
|
|
|
|
|
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
|
|
|
|
|
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
|
|
|
|
|
output2, index2 = maxpool2d2(x)
|
|
|
|
|
output, index = maxpool2d(x)
|
|
|
|
|
assert (output.asnumpy() == expect_result).all()
|
|
|
|
@ -136,12 +191,11 @@ def test_maxpool_with_argmax_2d_fp16():
|
|
|
|
|
assert (index2.asnumpy() == expect__index_result2).all()
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
|
|
|
|
maxpool2d = Net_Pool()
|
|
|
|
|
maxpool2d2 = Net_Pool2()
|
|
|
|
|
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
|
|
|
|
|
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
|
|
|
|
|
output2, index2 = maxpool2d2(x)
|
|
|
|
|
output, index = maxpool2d(x)
|
|
|
|
|
assert (output.asnumpy() == expect_result).all()
|
|
|
|
|
assert (output2.asnumpy() == expect_result2).all()
|
|
|
|
|
assert (index.asnumpy() == expect_index_result).all()
|
|
|
|
|
assert (index2.asnumpy() == expect__index_result2).all()
|
|
|
|
|
|