From 3366d66034342ed10050bf439ae3fa26c4192da1 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Wed, 2 Dec 2020 11:03:34 +0800 Subject: [PATCH] fix isnan isfinite isinfinite infer type --- .../gpu/arrays/select_gpu_kernel.cc | 7 ++++ .../gpu/arrays/slice_gpu_kernel.cc | 2 ++ .../gpu/cuda_impl/assign_add_impl.cu | 2 ++ .../gpu/cuda_impl/broadcast_grad_impl.cu | 7 ++++ .../gpu/cuda_impl/broadcast_impl.cu | 13 ++++++++ .../gpu/cuda_impl/select_impl.cu | 3 ++ .../gpu/cuda_impl/slice_impl.cu | 16 +++++++++ .../kernel_compiler/gpu/kernel_constants.h | 2 +- .../gpu/math/addn_gpu_kernel.h | 12 +++---- .../gpu/math/assign_add_gpu_kernel.cc | 3 ++ .../gpu/math/broadcast_gpu_kernel.cc | 33 +++++++++++++++++++ .../gpu/math/broadcast_grad_gpu_kernel.cc | 16 +++++++++ .../gpu/nn/activation_gpu_kernel.cc | 2 -- mindspore/ops/operations/math_ops.py | 6 ++-- 14 files changed, 112 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc index 3c1323de07..4572a3cd47 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc @@ -39,5 +39,12 @@ MS_REG_GPU_KERNEL_ONE(Select, .AddInputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32), SelectGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(Select, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + SelectGpuKernel, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc index 073a31918e..3ad2a6a264 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc @@ -26,6 +26,8 @@ MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOu SliceGpuFwdKernel, half) MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), SliceGpuFwdKernel, int16_t) +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + SliceGpuFwdKernel, int64_t) MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), SliceGpuFwdKernel, uchar) MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu index 604391ccf3..61e8d84ba8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu @@ -38,3 +38,5 @@ template void CalAssignAdd(const size_t size, float* ref, const float* va template void CalAssignAdd(const size_t size, half* ref, const half* value, half* output, cudaStream_t cuda_stream); template void CalAssignAdd(const size_t size, int* ref, const int* value, int* output, cudaStream_t cuda_stream); +template void CalAssignAdd(const size_t size, int64_t* ref, const int64_t* value, int64_t* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu index fee1e3eb3b..150fb31e8e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu @@ -121,6 +121,9 @@ template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool & template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const half *x1, const half *x2, const half *dy, half *dx1, half *dx2, cudaStream_t stream); +template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, + const int64_t *x1, const int64_t *x2, const int64_t *dy, int64_t *dx1, int64_t *dx2, + cudaStream_t stream); template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1, @@ -133,3 +136,7 @@ template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const i const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const half *x1, const half *x2, const half *dy, half *dx1, half *dx2, cudaStream_t stream); +template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int64_t *x1, + const int64_t *x2, const int64_t *dy, int64_t *dx1, int64_t *dx2, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 4642412824..6f825e460b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -203,6 +203,8 @@ template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int8_t cudaStream_t stream); template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, bool *y, cudaStream_t stream); +template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, bool *y, + cudaStream_t stream); // Element-wise ArithMetic template @@ -269,6 +271,8 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int8_ cudaStream_t stream); template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, uint8_t *y, cudaStream_t stream); +template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, int64_t *y, + cudaStream_t stream); // Broadcast comparation __device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } @@ -347,6 +351,9 @@ template void BroadcastCmp(const std::vector &x0_dims, const std::vector template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, bool *y, cudaStream_t stream); +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int64_t *x0, + const int64_t *x1, bool *y, cudaStream_t stream); // Broadcast Arithmetic template @@ -468,6 +475,9 @@ template void BroadcastArith(const std::vector &x0_dims, const std::vect template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, uint8_t *y, cudaStream_t stream); +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int64_t *x0, + const int64_t *x1, int64_t *y, cudaStream_t stream); // BroadcastTo template @@ -500,3 +510,6 @@ template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, const size_t &o1, const size_t &o2, const size_t &o3, const half *input_addr, half *output_addr, cudaStream_t stream); +template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, + const size_t &o1, const size_t &o2, const size_t &o3, const int64_t *input_addr, + int64_t *output_addr, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu index f7086f8093..cacd0f844a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu @@ -40,3 +40,6 @@ template void CalSelect(const size_t size, const bool* cond, const int* inp cudaStream_t cuda_stream); template void CalSelect(const size_t size, const bool* cond, const half* input_X, const half* input_y, half* output, cudaStream_t cuda_stream); +template void CalSelect(const size_t size, const bool* cond, const int64_t* input_X, const int64_t* input_y, + int64_t* output, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu index 62ca154c18..329fc97ef4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu @@ -204,6 +204,16 @@ template void CalSliceGrad(const size_t input_size, const unsigne const std::vector in_shape, const std::vector begin, const std::vector size, unsigned char *output, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, int64_t *addr, const float value, + cudaStream_t cuda_stream); +template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const int64_t *input, int64_t *output, + cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const int64_t *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, int64_t *output, + cudaStream_t cuda_stream); + template void FillDeviceArray(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream); template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, @@ -230,6 +240,9 @@ template void StridedSlice(const std::vector &input_shape, const std::ve template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, const bool *input, bool *output, cudaStream_t cuda_stream); +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, + const int64_t *input, int64_t *output, cudaStream_t cuda_stream); template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, const float *dy, @@ -249,3 +262,6 @@ template void StridedSliceGrad(const std::vector &dy_shape, const std::v template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, const bool *dy, bool *dx, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const int64_t *dy, + int64_t *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h index 371f27437d..9dce244774 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h @@ -45,7 +45,7 @@ static constexpr float kSignedMinFloat = -3.402823466e+38F; // Used by mixprecision, cudnn dtype select static std::map kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF}, - {"kNumberTypeInt64", CUDNN_DATA_DOUBLE}, + {"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32}}; // Used by mixprecision, cuda dtype select static std::map kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h index c1a878c1df..a5f484ce94 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h @@ -51,7 +51,7 @@ class AddNGpuFwdKernel : public GpuKernel { } T *output_addr = GetDeviceAddress(outputs, 0); auto work_addr = output_addr; - for (size_t i = 0; i < IntToSize(num_input_); i++) { + for (size_t i = 0; i < num_input_; i++) { if (output_addr == GetDeviceAddress(inputs, i)) { work_addr = GetDeviceAddress(workspace, 0); break; @@ -63,7 +63,7 @@ class AddNGpuFwdKernel : public GpuKernel { } const float alpha = 1; const float beta = 0; - for (size_t i = 0; i < IntToSize(num_input_); i++) { + for (size_t i = 0; i < num_input_; i++) { T *input_addr = GetDeviceAddress(inputs, i); if (cudnn_data_type_ == CUDNN_DATA_INT32) { ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr, @@ -85,8 +85,8 @@ class AddNGpuFwdKernel : public GpuKernel { InitResource(); cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - num_input_ = static_cast(GetAttr(kernel_node, "n")); - if (IntToSize(num_input_) != input_num) { + num_input_ = GetAttr(kernel_node, "n"); + if (num_input_ != input_num) { MS_LOG(ERROR) << "Input number is " << num_input_ << " in attr, but got " << input_num << "input."; return false; } @@ -137,7 +137,7 @@ class AddNGpuFwdKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), "cudnnGetTensorSizeInBytes failed"); } - for (int i = 0; i < num_input_; i++) { + for (size_t i = 0; i < num_input_; i++) { input_size_list_.push_back(input_size_); } output_size_list_.push_back(input_size_); @@ -157,7 +157,7 @@ class AddNGpuFwdKernel : public GpuKernel { size_t output_size_; size_t workspace_size_; bool is_null_input_; - int num_input_; + size_t num_input_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc index bffcca158b..09ed79d14a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc @@ -21,6 +21,9 @@ namespace kernel { MS_REG_GPU_KERNEL_ONE( AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), AssignAddGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE( + AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + AssignAddGpuFwdKernel, int64_t) MS_REG_GPU_KERNEL_ONE( AssignAdd, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index 55f9456a8d..8eee77fb5b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -148,6 +148,39 @@ MS_REG_GPU_KERNEL_ONE( DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int) +// int64 +// int32 +MS_REG_GPU_KERNEL_ONE( + Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + Less, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + Maximum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + Mul, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) + // int8 MS_REG_GPU_KERNEL_ONE( DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc index a7e0eeebfd..4d9ef3b184 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc @@ -66,5 +66,21 @@ MS_REG_GPU_KERNEL_ONE(MaximumGrad, .AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32), BroadcastOpGradGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + BroadcastOpGradGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE(MaximumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + BroadcastOpGradGpuKernel, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc index 627a71de8c..f094bd064d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc @@ -24,8 +24,6 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOut ActivationGpuFwdKernel, half) MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ActivationGpuFwdKernel, int32_t) -MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - ActivationGpuFwdKernel, int64_t) MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ActivationGpuFwdKernel, float) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index c6258639ee..38de5e2140 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2961,7 +2961,7 @@ class IsNan(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - return mstype.bool_ + return mstype.tensor_type(mstype.bool_) class IsInf(PrimitiveWithInfer): @@ -2992,7 +2992,7 @@ class IsInf(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - return mstype.bool_ + return mstype.tensor_type(mstype.bool_) class IsFinite(PrimitiveWithInfer): @@ -3026,7 +3026,7 @@ class IsFinite(PrimitiveWithInfer): def infer_dtype(self, x_dtype): validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) - return mstype.bool_ + return mstype.tensor_type(mstype.bool_) class FloatStatus(PrimitiveWithInfer):