From 37ee246c705178b98780bbd33649c768e5dd9c75 Mon Sep 17 00:00:00 2001 From: liangchenghui Date: Wed, 4 Nov 2020 19:27:49 +0800 Subject: [PATCH] Add quant ops. --- mindspore/ops/_grad/grad_quant_ops.py | 24 +++ mindspore/ops/_op_impl/tbe/__init__.py | 5 + .../_op_impl/tbe/act_ulq_clamp_max_grad.py | 38 ++++ .../_op_impl/tbe/act_ulq_clamp_min_grad.py | 38 ++++ mindspore/ops/_op_impl/tbe/acts_ulq.py | 45 ++++ .../ops/_op_impl/tbe/acts_ulq_input_grad.py | 38 ++++ mindspore/ops/_op_impl/tbe/wts_arq.py | 40 ++++ mindspore/ops/operations/_quant_ops.py | 202 ++++++++++++++++++ 8 files changed, 430 insertions(+) create mode 100644 mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py create mode 100644 mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py create mode 100644 mindspore/ops/_op_impl/tbe/acts_ulq.py create mode 100644 mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py create mode 100644 mindspore/ops/_op_impl/tbe/wts_arq.py diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py index 2c0ad42569..9abbd4119d 100644 --- a/mindspore/ops/_grad/grad_quant_ops.py +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -176,3 +176,27 @@ def get_bprop_fakequant_with_minmax_per_channel_update(self): return zeros_like(x), zeros_like(x_min), zeros_like(x_max) return bprop + + +@bprop_getters.register(Q.ActsULQ) +def get_bprop_acts_ulq(self): + """Grad definition for 'ActsULQ' operation""" + op = Q.ActsULQInputGrad() + op1 = Q.ActULQClampMinGrad() + op2 = Q.ActULQClampMaxGrad() + def bprop(x, clamp_min, clamp_max, out, dout): + dx = op(dout[0], out[1], out[2]) + dx1 = op1(dout[0], out[1], out[3]) + dx2 = op2(dout[0], out[2], out[3]) + return (dx, dx1, dx2) + + return bprop + + +@bprop_getters.register(Q.WtsARQ) +def get_bprop_wts_arq(self): + """Grad definition for 'WtsArq' operation""" + def bprop(w, w_min, w_max, out, dout): + return (dout, zeros_like(w_min), zeros_like(w_max)) + + return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index a27f14412a..e81dcae6c8 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -325,6 +325,11 @@ from .parallel_concat import _parallel_concat_tbe from .adam_apply_one_assign import _adam_apply_one_assign_tbe from .adam_apply_one_with_decay_assign import _adam_apply_one_with_decay_assign_tbe from .ifmr import _ifmr_tbe +from .acts_ulq import _acts_ulq_tbe +from .acts_ulq_input_grad import _acts_ulq_input_grad_tbe +from .act_ulq_clamp_min_grad import _act_ulq_clamp_min_grad_tbe +from .act_ulq_clamp_max_grad import _act_ulq_clamp_max_grad_tbe +from .wts_arq import _wts_arq_tbe from .fake_quant_with_min_max_vars import _fake_quant_with_min_max_vars_tbe from .fake_quant_with_min_max_vars_gradient import _fake_quant_with_min_max_vars_gradient_tbe from .fake_quant_with_min_max_vars_per_channel import _fake_quant_with_min_max_vars_per_channel_tbe diff --git a/mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py b/mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py new file mode 100644 index 0000000000..74ed7bfb3c --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ActULQClampMaxGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +act_ulq_clamp_max_grad_op_info = TBERegOp("ActULQClampMaxGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("act_ulq_clamp_max_grad.so") \ + .compute_cost(10) \ + .kernel_name("act_ulq_clamp_max_grad") \ + .partial_flag(True) \ + .input(0, "input_x", False, "required", "all") \ + .input(1, "input_y", False, "required", "all") \ + .input(2, "input_z", False, "required", "all") \ + .output(0, "output", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(act_ulq_clamp_max_grad_op_info) +def _act_ulq_clamp_max_grad_tbe(): + """ActULQClampMaxGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py b/mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py new file mode 100644 index 0000000000..4d16863a54 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ActULQClampMinGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +act_ulq_clamp_min_grad_op_info = TBERegOp("ActULQClampMinGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("act_ulq_clamp_min_grad.so") \ + .compute_cost(10) \ + .kernel_name("act_ulq_clamp_min_grad") \ + .partial_flag(True) \ + .input(0, "input_x", False, "required", "all") \ + .input(1, "input_y", False, "required", "all") \ + .input(2, "input_z", False, "required", "all") \ + .output(0, "output", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(act_ulq_clamp_min_grad_op_info) +def _act_ulq_clamp_min_grad_tbe(): + """ActULQClampMinGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/acts_ulq.py b/mindspore/ops/_op_impl/tbe/acts_ulq.py new file mode 100644 index 0000000000..40e72ac863 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/acts_ulq.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ActsULQ op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +acts_ulq_op_info = TBERegOp("ActsULQ") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("acts_ulq.so") \ + .compute_cost(10) \ + .kernel_name("acts_ulq") \ + .partial_flag(True) \ + .attr("fixed_min", "optional", "bool", "all") \ + .attr("num_bits", "optional", "int", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "clamp_min", False, "required", "all") \ + .input(2, "clamp_max", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .output(1, "clamp_min_mask", False, "required", "all") \ + .output(2, "clamp_max_mask", False, "required", "all") \ + .output(3, "x_clamped_loss", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(acts_ulq_op_info) +def _acts_ulq_tbe(): + """ActsULQ TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py b/mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py new file mode 100644 index 0000000000..eef47cd4f2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ActsULQInputGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +acts_ulq_input_grad_op_info = TBERegOp("ActsULQInputGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("acts_ulq_input_grad.so") \ + .compute_cost(10) \ + .kernel_name("acts_ulq_input_grad") \ + .partial_flag(True) \ + .input(0, "y_grad", False, "required", "all") \ + .input(1, "clamp_min_mask", False, "required", "all") \ + .input(2, "clamp_max_mask", False, "required", "all") \ + .output(0, "x_grad", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(acts_ulq_input_grad_op_info) +def _acts_ulq_input_grad_tbe(): + """ActsULQInputGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/wts_arq.py b/mindspore/ops/_op_impl/tbe/wts_arq.py new file mode 100644 index 0000000000..a7d2638480 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/wts_arq.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""WtsARQ op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +wts_arq_op_info = TBERegOp("WtsARQ") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("wts_arq.so") \ + .compute_cost(10) \ + .kernel_name("wts_arq") \ + .partial_flag(True) \ + .attr("num_bits", "optional", "int", "all") \ + .attr("offset_flag", "optional", "bool", "all") \ + .input(0, "w", False, "required", "all") \ + .input(1, "w_min", False, "required", "all") \ + .input(2, "w_max", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(wts_arq_op_info) +def _wts_arq_tbe(): + """WtsARQ TBE register""" + return diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index f6df84f18e..3a41a68125 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -1197,3 +1197,205 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): def infer_dtype(self, dout_type, x_type): validator.check("dout type", dout_type, "x type", x_type) return dout_type, dout_type + + +class ActsULQ(PrimitiveWithInfer): + """ + The ActsULQ(Activation universal learnable quantization). + + Args: + fixed_min (bool): whether fix clamp min to zero. + num_bits (int): The bits num used for quantize. + + Inputs: + - **x** (Tensor) - A Tensor of feature map. With float16 or float32 data type. + - **clamp_min** (Tensor) - A Tensor of clamp min with the same type as x. + - **clamp_max** (Tensor) - A Tensor of clamp max with the same type as x. + + Outputs: + - **y** (Tensor) - A tensor of fake quant of feature map with the same type as `w`. + - **clamp_min** (Tensor) - A tensor of boolean masks if data in feature map >= clamp_min. + - **clamp_max** (Tensor) - A tensor of boolean masks if data in feature map <= clamp_max. + - **x_clamped_loss** (Tensor) - A tensor of clamped loss. + + Examples: + >>> data_type = np.float32 + >>> x= np.random.uniform(-10, 10, (32, 120)).astype(data_type) + >>> clamp_max = 0.7 * np.max(x) + >>> clamp_min = 0.7 * np.min(x) + >>> clamp_max = np.array([clamp_max], dtype=data_type) + >>> clamp_min = np.array([clamp_min], dtype=data_type) + >>> acts_ulq = Q.ActsULQ(fixed_mini=True, num_bits=8) + >>> quant_x, clamp_min_mask, clamp_max_mask, x_clamped_loss = acts_ulq(Tensor(x), Tensor( clamp_min), + Tensor(clamp_max)) + """ + @prim_attr_register + def __init__(self, fixed_min=False, num_bits=8): + validator.check_value_type("fixed_min", fixed_min, [bool], self.name) + validator.check_value_type("num_bits", num_bits, [int], self.name) + validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name) + + def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape): + """infer shape of primitive""" + validator.check_int(len(clamp_min_shape), len(x_shape), Rel.EQ, "dims of clamp_min", self.name) + validator.check_int(len(clamp_max_shape), len(x_shape), Rel.EQ, "dims of clamp_max", self.name) + + x_shape_len = len(x_shape) + for i in range(x_shape_len): + validator.check_int(clamp_min_shape[i], 1, Rel.EQ, "dims of clamp_min", self.name) + validator.check_int(clamp_max_shape[i], 1, Rel.EQ, "dims of clamp_max", self.name) + + return x_shape, x_shape, x_shape, x_shape + + def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype): + """infer dtype of primitive""" + valid_types = [mstype.float32, mstype.float16] + validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"clamp_min": clamp_min_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"clamp_max": clamp_max_dtype}, valid_types, self.name) + + return x_dtype, mstype.bool_, mstype.bool_, x_dtype + + +class ActsULQInputGrad(PrimitiveWithInfer): + """ + The ActsULQInputGrad(grad of ActsULQ). + + Inputs: + - **y_grad** (Tensor) - A Tensor of grad. With float16 or float32 data type. + + Outputs: + - **x_grad** (Tensor) - A tensor of data grad with the same type as `y_grad`. + """ + @prim_attr_register + def __init__(self): + pass + + def infer_shape(self, y_grad_shape, clamp_min_mask_shape, clamp_max_mask_shape): + return y_grad_shape + + def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type): + valid_types = [mstype.float32, mstype.float16] + validator.check_tensor_type_same({"y_grad": y_grad_type}, valid_types, self.name) + return y_grad_type + + +class ActULQClampMinGrad(PrimitiveWithInfer): + """ + The ActULQClampMinGrad(Activation Universal Linear Quantization on Clamp Minimum Gradient) + + Inputs: + - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type. + - **clamp_min_mask** - A tensor of mask, only support int8 type. + - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad". + + Outputs: + - **clamp_min_grad** - A tensor of clamp minimum gradient, with the same type as "y_grad". + The length of tensor is 1. + + Examples: + >>> data_type = np.float32 + >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type) + >>> clamp_min_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0) + >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type) + >>> act_ulq_clamp_min_grad = Q.ActULQClampMinGrad() + >>> clamp_min_grad = act_ulq_clamp_min_grad(Tensor(y_grad), Tensor(clamp_min_mask, mindspore.bool_), + Tensor(x_clamped_loss)) + """ + @prim_attr_register + def __init__(self): + pass + + def infer_shape(self, input_x, input_y, input_z): + input_x_len = len(input_x) + output_shape = [] + for _ in range(input_x_len): + output_shape.append(1) + return tuple(output_shape) + + def infer_dtype(self, input_x, input_y, input_z): + return input_x + + +class ActULQClampMaxGrad(PrimitiveWithInfer): + """ + The ActULQClampMaxGrad(Activation Universal Linear Quantization on Clamp Maximum Gradient) + + Inputs: + - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type. + - **clamp_max_mask** - A tensor of mask, only support int8 type. + - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad". + + Outputs: + - **clamp_max_grad** - A tensor of clamp maximum gradient, with the same type as "y_grad". + The length of tensor is 1. + + Examples: + >>> data_type = np.float32 + >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type) + >>> clamp_max_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0) + >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type) + >>> act_ulq_clamp_max_grad = Q.ActULQClampMaxGrad() + >>> clamp_max_grad = act_ulq_clamp_max_grad(Tensor(y_grad), Tensor(clamp_max_mask, mindspore.bool_), + Tensor(x_clamped_loss)) + """ + @prim_attr_register + def __init__(self): + pass + + def infer_shape(self, input_x, input_y, input_z): + input_x_len = len(input_x) + output_shape = [] + for _ in range(input_x_len): + output_shape.append(1) + return tuple(output_shape) + + def infer_dtype(self, input_x, input_y, input_z): + return input_x + + +class WtsARQ(PrimitiveWithInfer): + """ + The WtsARQ(Weights Adaptive Range Quantization). + + Args: + axes (list): Specify channels for ARQ algorithm. + num_bits (int): The bits num used for quantize. + offset_flag (bool): Whether use offset for quantize. + + Inputs: + - **w** (Tensor) - A Tensor of weights. With float16 or float32 data type. + + Outputs: + - **scale** (Tensor) - A tensor of optimal scale, has the same type as `w`. + - **offset** (Tensor) - A tensor of optimal offset, has the same type as `w`. + - If axis is [], + the shape of scale and offset is :math:`(1, )`. + - If axis is [0], + the shape of scale and offset is :math:`(w_1, )`. + - If axis is [1], + the shape of scale and offset is :math:`(w_2, )`. + - **y** (Tensor) - A tensor of fakequant weights, has the same type and shape as `w`. + + Examples: + >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32)) + >>> wts_arq = Q.WtsARQ(axes=[0], num_bits=8, offset_flag=False) + >>> scale, offset, y = wts_arq(data) + """ + @prim_attr_register + def __init__(self, num_bits, offset_flag): + validator.check_value_type("num_bits", num_bits, [int], self.name) + validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name) + validator.check_value_type("offset_flag", offset_flag, [bool], self.name) + + def infer_shape(self, w_shape, w_min_shape, w_max_shape): + validator.check_int(len(w_min_shape), len(w_shape), Rel.EQ, "dims of w_min", self.name) + validator.check_int(len(w_max_shape), len(w_shape), Rel.EQ, "dims of w_max", self.name) + return w_shape + + def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype): + valid_types = [mstype.float32, mstype.float16] + validator.check_tensor_type_same({"w": w_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"w_min": w_min_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"w_max": w_max_dtype}, valid_types, self.name) + return w_dtype