Integrate NVRTC to support compiling CUDA kernel at runtime (#19422)
* Add the dynamic load of nvrtc, and support runtime compiling of CUDA kernel using nvrtc. test=develop * Call CUDA driver api to launch the kernel compiled by nvrtc. test=develop * Disable for mac and windows. test=develop * Refine the codes to support manually specified num_threads and workload_per_thread. test=develop * Refine the CUDA kernel to support large dims. test=developsigmoid_bug
parent
3ae939e48a
commit
42b5bec6f9
@ -0,0 +1,123 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/platform/device_code.h"
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
inline bool is_error(nvrtcResult stat) { return stat != NVRTC_SUCCESS; }
|
||||
|
||||
inline void throw_on_error(nvrtcResult stat, const std::string& msg) {
|
||||
#ifndef REPLACE_ENFORCE_GLOG
|
||||
throw std::runtime_error(dynload::nvrtcGetErrorString(stat) + msg);
|
||||
#else
|
||||
LOG(FATAL) << dynload::nvrtcGetErrorString(stat) << msg;
|
||||
#endif
|
||||
}
|
||||
|
||||
CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name,
|
||||
const std::string& kernel) {
|
||||
if (!is_gpu_place(place)) {
|
||||
PADDLE_THROW("CUDADeviceCode can only launch on GPU place.");
|
||||
}
|
||||
|
||||
place_ = place;
|
||||
name_ = name;
|
||||
kernel_ = kernel;
|
||||
}
|
||||
|
||||
void CUDADeviceCode::Compile() {
|
||||
nvrtcProgram program;
|
||||
PADDLE_ENFORCE_EQ(dynload::nvrtcCreateProgram(&program,
|
||||
kernel_.c_str(), // buffer
|
||||
name_.c_str(), // name
|
||||
0, // numHeaders
|
||||
nullptr, // headers
|
||||
nullptr), // includeNames
|
||||
NVRTC_SUCCESS,
|
||||
"nvrtcCreateProgram failed.");
|
||||
|
||||
// Compile the program for specified compute_capability
|
||||
auto* dev_ctx = reinterpret_cast<CUDADeviceContext*>(
|
||||
DeviceContextPool::Instance().Get(place_));
|
||||
int compute_capability = dev_ctx->GetComputeCapability();
|
||||
std::string compute_flag =
|
||||
"--gpu-architecture=compute_" + std::to_string(compute_capability);
|
||||
const std::vector<const char*> options = {"--std=c++11",
|
||||
compute_flag.c_str()};
|
||||
nvrtcResult compile_result =
|
||||
dynload::nvrtcCompileProgram(program, // program
|
||||
options.size(), // numOptions
|
||||
options.data()); // options
|
||||
if (compile_result == NVRTC_ERROR_COMPILATION) {
|
||||
// Obtain compilation log from the program
|
||||
size_t log_size;
|
||||
PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLogSize(program, &log_size),
|
||||
NVRTC_SUCCESS, "nvrtcGetProgramLogSize failed.");
|
||||
std::vector<char> log;
|
||||
log.resize(log_size + 1);
|
||||
PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLog(program, log.data()),
|
||||
NVRTC_SUCCESS, "nvrtcGetProgramLog failed.");
|
||||
LOG(FATAL) << "JIT compiling of CUDA code failed:\n" << log.data();
|
||||
}
|
||||
|
||||
// Obtain PTX from the program
|
||||
size_t ptx_size;
|
||||
PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTXSize(program, &ptx_size), NVRTC_SUCCESS,
|
||||
"nvrtcGetPTXSize failed.");
|
||||
ptx_.resize(ptx_size + 1);
|
||||
PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTX(program, ptx_.data()), NVRTC_SUCCESS,
|
||||
"nvrtcGetPTX failed.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(dynload::nvrtcDestroyProgram(&program), NVRTC_SUCCESS,
|
||||
"nvrtcDestroyProgram failed.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dynload::cuModuleLoadData(&module_, ptx_.data()), CUDA_SUCCESS,
|
||||
"Fail to load PTX of %s (in cuModuleLoadData.)", name_.c_str());
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dynload::cuModuleGetFunction(&function_, module_, name_.c_str()),
|
||||
CUDA_SUCCESS, "Fail to get function of %s (in cuModuleGetFunction.)",
|
||||
name_.c_str());
|
||||
|
||||
max_threads_ = dev_ctx->GetMaxPhysicalThreadCount();
|
||||
}
|
||||
|
||||
void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
|
||||
int max_blocks = std::max(max_threads_ / num_threads_, 1);
|
||||
int workload_per_block = workload_per_thread_ * num_threads_;
|
||||
int num_blocks =
|
||||
std::min(max_blocks, (static_cast<int>(n) + workload_per_block - 1) /
|
||||
workload_per_block);
|
||||
|
||||
auto* dev_ctx = reinterpret_cast<CUDADeviceContext*>(
|
||||
DeviceContextPool::Instance().Get(place_));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dynload::cuLaunchKernel(function_, num_blocks, 1, 1, // grid dim
|
||||
num_threads_, 1, 1, // block dim
|
||||
0, // shared memory
|
||||
dev_ctx->stream(), // stream
|
||||
args->data(), // arguments
|
||||
nullptr),
|
||||
CUDA_SUCCESS, "Fail to launch kernel %s (in cuLaunchKernel.)",
|
||||
name_.c_str());
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,64 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include "paddle/fluid/platform/dynload/cuda_driver.h"
|
||||
#include "paddle/fluid/platform/dynload/nvrtc.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
class DeviceCode {
|
||||
public:
|
||||
virtual ~DeviceCode() {}
|
||||
virtual void Compile() = 0;
|
||||
virtual void Launch(const size_t n, std::vector<void*>* args) const = 0;
|
||||
|
||||
protected:
|
||||
Place place_;
|
||||
std::string name_;
|
||||
std::string kernel_;
|
||||
};
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
class CUDADeviceCode : public DeviceCode {
|
||||
public:
|
||||
explicit CUDADeviceCode(const Place& place, const std::string& name,
|
||||
const std::string& kernel);
|
||||
void Compile() override;
|
||||
void Launch(const size_t n, std::vector<void*>* args) const override;
|
||||
|
||||
void SetNumThreads(int num_threads) { num_threads_ = num_threads; }
|
||||
void SetWorkloadPerThread(int workload_per_thread) {
|
||||
workload_per_thread_ = workload_per_thread;
|
||||
}
|
||||
|
||||
private:
|
||||
int max_threads_{0};
|
||||
int num_threads_{1024};
|
||||
int workload_per_thread_{1};
|
||||
std::vector<char> ptx_;
|
||||
CUmodule module_;
|
||||
CUfunction function_;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,78 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/platform/device_code.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/platform/init.h"
|
||||
|
||||
constexpr auto saxpy_code = R"(
|
||||
extern "C" __global__
|
||||
void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) {
|
||||
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n;
|
||||
tid += blockDim.x * gridDim.x) {
|
||||
z[tid] = a * x[tid] + y[tid];
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
TEST(device_code, cuda) {
|
||||
paddle::framework::InitDevices(false, {0});
|
||||
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
|
||||
paddle::platform::CUDADeviceCode code(place, "saxpy_kernel", saxpy_code);
|
||||
|
||||
paddle::framework::Tensor cpu_x;
|
||||
paddle::framework::Tensor cpu_y;
|
||||
paddle::framework::Tensor cpu_z;
|
||||
|
||||
float scale = 2;
|
||||
auto dims = paddle::framework::make_ddim(
|
||||
{static_cast<int64_t>(256), static_cast<int64_t>(1024)});
|
||||
cpu_x.mutable_data<float>(dims, paddle::platform::CPUPlace());
|
||||
cpu_y.mutable_data<float>(dims, paddle::platform::CPUPlace());
|
||||
|
||||
size_t n = cpu_x.numel();
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
cpu_x.data<float>()[i] = static_cast<float>(i);
|
||||
}
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
cpu_y.data<float>()[i] = static_cast<float>(0.5);
|
||||
}
|
||||
|
||||
paddle::framework::Tensor x;
|
||||
paddle::framework::Tensor y;
|
||||
paddle::framework::Tensor z;
|
||||
|
||||
float* x_data = x.mutable_data<float>(dims, place);
|
||||
float* y_data = y.mutable_data<float>(dims, place);
|
||||
float* z_data = z.mutable_data<float>(dims, place);
|
||||
|
||||
TensorCopySync(cpu_x, place, &x);
|
||||
TensorCopySync(cpu_y, place, &y);
|
||||
|
||||
code.Compile();
|
||||
|
||||
std::vector<void*> args = {&scale, &x_data, &y_data, &z_data, &n};
|
||||
code.SetNumThreads(1024);
|
||||
code.SetWorkloadPerThread(1);
|
||||
code.Launch(n, &args);
|
||||
|
||||
TensorCopySync(z, paddle::platform::CPUPlace(), &cpu_z);
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
PADDLE_ENFORCE_EQ(cpu_z.data<float>()[i],
|
||||
static_cast<float>(i) * scale + 0.5);
|
||||
}
|
||||
}
|
||||
#endif
|
@ -0,0 +1,30 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/platform/dynload/cuda_driver.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace dynload {
|
||||
|
||||
std::once_flag cuda_dso_flag;
|
||||
void* cuda_dso_handle = nullptr;
|
||||
|
||||
#define DEFINE_WRAP(__name) DynLoad__##__name __name
|
||||
|
||||
CUDA_ROUTINE_EACH(DEFINE_WRAP);
|
||||
|
||||
} // namespace dynload
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,79 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <mutex> // NOLINT
|
||||
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
|
||||
#include "paddle/fluid/platform/port.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace dynload {
|
||||
|
||||
extern std::once_flag cuda_dso_flag;
|
||||
extern void* cuda_dso_handle;
|
||||
|
||||
#ifdef PADDLE_USE_DSO
|
||||
|
||||
#define DECLARE_DYNAMIC_LOAD_CUDA_WRAP(__name) \
|
||||
struct DynLoad__##__name { \
|
||||
template <typename... Args> \
|
||||
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
|
||||
using cuda_func = decltype(&::__name); \
|
||||
std::call_once(cuda_dso_flag, []() { \
|
||||
cuda_dso_handle = paddle::platform::dynload::GetCUDADsoHandle(); \
|
||||
}); \
|
||||
static void* p_##__name = dlsym(cuda_dso_handle, #__name); \
|
||||
return reinterpret_cast<cuda_func>(p_##__name)(args...); \
|
||||
} \
|
||||
}; \
|
||||
extern struct DynLoad__##__name __name
|
||||
|
||||
#else
|
||||
|
||||
#define DECLARE_DYNAMIC_LOAD_CUDA_WRAP(__name) \
|
||||
struct DynLoad__##__name { \
|
||||
template <typename... Args> \
|
||||
inline auto operator()(Args... args) { \
|
||||
return ::__name(args...); \
|
||||
} \
|
||||
}; \
|
||||
extern DynLoad__##__name __name
|
||||
|
||||
#endif
|
||||
|
||||
/**
|
||||
* include all needed cuda driver functions
|
||||
**/
|
||||
#define CUDA_ROUTINE_EACH(__macro) \
|
||||
__macro(cuGetErrorString); \
|
||||
__macro(cuModuleLoadData); \
|
||||
__macro(cuModuleGetFunction); \
|
||||
__macro(cuModuleUnload); \
|
||||
__macro(cuOccupancyMaxActiveBlocksPerMultiprocessor); \
|
||||
__macro(cuLaunchKernel); \
|
||||
__macro(cuCtxCreate); \
|
||||
__macro(cuCtxGetCurrent); \
|
||||
__macro(cuDeviceGet); \
|
||||
__macro(cuDevicePrimaryCtxGetState)
|
||||
|
||||
CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
|
||||
|
||||
#undef DECLARE_DYNAMIC_LOAD_CUDA_WRAP
|
||||
|
||||
} // namespace dynload
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,30 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/platform/dynload/nvrtc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace dynload {
|
||||
|
||||
std::once_flag nvrtc_dso_flag;
|
||||
void* nvrtc_dso_handle = nullptr;
|
||||
|
||||
#define DEFINE_WRAP(__name) DynLoad__##__name __name
|
||||
|
||||
NVRTC_ROUTINE_EACH(DEFINE_WRAP);
|
||||
|
||||
} // namespace dynload
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,77 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <nvrtc.h>
|
||||
#include <mutex> // NOLINT
|
||||
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
|
||||
#include "paddle/fluid/platform/port.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace dynload {
|
||||
|
||||
extern std::once_flag nvrtc_dso_flag;
|
||||
extern void* nvrtc_dso_handle;
|
||||
|
||||
#ifdef PADDLE_USE_DSO
|
||||
|
||||
#define DECLARE_DYNAMIC_LOAD_NVRTC_WRAP(__name) \
|
||||
struct DynLoad__##__name { \
|
||||
template <typename... Args> \
|
||||
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
|
||||
using nvrtc_func = decltype(&::__name); \
|
||||
std::call_once(nvrtc_dso_flag, []() { \
|
||||
nvrtc_dso_handle = paddle::platform::dynload::GetNVRTCDsoHandle(); \
|
||||
}); \
|
||||
static void* p_##__name = dlsym(nvrtc_dso_handle, #__name); \
|
||||
return reinterpret_cast<nvrtc_func>(p_##__name)(args...); \
|
||||
} \
|
||||
}; \
|
||||
extern struct DynLoad__##__name __name
|
||||
|
||||
#else
|
||||
|
||||
#define DECLARE_DYNAMIC_LOAD_NVRTC_WRAP(__name) \
|
||||
struct DynLoad__##__name { \
|
||||
template <typename... Args> \
|
||||
inline auto operator()(Args... args) { \
|
||||
return ::__name(args...); \
|
||||
} \
|
||||
}; \
|
||||
extern DynLoad__##__name __name
|
||||
|
||||
#endif
|
||||
|
||||
/**
|
||||
* include all needed nvrtc functions
|
||||
**/
|
||||
#define NVRTC_ROUTINE_EACH(__macro) \
|
||||
__macro(nvrtcGetErrorString); \
|
||||
__macro(nvrtcCompileProgram); \
|
||||
__macro(nvrtcCreateProgram); \
|
||||
__macro(nvrtcDestroyProgram); \
|
||||
__macro(nvrtcGetPTX); \
|
||||
__macro(nvrtcGetPTXSize); \
|
||||
__macro(nvrtcGetProgramLog); \
|
||||
__macro(nvrtcGetProgramLogSize)
|
||||
|
||||
NVRTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NVRTC_WRAP);
|
||||
|
||||
#undef DECLARE_DYNAMIC_LOAD_NVRTC_WRAP
|
||||
|
||||
} // namespace dynload
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
Loading…
Reference in new issue