From 2637133242834a1bc37da7b1e3109ad1befb6e1d Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Fri, 12 Mar 2021 16:01:25 -0500 Subject: [PATCH] add float64 support to select --- .../kernel_compiler/gpu/arrays/select_gpu_kernel.cc | 9 ++++++++- .../backend/kernel_compiler/gpu/cuda_impl/select_impl.cu | 2 ++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc index 4572a3cd47..d0c9ba1943 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,13 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE(Select, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + SelectGpuKernel, double) MS_REG_GPU_KERNEL_ONE(Select, KernelAttr() .AddInputAttr(kNumberTypeBool) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu index cacd0f844a..393dd683bf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu @@ -34,6 +34,8 @@ void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* i return; } +template void CalSelect(const size_t size, const bool* cond, const double* input_X, const double* input_y, + double* output, cudaStream_t cuda_stream); template void CalSelect(const size_t size, const bool* cond, const float* input_X, const float* input_y, float* output, cudaStream_t cuda_stream); template void CalSelect(const size_t size, const bool* cond, const int* input_X, const int* input_y, int* output,