add float64 support to select

pull/13263/head
Peilin Wang 4 years ago
parent 483bb9de60
commit 2637133242

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,13 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Select,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
SelectGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Select,
KernelAttr()
.AddInputAttr(kNumberTypeBool)

@ -34,6 +34,8 @@ void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* i
return;
}
template void CalSelect<double>(const size_t size, const bool* cond, const double* input_X, const double* input_y,
double* output, cudaStream_t cuda_stream);
template void CalSelect<float>(const size_t size, const bool* cond, const float* input_X, const float* input_y,
float* output, cudaStream_t cuda_stream);
template void CalSelect<int>(const size_t size, const bool* cond, const int* input_X, const int* input_y, int* output,

Loading…
Cancel
Save