|
|
@ -12,7 +12,9 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# ============================================================================
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
"""basic"""
|
|
|
|
"""basic"""
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
from mindspore.common.seed import get_seed
|
|
|
|
from mindspore.common.seed import get_seed
|
|
|
@ -28,7 +30,6 @@ from mindspore.common.parameter import Parameter
|
|
|
|
from mindspore._extends import cell_attr_register
|
|
|
|
from mindspore._extends import cell_attr_register
|
|
|
|
from mindspore.common.api import ms_function
|
|
|
|
from mindspore.common.api import ms_function
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore.ops import _selected_ops
|
|
|
|
|
|
|
|
from ..cell import Cell
|
|
|
|
from ..cell import Cell
|
|
|
|
from .activation import get_activation
|
|
|
|
from .activation import get_activation
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
@ -139,10 +140,8 @@ class Flatten(Cell):
|
|
|
|
the product of the remaining dimensions.
|
|
|
|
the product of the remaining dimensions.
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
|
>>> net = nn.Flatten()
|
|
|
|
|
|
|
|
>>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
|
|
|
|
>>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
|
|
|
|
>>> input.shape
|
|
|
|
>>> net = nn.Flatten()
|
|
|
|
(2, 2, 2)
|
|
|
|
|
|
|
|
>>> net(input)
|
|
|
|
>>> net(input)
|
|
|
|
[[1.2 1.2 2.1 2.1]
|
|
|
|
[[1.2 1.2 2.1 2.1]
|
|
|
|
[2.2 2.2 3.2 3.2]]
|
|
|
|
[2.2 2.2 3.2 3.2]]
|
|
|
@ -157,9 +156,9 @@ class Flatten(Cell):
|
|
|
|
|
|
|
|
|
|
|
|
class Dense(Cell):
|
|
|
|
class Dense(Cell):
|
|
|
|
r"""
|
|
|
|
r"""
|
|
|
|
The fully connected layer.
|
|
|
|
The dense connected layer.
|
|
|
|
|
|
|
|
|
|
|
|
Applies dense-connected layer for the input. This layer implements the operation as:
|
|
|
|
Applies dense connected layer for the input. This layer implements the operation as:
|
|
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
.. math::
|
|
|
|
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
|
|
|
|
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
|
|
|
@ -190,8 +189,8 @@ class Dense(Cell):
|
|
|
|
Tensor of shape :math:`(N, out\_channels)`.
|
|
|
|
Tensor of shape :math:`(N, out\_channels)`.
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
|
>>> net = nn.Dense(3, 4)
|
|
|
|
|
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
|
|
|
|
|
|
|
>>> net = nn.Dense(3, 4)
|
|
|
|
>>> net(input)
|
|
|
|
>>> net(input)
|
|
|
|
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
|
|
|
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
|
|
|
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
|
|
|
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
|
|
@ -212,41 +211,36 @@ class Dense(Cell):
|
|
|
|
if isinstance(weight_init, Tensor):
|
|
|
|
if isinstance(weight_init, Tensor):
|
|
|
|
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
|
|
|
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
|
|
|
weight_init.shape[1] != in_channels:
|
|
|
|
weight_init.shape[1] != in_channels:
|
|
|
|
raise ValueError("weight_init shape error")
|
|
|
|
raise ValueError("Weight init shape error.")
|
|
|
|
|
|
|
|
|
|
|
|
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
|
|
|
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.bias = None
|
|
|
|
if self.has_bias:
|
|
|
|
if self.has_bias:
|
|
|
|
if isinstance(bias_init, Tensor):
|
|
|
|
if isinstance(bias_init, Tensor):
|
|
|
|
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
|
|
|
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
|
|
|
raise ValueError("bias_init shape error")
|
|
|
|
raise ValueError("Bias init shape error.")
|
|
|
|
|
|
|
|
|
|
|
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
|
|
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
|
|
|
|
|
|
|
self.bias_add = P.BiasAdd()
|
|
|
|
|
|
|
|
|
|
|
|
self.matmul = P.MatMul(transpose_b=True)
|
|
|
|
self.matmul = P.MatMul(transpose_b=True)
|
|
|
|
self.bias_add = _selected_ops.BiasAdd()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.activation = get_activation(activation)
|
|
|
|
self.activation = get_activation(activation)
|
|
|
|
self.activation_flag = self.activation is not None
|
|
|
|
self.activation_flag = self.activation is not None
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
def construct(self, x):
|
|
|
|
output = self.matmul(x, self.weight)
|
|
|
|
x = self.matmul(x, self.weight)
|
|
|
|
if self.has_bias:
|
|
|
|
if self.has_bias:
|
|
|
|
output = self.bias_add(output, self.bias)
|
|
|
|
x = self.bias_add(x, self.bias)
|
|
|
|
if self.activation_flag:
|
|
|
|
if self.activation_flag:
|
|
|
|
return self.activation(output)
|
|
|
|
x = self.activation(x)
|
|
|
|
return output
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
def extend_repr(self):
|
|
|
|
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
|
|
|
|
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
|
|
|
|
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
|
|
|
|
|
|
|
|
if self.has_bias:
|
|
|
|
if self.has_bias:
|
|
|
|
str_info = str_info + ', bias={}'.format(self.bias)
|
|
|
|
s += ', has_bias={}'.format(self.has_bias)
|
|
|
|
|
|
|
|
|
|
|
|
if self.activation_flag:
|
|
|
|
if self.activation_flag:
|
|
|
|
str_info = str_info + ', activation={}'.format(self.activation)
|
|
|
|
s += ', activation={}'.fomat(self.activation)
|
|
|
|
|
|
|
|
return s
|
|
|
|
return str_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
@constexpr
|
|
|
|