|
|
|
@ -17,6 +17,7 @@
|
|
|
|
|
from mindspore.ops.primitive import constexpr, Primitive
|
|
|
|
|
from mindspore.ops import Reshape, Transpose, Pack, Unpack
|
|
|
|
|
from mindspore.common import Tensor
|
|
|
|
|
from mindspore._checkparam import Validator
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
|
|
|
|
|
__all__ = ['TimeDistributed']
|
|
|
|
@ -69,13 +70,13 @@ class TimeDistributed(Cell):
|
|
|
|
|
Args:
|
|
|
|
|
layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped.
|
|
|
|
|
time_axis(int): The axis of time_step.
|
|
|
|
|
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: 'None'.
|
|
|
|
|
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: None.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, T, *)`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor of shape: math:'(N, T, *)'
|
|
|
|
|
Tensor of shape :math:`(N, T, *)`
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
@ -97,6 +98,9 @@ class TimeDistributed(Cell):
|
|
|
|
|
raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or "
|
|
|
|
|
"mindspore.ops.Primitive instance. You passed: {input}".format(input=layer))
|
|
|
|
|
super(TimeDistributed, self).__init__()
|
|
|
|
|
Validator.check_is_int(time_axis)
|
|
|
|
|
if reshape_with_axis is not None:
|
|
|
|
|
Validator.check_is_int(reshape_with_axis)
|
|
|
|
|
self.layer = layer
|
|
|
|
|
self.time_axis = time_axis
|
|
|
|
|
self.reshape_with_axis = reshape_with_axis
|
|
|
|
|