|
|
|
@ -195,6 +195,12 @@ class Flatten(Cell):
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return F.reshape(x, (F.shape(x)[0], -1))
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_dense_input_shape(x):
|
|
|
|
|
if len(x) < 2:
|
|
|
|
|
raise ValueError('For Dense, the dimension of input should not be less than 2, while the input dimension is '
|
|
|
|
|
+ f'{len(x)}.')
|
|
|
|
|
|
|
|
|
|
class Dense(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
The dense connected layer.
|
|
|
|
@ -278,6 +284,7 @@ class Dense(Cell):
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
x_shape = self.shape_op(x)
|
|
|
|
|
check_dense_input_shape(x_shape)
|
|
|
|
|
if len(x_shape) != 2:
|
|
|
|
|
x = self.reshape(x, (-1, x_shape[-1]))
|
|
|
|
|
x = self.matmul(x, self.weight)
|
|
|
|
|