|
|
|
@ -117,17 +117,18 @@ __global__ static void ForRangeInElemwiseOp(Function func, T* vector,
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ForRangeIn<CUDADeviceContext> {
|
|
|
|
|
ForRange(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
|
|
|
|
|
ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
|
|
|
|
|
: dev_ctx_(dev_ctx), range_(range) {}
|
|
|
|
|
|
|
|
|
|
template <typename Function>
|
|
|
|
|
inline void operator()(Function func) const {
|
|
|
|
|
constexpr int num_threads = 1024;
|
|
|
|
|
int block_size = range_.size() <= num_threads ? limit_ : num_threads;
|
|
|
|
|
int range_size = range_.size();
|
|
|
|
|
int block_size = range_size <= num_threads ? range_size : num_threads;
|
|
|
|
|
int grid_size = (range_.size() + num_threads - 1) / num_threads;
|
|
|
|
|
|
|
|
|
|
ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
|
|
|
|
|
func, range_.data(), range_.size());
|
|
|
|
|
func, range_.data(), range_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const CUDADeviceContext& dev_ctx_;
|
|
|
|
|