Extension interface for dense

pull/7581/head
lilei 4 years ago
parent 08dad79529
commit 127f70ce40

@ -24,7 +24,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops.functional import identity from mindspore.ops.functional import identity
from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr, Primitive
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore._checkparam import Rel, Validator from mindspore._checkparam import Rel, Validator
@ -175,8 +175,8 @@ class Dense(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'. activation (Union[str, Cell, Primitive]): activate function applied to the output of the fully connected layer,
Default: None. eg. 'ReLU'.Default: None.
Raises: Raises:
ValueError: If weight_init or bias_init shape is incorrect. ValueError: If weight_init or bias_init shape is incorrect.
@ -222,7 +222,9 @@ class Dense(Cell):
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
self.matmul = P.MatMul(transpose_b=True) self.matmul = P.MatMul(transpose_b=True)
self.activation = get_activation(activation) self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
self.activation_flag = self.activation is not None self.activation_flag = self.activation is not None
def construct(self, x): def construct(self, x):

@ -19,6 +19,7 @@ from collections import namedtuple
import numpy as np import numpy as np
from mindspore import nn from mindspore import nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops.primitive import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
@ -85,7 +86,7 @@ class Conv2dBnAct(Cell):
momentum (float): Momentum for moving average.Momentum value must be [0, 1].Default:0.9 momentum (float): Momentum for moving average.Momentum value must be [0, 1].Default:0.9
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
1e-5. 1e-5.
activation (Cell): Specifies activation type. The optional values are as following: activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0. Default: 0.2. alpha (float): Slope of the activation function at x < 0. Default: 0.2.
@ -143,7 +144,9 @@ class Conv2dBnAct(Cell):
if activation == "leakyrelu": if activation == "leakyrelu":
self.activation = LeakyReLU(alpha) self.activation = LeakyReLU(alpha)
else: else:
self.activation = get_activation(activation) self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
def construct(self, x): def construct(self, x):
x = self.conv(x) x = self.conv(x)
@ -170,7 +173,7 @@ class DenseBnAct(Cell):
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None. activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None.
has_bn (bool): Specifies to use batchnorm or not. Default: False. has_bn (bool): Specifies to use batchnorm or not. Default: False.
activation (string): Specifies activation type. The optional values are as following: activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None. 'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None.
after_fake(bool): Determin whether there must be a fake quantization operation after DenseBnAct. after_fake(bool): Determin whether there must be a fake quantization operation after DenseBnAct.
@ -208,7 +211,9 @@ class DenseBnAct(Cell):
self.after_fake = after_fake self.after_fake = after_fake
if has_bn: if has_bn:
self.batchnorm = BatchNorm1d(out_channels) self.batchnorm = BatchNorm1d(out_channels)
self.activation = get_activation(activation) self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
def construct(self, x): def construct(self, x):
x = self.dense(x) x = self.dense(x)
@ -930,7 +935,8 @@ class DenseQuant(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. activation (Union[str, Cell, Primitive]): The regularization function applied to the output of the layer,
eg. 'relu'. Default: None.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
@ -979,7 +985,9 @@ class DenseQuant(Cell):
self.matmul = P.MatMul(transpose_b=True) self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
self.activation = get_activation(activation) self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
self.activation_flag = self.activation is not None self.activation_flag = self.activation is not None
self.fake_quant_weight = quant_config.weight(min_init=-6, self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6, max_init=6,

@ -19,6 +19,7 @@ import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.api import _executor from mindspore.common.api import _executor
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
@ -37,6 +38,24 @@ def test_dense_str_activation():
dense(input_data) dense(input_data)
@non_graph_engine
def test_dense_nn_activation_():
dense = nn.Dense(1, 1, activation=nn.ReLU())
assert isinstance(dense.activation, nn.ReLU)
input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32))
dense(input_data)
@non_graph_engine
def test_dense_ops_activation_():
dense = nn.Dense(1, 1, activation=P.ReLU())
assert isinstance(dense.activation, P.ReLU)
input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32))
dense(input_data)
def test_dense_weight_error(): def test_dense_weight_error():
dim_error = Tensor(np.array([[[0.1], [0.3], [0.6]], [[0.4], [0.5], [0.2]]])) dim_error = Tensor(np.array([[[0.1], [0.3], [0.6]], [[0.4], [0.5], [0.2]]]))
with pytest.raises(ValueError): with pytest.raises(ValueError):

Loading…
Cancel
Save