|
|
|
@ -67,7 +67,7 @@ class CholeskyGPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_count);
|
|
|
|
|
auto* info_ptr = reinterpret_cast<int*>(info->ptr());
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9020
|
|
|
|
|
#if CUDA_VERSION >= 9020 && !defined(_WIN32)
|
|
|
|
|
if (batch_count > 1) {
|
|
|
|
|
std::vector<T*> output_ptrs;
|
|
|
|
|
for (int i = 0; i < batch_count; i++) {
|
|
|
|
@ -93,7 +93,7 @@ class CholeskyGPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
Potrf(dev_ctx, uplo, m, out_data + i * m * m, m, info_ptr + i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9020
|
|
|
|
|
#if CUDA_VERSION >= 9020 && !defined(_WIN32)
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
@ -126,7 +126,7 @@ class CholeskyGPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
FUNC_WITH_TYPES(POTRF_INSTANCE);
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9020
|
|
|
|
|
#if CUDA_VERSION >= 9020 && !defined(_WIN32)
|
|
|
|
|
#define POTRF_BATCH_INSTANCE(T, C) \
|
|
|
|
|
template <> \
|
|
|
|
|
void CholeskyGPUKernel<T>::PotrfBatched( \
|
|
|
|
|