!2943 [quant]export bug fix

Merge pull request !2943 from vlne-v1/quant_export_bugfix
pull/2943/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d3ec05d716

@ -328,6 +328,9 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
x = cnode->input(1); x = cnode->input(1);
count += 1; count += 1;
} }
if (x->isa<Parameter>()) {
fake_quant_table[weight_name] = std::make_pair(nullptr, "input");
}
// get the fakequant parameter minq's name // get the fakequant parameter minq's name
if (!is_quant_cnode(x)) { if (!is_quant_cnode(x)) {
continue; continue;

@ -1169,9 +1169,9 @@ class QuantBlock(Cell):
return x return x
def extend_repr(self): def extend_repr(self):
str_info = f'quant={self.quant}, core_op={type(self.core_op)}' str_info = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
if self.has_bias: if self.has_bias:
str_info = str_info + f', bias={self.bias}' str_info = str_info + f', bias=shape[{self.bias.shape}]'
if self.has_act: if self.has_act:
str_info = str_info + f', activation={self.activation}' str_info = str_info + f', activation={self.activation}'
str_info = str_info + f', dequant={self.dequant}' str_info = str_info + f', dequant={self.dequant}'

@ -237,12 +237,14 @@ class PrimitiveWithInfer(Primitive):
""" """
Infer output shape based on input shape. Infer output shape based on input shape.
Args:
inputs (tuple(int)): dimensions of input tensors.
outputs (tuple(int)): dimensions of output tensors.
Note: Note:
The shape of scalar is an empty tuple. The shape of scalar is an empty tuple.
Args:
args (tuple(int)): shapes of input tensors.
Return:
`tuple(int)`, shapes of output tensors.
""" """
return None return None
@ -251,8 +253,10 @@ class PrimitiveWithInfer(Primitive):
Infer output dtype based on input dtype. Infer output dtype based on input dtype.
Args: Args:
inputs (mstype): data type of inputs. args (:class:`mindspore.dtype`): data type of inputs.
outputs (mstype): data type of outputs.
Return:
:class:`mindspore.dtype`, data type of outputs.
""" """
return None return None
@ -261,8 +265,10 @@ class PrimitiveWithInfer(Primitive):
Infer output value based on input value at compile time. Infer output value based on input value at compile time.
Args: Args:
inputs (any): value of inputs. args (Any): value of inputs.
outputs (any): value of outputs.
Return:
Value of outputs. Return `None` for, cat not infer the value at compile time.
""" """
return None return None

@ -318,9 +318,12 @@ class ExportToQuantInferNetwork:
info = self.quant_info_table.get(w_minq_name, None) info = self.quant_info_table.get(w_minq_name, None)
if info: if info:
fack_quant_a_in_op, minq_name = info fack_quant_a_in_op, minq_name = info
maxq = self.all_parameters[minq_name[:-4] + "maxq"] if minq_name == 'input':
minq = self.all_parameters[minq_name] scale_a_in, zp_a_in = self.input_scale, self.input_zero_point
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) else:
maxq = self.all_parameters[minq_name[:-4] + "maxq"]
minq = self.all_parameters[minq_name]
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
else: else:
logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
return None return None

@ -104,19 +104,20 @@ def weight2int(data, scale, zero_point):
raise ValueError("`scale` and `zero_point` should have the same shape.") raise ValueError("`scale` and `zero_point` should have the same shape.")
if scale.shape[0] < 0: if scale.shape[0] < 0:
raise ValueError("`scale` and `zero_point` shape should greater than zero.") raise ValueError("`scale` and `zero_point` shape should greater than zero.")
if len(scale.shape) > 1:
if scale.shape[0] == data.shape[0]: # for perchannel
# `Conv2d` or `Dense` op weight if scale.shape[0] == data.shape[0]:
shape_list = [-1] + [1] * len(data.shape[1:]) # `Conv2d` or `Dense` op weight
scale = scale.reshape(shape_list) shape_list = [-1] + [1] * len(data.shape[1:])
zero_point = zero_point.reshape(shape_list) scale = scale.reshape(shape_list)
elif scale.shape[0] == data.shape[1]: zero_point = zero_point.reshape(shape_list)
# `DepthwiseConv2d` op weight elif scale.shape[0] == data.shape[1]:
shape_list = [1, -1] + [1] * len(data.shape[2:]) # `DepthwiseConv2d` op weight
scale = scale.reshape(shape_list) shape_list = [1, -1] + [1] * len(data.shape[2:])
zero_point = zero_point.reshape(shape_list) scale = scale.reshape(shape_list)
else: zero_point = zero_point.reshape(shape_list)
raise ValueError("Unsupported weight shape({})".format(data.shape)) else:
raise ValueError("Unsupported weight shape({})".format(data.shape))
return np.round((data / scale) + zero_point) return np.round((data / scale) + zero_point)

@ -1,115 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MobileNetV2"""
from mindspore import nn
from mindspore.ops import operations as P
def make_divisible(input_x, div_by=8):
return int((input_x + div_by) // div_by)
def _conv_bn(in_channel,
out_channel,
ksize,
stride=1):
"""Get a conv2d batchnorm and relu layer."""
return nn.SequentialCell(
[nn.Conv2d(in_channel,
out_channel,
kernel_size=ksize,
stride=stride),
nn.BatchNorm2d(out_channel)])
class InvertedResidual(nn.Cell):
def __init__(self, inp, oup, stride, expend_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expend_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expend_ratio == 1:
self.conv = nn.SequentialCell([
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
nn.Conv2d(hidden_dim, oup, 1, 1),
nn.BatchNorm2d(oup)
])
else:
self.conv = nn.SequentialCell([
nn.Conv2d(inp, hidden_dim, 1, 1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
nn.Conv2d(hidden_dim, oup, 1, 1),
nn.BatchNorm2d(oup)
])
def construct(self, input_x):
out = self.conv(input_x)
if self.use_res_connect:
out = input_x + out
return out
class MobileNetV2(nn.Cell):
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
super(MobileNetV2, self).__init__()
_ = input_size
block = InvertedResidual
input_channel = 32
last_channel = 1280
inverted_residual_setting = [
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 230, 1, 1],
]
if width_mul > 1.0:
last_channel = make_divisible(last_channel * width_mul)
self.last_channel = last_channel
features = [_conv_bn(3, input_channel, 3, 2)]
for t, c, n, s in inverted_residual_setting:
out_channel = make_divisible(c * width_mul) if t > 1 else c
for i in range(n):
if i == 0:
features.append(block(input_channel, out_channel, s, t))
else:
features.append(block(input_channel, out_channel, 1, t))
input_channel = out_channel
features.append(_conv_bn(input_channel, self.last_channel, 1))
self.features = nn.SequentialCell(features)
self.mean = P.ReduceMean(keep_dims=False)
self.classifier = nn.Dense(self.last_channel, num_class)
def construct(self, input_x):
out = input_x
out = self.features(out)
out = self.mean(out, (2, 3))
out = self.classifier(out)
return out

@ -1,122 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mobile net v2"""
from mindspore import nn
from mindspore.ops import operations as P
def make_divisible(input_x, div_by=8):
return int((input_x + div_by) // div_by)
def _conv_bn(in_channel,
out_channel,
ksize,
stride=1):
"""Get a conv2d batchnorm and relu layer."""
return nn.SequentialCell(
[nn.Conv2dBnAct(in_channel,
out_channel,
kernel_size=ksize,
stride=stride,
has_bn=True)])
class InvertedResidual(nn.Cell):
def __init__(self, inp, oup, stride, expend_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expend_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expend_ratio == 1:
self.conv = nn.SequentialCell([
nn.Conv2dBnAct(hidden_dim,
hidden_dim,
3,
stride,
group=hidden_dim,
has_bn=True,
activation='relu6'),
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
has_bn=True)
])
else:
self.conv = nn.SequentialCell([
nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
has_bn=True,
activation='relu6'),
nn.Conv2dBnAct(hidden_dim,
hidden_dim,
3,
stride,
group=hidden_dim,
has_bn=True,
activation='relu6'),
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
has_bn=True)
])
self.add = P.TensorAdd()
def construct(self, input_x):
out = self.conv(input_x)
if self.use_res_connect:
out = self.add(input_x, out)
return out
class MobileNetV2(nn.Cell):
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
super(MobileNetV2, self).__init__()
_ = input_size
block = InvertedResidual
input_channel = 32
last_channel = 1280
inverted_residual_setting = [
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 230, 1, 1],
]
if width_mul > 1.0:
last_channel = make_divisible(last_channel * width_mul)
self.last_channel = last_channel
features = [_conv_bn(3, input_channel, 3, 2)]
for t, c, n, s in inverted_residual_setting:
out_channel = make_divisible(c * width_mul) if t > 1 else c
for i in range(n):
if i == 0:
features.append(block(input_channel, out_channel, s, t))
else:
features.append(block(input_channel, out_channel, 1, t))
input_channel = out_channel
features.append(_conv_bn(input_channel, self.last_channel, 1))
self.features = nn.SequentialCell(features)
self.mean = P.ReduceMean(keep_dims=False)
self.classifier = nn.DenseBnAct(self.last_channel, num_class)
def construct(self, input_x):
out = input_x
out = self.features(out)
out = self.mean(out, (2, 3))
out = self.classifier(out)
return out

