fix compile error

for_weibo
Qiao Longfei 6 years ago
parent fcde2b2725
commit 763e8fdf02

@ -117,17 +117,18 @@ __global__ static void ForRangeInElemwiseOp(Function func, T* vector,
template <> template <>
struct ForRangeIn<CUDADeviceContext> { struct ForRangeIn<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range) ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
: dev_ctx_(dev_ctx), range_(range) {} : dev_ctx_(dev_ctx), range_(range) {}
template <typename Function> template <typename Function>
inline void operator()(Function func) const { inline void operator()(Function func) const {
constexpr int num_threads = 1024; constexpr int num_threads = 1024;
int block_size = range_.size() <= num_threads ? limit_ : num_threads; int range_size = range_.size();
int block_size = range_size <= num_threads ? range_size : num_threads;
int grid_size = (range_.size() + num_threads - 1) / num_threads; int grid_size = (range_.size() + num_threads - 1) / num_threads;
ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>( ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, range_.data(), range_.size()); func, range_.data(), range_size);
} }
const CUDADeviceContext& dev_ctx_; const CUDADeviceContext& dev_ctx_;

Loading…
Cancel
Save