!3372 Fix data type bug of ApplyPowerSign' inputs.

Merge pull request !3372 from liuxiao93/fix-ApplyPowerSign
pull/3372/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0874b8768d

@ -2934,7 +2934,8 @@ class Tan(PrimitiveWithInfer):
Computes tangent of `input_x` element-wise. Computes tangent of `input_x` element-wise.
Inputs: Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. Data type should be
float16, float32 or int32.
Outputs: Outputs:
Tensor, has the same shape as `input_x`. Tensor, has the same shape as `input_x`.
@ -2953,7 +2954,8 @@ class Tan(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) valid_types = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
return x_type return x_type

@ -4281,6 +4281,7 @@ class ApplyPowerSign(PrimitiveWithInfer):
Inputs: Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type. - **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
If data type of `var` is float16, all inputs must have the same data type as `var`.
- **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. - **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **lr** (Union[Number, Tensor]) - The learning rate value, should be a scalar. - **lr** (Union[Number, Tensor]) - The learning rate value, should be a scalar.
With float32 or float16 data type. With float32 or float16 data type.
@ -4323,11 +4324,11 @@ class ApplyPowerSign(PrimitiveWithInfer):
__mindspore_signature__ = ( __mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), ('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, ('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
sig_dtype.T3), sig_dtype.T),
('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4), ('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
) )

Loading…
Cancel
Save