!11845 check wrapper layer timedistributed input type

From: @dinglongwei
Reviewed-by: @c_34,@liangchenghui
Signed-off-by: @c_34
pull/11845/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 35d0634291

@ -17,6 +17,7 @@
from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops import Reshape, Transpose, Pack, Unpack
from mindspore.common import Tensor
from mindspore._checkparam import Validator
from ..cell import Cell
__all__ = ['TimeDistributed']
@ -69,13 +70,13 @@ class TimeDistributed(Cell):
Args:
layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped.
time_axis(int): The axis of time_step.
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: 'None'.
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, T, *)`.
Outputs:
Tensor of shape: math:'(N, T, *)'
Tensor of shape :math:`(N, T, *)`
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -97,6 +98,9 @@ class TimeDistributed(Cell):
raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or "
"mindspore.ops.Primitive instance. You passed: {input}".format(input=layer))
super(TimeDistributed, self).__init__()
Validator.check_is_int(time_axis)
if reshape_with_axis is not None:
Validator.check_is_int(reshape_with_axis)
self.layer = layer
self.time_axis = time_axis
self.reshape_with_axis = reshape_with_axis

Loading…
Cancel
Save