@ -20,7 +20,7 @@ import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.train.quant import quant as qat from mindspore.train.quant import quant as qat
from mobilenetv2_combined import MobileNetV2 from model_zoo.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
def __init__(self, num_class=10): def __init__(self, num_class=10):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.num_class = num_class self.num_class = num_class
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid") self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu', pad_mode="valid")
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
@ -67,20 +67,19 @@ def test_qat_lenet():
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
net = LeNet5() net = LeNet5()
net = qat.convert_quant_network( net = qat.convert_quant_network(
net, freeze_bn=10000, num_bits=8) net, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# should load the checkpoint. mock here # should load the checkpoint. mock here
for param in net.get_parameters(): for param in net.get_parameters():
param.init_data() param.init_data()
qat.export_geir(net, img, file_name="quant.pb") qat.export(net, img, file_name="quant.pb")
@pytest.mark.skip(reason="no `te.lang.cce` in ut env") @pytest.mark.skip(reason="no `te.lang.cce` in ut env")
def test_qat_mobile(): def test_qat_mobile():
net = MobileNetV2() network = mobilenetV2(num_classes=1000)
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
net = qat.convert_quant_network( network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
net, quant_delay=0, bn_fold=True, freeze_bn=10000, num_bits=8)
# should load the checkpoint. mock here # should load the checkpoint. mock here
for param in net.get_parameters(): for param in network.get_parameters():
param.init_data() param.init_data()
qat.export_geir(net, img, file_name="quant.pb") qat.export(network, img, file_name="quant.pb")

Loading…
Cancel
Save