!8280 add export lenet grad example

Merge pull request !8280 from changzherui/add_export_grad_example
pull/8280/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c253d874e0

@ -0,0 +1,97 @@
import os
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import TruncatedNormal
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.train.serialization import export
def weight_variable():
return TruncatedNormal(0.02)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class WithLossCell(nn.Cell):
def __init__(self, network):
super(WithLossCell, self).__init__(auto_prefix=False)
self.loss = nn.SoftmaxCrossEntropyWithLogits()
self.network = network
def construct(self, x, label):
predict = self.network(x)
return self.loss(predict, label)
class TrainOneStepCell(nn.Cell):
def __init__(self, network):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True)
def construct(self, x, label):
weights = self.weights
grads = self.grad(self.network, weights)(x, label)
return self.optimizer(grads)
def test_export_lenet_grad_mindir():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
network = LeNet5()
network.set_train()
predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([32, 10]).astype(np.float32))
net = TrainOneStepCell(WithLossCell(network))
file_name = "lenet_grad.mindir"
export(net, predict, label, file_name=file_name, file_format='MINDIR')
assert os.path.exists(file_name)
os.remove(file_name)
Loading…
Cancel
Save