From 127f70ce406e6715ee608e1a15fece217a3c10d5 Mon Sep 17 00:00:00 2001 From: lilei Date: Wed, 21 Oct 2020 19:13:03 +0800 Subject: [PATCH] Extension interface for dense --- mindspore/nn/layer/basic.py | 10 ++++++---- mindspore/nn/layer/quant.py | 20 ++++++++++++++------ tests/ut/python/nn/test_dense.py | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index b9f18d5494..d29c9f56e9 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -24,7 +24,7 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops.functional import identity from mindspore.ops.operations import _inner_ops as inner -from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import constexpr, Primitive from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from mindspore._checkparam import Rel, Validator @@ -175,8 +175,8 @@ class Dense(Cell): bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'. - Default: None. + activation (Union[str, Cell, Primitive]): activate function applied to the output of the fully connected layer, + eg. 'ReLU'.Default: None. Raises: ValueError: If weight_init or bias_init shape is incorrect. @@ -222,7 +222,9 @@ class Dense(Cell): self.bias_add = P.BiasAdd() self.matmul = P.MatMul(transpose_b=True) - self.activation = get_activation(activation) + self.activation = get_activation(activation) if isinstance(activation, str) else activation + if activation is not None and not isinstance(self.activation, (Cell, Primitive)): + raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) self.activation_flag = self.activation is not None def construct(self, x): diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 17ae5bcd52..bbdffa3d1b 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -19,6 +19,7 @@ from collections import namedtuple import numpy as np from mindspore import nn import mindspore.common.dtype as mstype +from mindspore.ops.primitive import Primitive from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.common.parameter import Parameter @@ -85,7 +86,7 @@ class Conv2dBnAct(Cell): momentum (float): Momentum for moving average.Momentum value must be [0, 1].Default:0.9 eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: 1e-5. - activation (Cell): Specifies activation type. The optional values are as following: + activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. alpha (float): Slope of the activation function at x < 0. Default: 0.2. @@ -143,7 +144,9 @@ class Conv2dBnAct(Cell): if activation == "leakyrelu": self.activation = LeakyReLU(alpha) else: - self.activation = get_activation(activation) + self.activation = get_activation(activation) if isinstance(activation, str) else activation + if activation is not None and not isinstance(self.activation, (Cell, Primitive)): + raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) def construct(self, x): x = self.conv(x) @@ -170,7 +173,7 @@ class DenseBnAct(Cell): has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None. has_bn (bool): Specifies to use batchnorm or not. Default: False. - activation (string): Specifies activation type. The optional values are as following: + activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: 'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None. after_fake(bool): Determin whether there must be a fake quantization operation after DenseBnAct. @@ -208,7 +211,9 @@ class DenseBnAct(Cell): self.after_fake = after_fake if has_bn: self.batchnorm = BatchNorm1d(out_channels) - self.activation = get_activation(activation) + self.activation = get_activation(activation) if isinstance(activation, str) else activation + if activation is not None and not isinstance(self.activation, (Cell, Primitive)): + raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) def construct(self, x): x = self.dense(x) @@ -930,7 +935,8 @@ class DenseQuant(Cell): bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. + activation (Union[str, Cell, Primitive]): The regularization function applied to the output of the layer, + eg. 'relu'. Default: None. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. @@ -979,7 +985,9 @@ class DenseQuant(Cell): self.matmul = P.MatMul(transpose_b=True) self.bias_add = P.BiasAdd() - self.activation = get_activation(activation) + self.activation = get_activation(activation) if isinstance(activation, str) else activation + if activation is not None and not isinstance(self.activation, (Cell, Primitive)): + raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) self.activation_flag = self.activation is not None self.fake_quant_weight = quant_config.weight(min_init=-6, max_init=6, diff --git a/tests/ut/python/nn/test_dense.py b/tests/ut/python/nn/test_dense.py index 3972f48b4d..57f4ed8083 100644 --- a/tests/ut/python/nn/test_dense.py +++ b/tests/ut/python/nn/test_dense.py @@ -19,6 +19,7 @@ import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor +from mindspore.ops import operations as P from mindspore.common.api import _executor from ..ut_filter import non_graph_engine @@ -37,6 +38,24 @@ def test_dense_str_activation(): dense(input_data) +@non_graph_engine +def test_dense_nn_activation_(): + dense = nn.Dense(1, 1, activation=nn.ReLU()) + assert isinstance(dense.activation, nn.ReLU) + + input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32)) + dense(input_data) + + +@non_graph_engine +def test_dense_ops_activation_(): + dense = nn.Dense(1, 1, activation=P.ReLU()) + assert isinstance(dense.activation, P.ReLU) + + input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32)) + dense(input_data) + + def test_dense_weight_error(): dim_error = Tensor(np.array([[[0.1], [0.3], [0.6]], [[0.4], [0.5], [0.2]]])) with pytest.raises(ValueError):