You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
223 lines
7.4 KiB
223 lines
7.4 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
|
|
import pytest
|
|
import numpy as np
|
|
import mindspore.nn as nn
|
|
import mindspore.common.dtype as mstype
|
|
|
|
from mindspore import Tensor
|
|
from mindspore import context
|
|
from mindspore import ParameterTuple
|
|
from mindspore.nn import Momentum
|
|
from mindspore.nn import WithLossCell
|
|
from mindspore.ops import composite as C
|
|
from mindspore.ops import operations as P
|
|
from mindspore.common.initializer import TruncatedNormal
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
|
|
|
|
|
grad_all = C.GradOperation(get_all=True)
|
|
|
|
|
|
def weight_variable():
|
|
"""weight initial"""
|
|
return TruncatedNormal(0.02)
|
|
|
|
|
|
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
|
"""weight initial for conv layer"""
|
|
weight = weight_variable()
|
|
return nn.Conv2d(in_channels, out_channels,
|
|
kernel_size=kernel_size, stride=stride, padding=padding,
|
|
weight_init=weight, has_bias=False, pad_mode="valid")
|
|
|
|
|
|
def fc_with_initialize(input_channels, out_channels):
|
|
"""weight initial for fc layer"""
|
|
weight = weight_variable()
|
|
bias = weight_variable()
|
|
return nn.Dense(input_channels, out_channels, weight, bias)
|
|
|
|
|
|
class test_custom_hook_function_base():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def test_custom_hook_function(self, hook_function, cell_hook_function):
|
|
return hook_function, cell_hook_function
|
|
|
|
|
|
def cell_hook_function_print_grad(cell_id, grad_input, grad_output):
|
|
assert grad_output[0].asnumpy().shape == (32, 6, 14, 14)
|
|
assert grad_input[0].asnumpy().shape == (32, 16, 10, 10)
|
|
|
|
|
|
def custom_hook_function_print_and_save_grad(grad_out):
|
|
assert grad_out[0].asnumpy().shape == (32, 6, 28, 28)
|
|
|
|
|
|
class LeNet5(nn.Cell):
|
|
def __init__(self, hook_function, cell_hook_function, num_class=10):
|
|
super(LeNet5, self).__init__()
|
|
self.num_class = num_class
|
|
self.batch_size = 32
|
|
self.conv1 = conv(1, 6, 5)
|
|
self.conv2 = conv(6, 16, 5)
|
|
self.conv1.register_backward_hook(cell_hook_function)
|
|
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
|
self.fc2 = fc_with_initialize(120, 84)
|
|
self.fc3 = fc_with_initialize(84, self.num_class)
|
|
self.relu = nn.ReLU()
|
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.reshape = P.Reshape()
|
|
self.hook = P.HookBackward(hook_function)
|
|
|
|
def construct(self, x):
|
|
x = self.conv1(x)
|
|
x = self.relu(x)
|
|
x = self.hook(x)
|
|
x = self.max_pool2d(x)
|
|
x = self.conv2(x)
|
|
x = self.relu(x)
|
|
x = self.max_pool2d(x)
|
|
x = self.reshape(x, (self.batch_size, -1))
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.fc2(x)
|
|
x = self.relu(x)
|
|
x = self.fc3(x)
|
|
return x
|
|
|
|
|
|
class GradWrap(nn.Cell):
|
|
""" GradWrap definition """
|
|
def __init__(self, network):
|
|
super(GradWrap, self).__init__(auto_prefix=False)
|
|
self.network = network
|
|
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
|
|
|
|
def construct(self, x, label):
|
|
weights = self.weights
|
|
return C.GradOperation(get_by_list=True)(self.network, weights)(x, label)
|
|
|
|
|
|
class test_custom_cell_base():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def test_custom_cell_function(self, cell):
|
|
return cell
|
|
|
|
|
|
class MulAdd(nn.Cell):
|
|
def construct(self, x, y):
|
|
return 2 * x + y
|
|
|
|
def bprop(self, x, y, out, dout):
|
|
assert x.asnumpy() == 1.0
|
|
assert y.asnumpy() == 2.0
|
|
assert out.asnumpy() == 4.0
|
|
assert dout.asnumpy() == 1.0
|
|
return dout, y
|
|
|
|
class Ms_Cell(nn.Cell):
|
|
def __init__(self):
|
|
super(Ms_Cell, self).__init__()
|
|
self.relu = P.ReLU()
|
|
|
|
def construct(self, x):
|
|
return self.relu(x)
|
|
|
|
def bprop(self, x, out, dout):
|
|
dout = Tensor(np.float32(0.0))
|
|
assert dout.shape == ()
|
|
return dout
|
|
|
|
class Ms_Cell_Change_Shape(nn.Cell):
|
|
def __init__(self):
|
|
super(Ms_Cell_Change_Shape, self).__init__()
|
|
self.relu = P.ReLU()
|
|
|
|
def construct(self, x):
|
|
return self.relu(x)
|
|
|
|
def bprop(self, x, out, dout):
|
|
dout = Tensor(np.ones([5, 5]).astype(np.float32))
|
|
assert dout.shape == (5, 5)
|
|
return dout
|
|
|
|
|
|
@pytest.mark.level0
|
|
@pytest.mark.platform_arm_ascend_training
|
|
@pytest.mark.platform_x86_ascend_training
|
|
@pytest.mark.env_onecard
|
|
def test_pynative_lenet_train_hook_function_print_and_save_grad():
|
|
hook = test_custom_hook_function_base()
|
|
function = hook.test_custom_hook_function(custom_hook_function_print_and_save_grad,
|
|
cell_hook_function_print_grad)
|
|
net = LeNet5(hook_function=function[0], cell_hook_function=function[1])
|
|
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
|
|
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
|
net_with_criterion = WithLossCell(net, criterion)
|
|
train_network = GradWrap(net_with_criterion)
|
|
train_network.set_train()
|
|
|
|
input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
|
|
label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
|
|
output = net(Tensor(input_data))
|
|
criterion(output, label)
|
|
grads = train_network(input_data, label)
|
|
success = optimizer(grads)
|
|
assert success
|
|
|
|
|
|
@pytest.mark.level0
|
|
@pytest.mark.platform_arm_ascend_training
|
|
@pytest.mark.platform_x86_ascend_training
|
|
@pytest.mark.env_onecard
|
|
def test_pynative_custom_bprop_and_Cell_MulAdd():
|
|
custom_cell = test_custom_cell_base()
|
|
mul_add = custom_cell.test_custom_cell_function(MulAdd())
|
|
mul_add.bprop_debug = True
|
|
grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32))
|
|
assert grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) == \
|
|
(Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32))
|
|
|
|
|
|
@pytest.mark.level0
|
|
@pytest.mark.platform_arm_ascend_training
|
|
@pytest.mark.platform_x86_ascend_training
|
|
@pytest.mark.env_onecard
|
|
def test_pynative_custom_bprop_and_Cell_Ms_Cell_Change_Shape():
|
|
custom_cell = test_custom_cell_base()
|
|
ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell_Change_Shape())
|
|
ms_Cell.bprop_debug = True
|
|
with pytest.raises(RuntimeError) as ex:
|
|
grad_all(ms_Cell)(Tensor(1, mstype.float32))
|
|
assert "Shapes of input and parameter are different, input index" in str(ex.value)
|
|
|
|
|
|
@pytest.mark.level0
|
|
@pytest.mark.platform_arm_ascend_training
|
|
@pytest.mark.platform_x86_ascend_training
|
|
@pytest.mark.env_onecard
|
|
def test_pynative_custom_bprop_and_Cell_Ms_Cell():
|
|
custom_cell = test_custom_cell_base()
|
|
ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell())
|
|
ms_Cell.bprop_debug = True
|
|
assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(0.0, mstype.float32),)
|
|
|