Extension interface for dense

pull/7581/head
lilei 4 years ago
parent 08dad79529
commit 127f70ce40

@ -24,7 +24,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.functional import identity
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import constexpr, Primitive
from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from mindspore._checkparam import Rel, Validator
@ -175,8 +175,8 @@ class Dense(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'.
Default: None.
activation (Union[str, Cell, Primitive]): activate function applied to the output of the fully connected layer,
eg. 'ReLU'.Default: None.
Raises:
ValueError: If weight_init or bias_init shape is incorrect.
@ -222,7 +222,9 @@ class Dense(Cell):
self.bias_add = P.BiasAdd()
self.matmul = P.MatMul(transpose_b=True)
self.activation = get_activation(activation)
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
self.activation_flag = self.activation is not None
def construct(self, x):

@ -19,6 +19,7 @@ from collections import namedtuple
import numpy as np
from mindspore import nn
import mindspore.common.dtype as mstype
from mindspore.ops.primitive import Primitive
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
@ -85,7 +86,7 @@ class Conv2dBnAct(Cell):
momentum (float): Momentum for moving average.Momentum value must be [0, 1].Default:0.9
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
1e-5.
activation (Cell): Specifies activation type. The optional values are as following:
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0. Default: 0.2.
@ -143,7 +144,9 @@ class Conv2dBnAct(Cell):
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
else:
self.activation = get_activation(activation)
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
def construct(self, x):
x = self.conv(x)
@ -170,7 +173,7 @@ class DenseBnAct(Cell):
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None.
has_bn (bool): Specifies to use batchnorm or not. Default: False.
activation (string): Specifies activation type. The optional values are as following:
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None.
after_fake(bool): Determin whether there must be a fake quantization operation after DenseBnAct.
@ -208,7 +211,9 @@ class DenseBnAct(Cell):
self.after_fake = after_fake
if has_bn:
self.batchnorm = BatchNorm1d(out_channels)
self.activation = get_activation(activation)
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
def construct(self, x):
x = self.dense(x)
@ -930,7 +935,8 @@ class DenseQuant(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
activation (Union[str, Cell, Primitive]): The regularization function applied to the output of the layer,
eg. 'relu'. Default: None.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
@ -979,7 +985,9 @@ class DenseQuant(Cell):
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.activation = get_activation(activation)
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
self.activation_flag = self.activation is not None
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6,

@ -19,6 +19,7 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.api import _executor
from ..ut_filter import non_graph_engine
@ -37,6 +38,24 @@ def test_dense_str_activation():
dense(input_data)
@non_graph_engine
def test_dense_nn_activation_():
dense = nn.Dense(1, 1, activation=nn.ReLU())
assert isinstance(dense.activation, nn.ReLU)
input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32))
dense(input_data)
@non_graph_engine
def test_dense_ops_activation_():
dense = nn.Dense(1, 1, activation=P.ReLU())
assert isinstance(dense.activation, P.ReLU)
input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32))
dense(input_data)
def test_dense_weight_error():
dim_error = Tensor(np.array([[[0.1], [0.3], [0.6]], [[0.4], [0.5], [0.2]]]))
with pytest.raises(ValueError):

Loading…
Cancel
Save