From 2c6a9c8486004d2888d5291dcfe1e5f5f65e555a Mon Sep 17 00:00:00 2001 From: jzg Date: Mon, 12 Oct 2020 19:39:53 +0800 Subject: [PATCH] add fake-quant operators. --- mindspore/ops/_grad/grad_quant_ops.py | 25 ++ mindspore/ops/_op_impl/tbe/__init__.py | 4 + .../tbe/fake_quant_with_min_max_vars.py | 39 +++ .../fake_quant_with_min_max_vars_gradient.py | 43 ++++ ...ake_quant_with_min_max_vars_per_channel.py | 39 +++ ..._with_min_max_vars_per_channel_gradient.py | 43 ++++ mindspore/ops/operations/_quant_ops.py | 235 ++++++++++++++++++ tests/ut/python/ops/test_ops.py | 13 + 8 files changed, 441 insertions(+) create mode 100644 mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py create mode 100644 mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py create mode 100644 mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py create mode 100644 mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py index a2b0ba8d97..2c0ad42569 100644 --- a/mindspore/ops/_grad/grad_quant_ops.py +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -34,6 +34,31 @@ def get_bprop_fakequant_with_minmax(self): return bprop +@bprop_getters.register(Q.FakeQuantWithMinMaxVars) +def get_bprop_fakequant_with_minmax_vars(self): + """Generate bprop for FakeQuantWithMinMaxVars for Ascend""" + op = Q.FakeQuantWithMinMaxVarsGradient( + num_bits=self.num_bits, narrow_range=self.narrow_range) + + def bprop(x, x_min, x_max, out, dout): + dx = op(dout, x, x_min, x_max) + return dx, zeros_like(x_min), zeros_like(x_max) + + return bprop + + +@bprop_getters.register(Q.FakeQuantWithMinMaxVarsPerChannel) +def get_bprop_fakequant_with_minmax_vars_perchannel(self): + """Generate bprop for FakeQuantWithMinMaxVarsPerChannel for Ascend""" + op = Q.FakeQuantWithMinMaxVarsPerChannelGradient( + num_bits=self.num_bits, narrow_range=self.narrow_range) + + def bprop(x, x_min, x_max, out, dout): + dx = op(dout, x, x_min, x_max) + return dx, zeros_like(x_min), zeros_like(x_max) + + return bprop + @bprop_getters.register(Q.FakeQuantPerChannel) def get_bprop_fakequant_with_minmax_perchannel(self): diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 5b3c1fec2f..9b207f1b89 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -321,3 +321,7 @@ from .parallel_concat import _parallel_concat_tbe from .adam_apply_one_assign import _adam_apply_one_assign_tbe from .adam_apply_one_with_decay_assign import _adam_apply_one_with_decay_assign_tbe from .ifmr import _ifmr_tbe +from .fake_quant_with_min_max_vars import _fake_quant_with_min_max_vars_tbe +from .fake_quant_with_min_max_vars_gradient import _fake_quant_with_min_max_vars_gradient_tbe +from .fake_quant_with_min_max_vars_per_channel import _fake_quant_with_min_max_vars_per_channel_tbe +from .fake_quant_with_min_max_vars_per_channel_gradient import _fake_quant_with_min_max_vars_per_channel_gradient_tbe diff --git a/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py b/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py new file mode 100644 index 0000000000..847631e465 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeQuantWithMinMaxVars op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fake_quant_with_min_max_vars_op_info = TBERegOp("FakeQuantWithMinMaxVars") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_quant_with_min_max_vars.so") \ + .compute_cost(10) \ + .kernel_name("fake_quant_with_min_max_vars") \ + .partial_flag(True) \ + .attr("num_bits", "optional", "int", "all") \ + .attr("narrow_range", "optional", "bool", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "min", False, "required", "all") \ + .input(2, "max", False, "required", "all") \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(fake_quant_with_min_max_vars_op_info) +def _fake_quant_with_min_max_vars_tbe(): + """FakeQuantWithMinMaxVar TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py b/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py new file mode 100644 index 0000000000..582a48789b --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeQuantWithMinMaxVars op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fake_quant_with_min_max_vars_gradient_op_info = TBERegOp("FakeQuantWithMinMaxVarsGradient") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_quant_with_min_max_vars_gradient.so") \ + .compute_cost(10) \ + .kernel_name("fake_quant_with_min_max_vars_gradient") \ + .partial_flag(True) \ + .attr("num_bits", "optional", "int", "all") \ + .attr("narrow_range", "optional", "bool", "all") \ + .input(0, "gradients", False, "required", "all") \ + .input(1, "x", False, "required", "all") \ + .input(2, "min", False, "required", "all") \ + .input(3, "max", False, "required", "all") \ + .output(0, "backprops_wrt_x", True, "required", "all") \ + .output(1, "backprops_wrt_min", True, "required", "all") \ + .output(2, "backprops_wrt_max", True, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(fake_quant_with_min_max_vars_gradient_op_info) +def _fake_quant_with_min_max_vars_gradient_tbe(): + """FakeQuantWithMinMaxVarsGradient TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py b/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py new file mode 100644 index 0000000000..957c7fb634 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeQuantWithMinMaxVarsPerChannel op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fake_quant_with_min_max_vars_per_channel_op_info = TBERegOp("FakeQuantWithMinMaxVarsPerChannel") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_quant_with_min_max_vars_per_channel.so") \ + .compute_cost(10) \ + .kernel_name("fake_quant_with_min_max_vars_per_channel") \ + .partial_flag(True) \ + .attr("num_bits", "optional", "int", "all") \ + .attr("narrow_range", "optional", "bool", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "min", False, "required", "all") \ + .input(2, "max", False, "required", "all") \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(fake_quant_with_min_max_vars_per_channel_op_info) +def _fake_quant_with_min_max_vars_per_channel_tbe(): + """FakeQuantWithMinMaxVarsPerChannel TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py b/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py new file mode 100644 index 0000000000..fc2438af94 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeQuantWithMinMaxVarsPerChannelGradient op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fake_quant_with_min_max_vars_per_channel_gradient_op_info = TBERegOp("FakeQuantWithMinMaxVarsPerChannelGradient") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_quant_with_min_max_vars_per_channel_gradient.so") \ + .compute_cost(10) \ + .kernel_name("fake_quant_with_min_max_vars_per_channel_gradient") \ + .partial_flag(True) \ + .attr("num_bits", "optional", "int", "all") \ + .attr("narrow_range", "optional", "bool", "all") \ + .input(0, "gradients", False, "required", "all") \ + .input(1, "x", False, "required", "all") \ + .input(2, "min", False, "required", "all") \ + .input(3, "max", False, "required", "all") \ + .output(0, "backprops_wrt_x", True, "required", "all") \ + .output(1, "backprops_wrt_min", True, "required", "all") \ + .output(2, "backprops_wrt_max", True, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(fake_quant_with_min_max_vars_per_channel_gradient_op_info) +def _fake_quant_with_min_max_vars_per_channel_gradient_tbe(): + """FakeQuantWithMinMaxVarsPerChannelGradient TBE register""" + return diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 863e214eec..de31a285e7 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -23,6 +23,10 @@ from ...common import dtype as mstype __all__ = ["MinMaxUpdatePerLayer", "MinMaxUpdatePerChannel", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxVarsGradient", + "FakeQuantWithMinMaxVarsPerChannel", + "FakeQuantWithMinMaxVarsPerChannelGradient", "FakeQuantPerLayer", "FakeQuantPerLayerGrad", "FakeQuantPerChannel", @@ -165,6 +169,237 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): return min_type, max_type +class FakeQuantWithMinMaxVars(PrimitiveWithInfer): + r""" + Fake-quantize the input by min and max. + + Args: + num_bits (int): Quantization bitwidth; between 2 and 16. Default: 8. + narrow_range (bool): Whether the quantization algorithm uses narrow range or not. + if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization + range is [1, 2^num_bits-1]. Default: False. + + Inputs: + - **x** (Tensor) - Float32 tensor representing the shape of the output tensor. + - **min** (Tensor) - Value of the min range of the input data x. + - **max** (Tensor) - Value of the max range of the input data x. + + Outputs: + - Tensor, the data type and shape of output tensor is the same as input x. + + Examples: + >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min_tensor = Tensor(np.array([-6]), mstype.float32) + >>> max_tensor = Tensor(np.array([6]), mstype.float32) + >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)( + >>> input_tensor, min_tensor, max_tensor) + >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32 + """ + @prim_attr_register + def __init__(self, + num_bits=8, + narrow_range=False): + self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) + self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + + def check_broadcast(self, min_shape, input_shape): + shape_val = 1 + for shape in input_shape: + shape_val = shape_val * shape + if min_shape[0] > 1 and min_shape[0] != shape_val: + raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.") + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) + self.check_broadcast(min_shape, x_shape) + return x_shape + + def infer_dtype(self, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) + validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) + validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) + return x_type + + +class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): + r""" + Performs grad of FakeQuantWithMinMaxVars operation. + + Args: + num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8. + narrow_range (bool): Whether the quantization algorithm uses narrow range or not. + if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization + range is [1, 2^num_bits-1]. Default: False. + + Inputs: + - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars. + - **x** (Tensor) - Float32 tensor representing the shape of the output tensor. + - **min** (Tensor) - Value of the min range of the input data x. + - **max** (Tensor) - Value of the max range of the input data x. + + Outputs: + - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x. + - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min. + - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max. + + Examples: + >>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min_tensor = Tensor(np.array([-6]), mstype.float32) + >>> max_tensor = Tensor(np.array([6]), mstype.float32) + >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False) + >>> (gradients, input_tensor, min_tensor, max_tensor) + >>> x_gradient shape: (3, 16, 5, 5) data type: mstype.float32 + >>> min_gradient shape: (1,) data type: mstype.float32 + >>> max_gradient shape: (1,) data type: mstype.float32 + """ + @prim_attr_register + def __init__(self, + num_bits=8, + narrow_range=False): + self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) + self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + + def check_broadcast(self, min_shape, input_shape): + shape_val = 1 + for shape in input_shape: + shape_val = shape_val * shape + if min_shape[0] > 1 and min_shape[0] != shape_val: + raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.") + + def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) + self.check_broadcast(min_shape, x_shape) + return x_shape, min_shape, max_shape + + def infer_dtype(self, dout_type, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) + validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) + validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) + validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) + return x_type, min_type, max_type + + +class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer): + r""" + Fake-quantize the input and one of shape: [d], [b, d], [b, h, w, d] by per-channel min and max + + Args: + num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8. + narrow_range (bool): Whether the quantization algorithm uses narrow range or not. + if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization + range is [1, 2^num_bits-1]. Default: False. + + Inputs: + - **x** (Tensor) - Float32 tensor representing the shape of the output tensor. + - **min** (Tensor) - Value of the min range of the input data x. + - **max** (Tensor) - Value of the max range of the input data x. + + Outputs: + - Tensor, the data type and shape of output tensor is the same as input x. + + Examples: + >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32) + >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32) + >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32) + >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)( + >>> input_tensor, min_tensor, max_tensor) + >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32 + """ + @prim_attr_register + def __init__(self, + num_bits=8, + narrow_range=False): + self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) + self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) + validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) + validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) + validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) + return x_type + + +class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): + r""" + Performs grad of FakeQuantWithMinMaxVars operation. + + Args: + num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8. + narrow_range (bool): Whether the quantization algorithm uses narrow range or not. + if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization + range is [1, 2^num_bits-1]. Default: False. + + Inputs: + - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars. + - **x** (Tensor) - Float32 tensor representing the shape of the output tensor. + - **min** (Tensor) - Value of the min range of the input data x. + - **max** (Tensor) - Value of the max range of the input data x. + + Outputs: + - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x. + - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min. + - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max. + + Examples: + >>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32) + >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32) + >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32) + >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32) + >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient( + >>> num_bits=8, narrow_range=False)( + >>> gradients, input_tensor, min_tensor, max_tensor) + >>> x_gradient shape: (3, 16, 3, 4) data type: mstype.float32 + >>> min_gradient shape: (4,) data type: mstype.float32 + >>> max_gradient shape: (4,) data type: mstype.float32 + """ + @prim_attr_register + def __init__(self, + num_bits=8, + narrow_range=False): + self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) + self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + + def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) + validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name) + return x_shape, min_shape, max_shape + + def infer_dtype(self, dout_type, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) + validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) + validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) + validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) + return x_type, min_type, max_type + + class FakeQuantPerLayer(PrimitiveWithInfer): r""" Simulates the quantize and dequantize operations in training time. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index f39334f63e..0ba7e9f8f8 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -26,6 +26,7 @@ from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _inner_ops as inner +from mindspore.ops.operations._quant_ops import FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsPerChannel from ..ut_filter import non_graph_engine from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ @@ -1029,6 +1030,18 @@ test_case_math_ops = [ 'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])), [2, 3], [2, 3]], 'desc_bprop': [[2, 3]]}), + ('FakeQuantWithMinMaxVars', { + 'block': FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False), + 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32), + Tensor(np.array([-6]), mstype.float32), + Tensor(np.array([6]), mstype.float32)], + 'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)]}), + ('FakeQuantWithMinMaxVarsPerChannel', { + 'block': FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False), + 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32), + Tensor(np.array([-6, -1, -2, -3]), mstype.float32), + Tensor(np.array([6, 1, 2, 3]), mstype.float32)], + 'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32)]}), ('Rank', { 'block': P.Rank(), 'desc_inputs': [[2, 3]],