|
|
|
@ -22,29 +22,6 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext>
|
|
|
|
|
struct ForRangeIn {
|
|
|
|
|
ForRangeIn(const DeviceContext& dev_ctx, std::vector<int64_t> range);
|
|
|
|
|
|
|
|
|
|
template <typename Function>
|
|
|
|
|
void operator()(Function func) const;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ForRangeIn<CPUDeviceContext> {
|
|
|
|
|
ForRangeIn(const CPUDeviceContext& dev_ctx, std::vector<int64_t> range)
|
|
|
|
|
: range_(range) {}
|
|
|
|
|
|
|
|
|
|
template <typename Function>
|
|
|
|
|
void operator()(Function func) const {
|
|
|
|
|
for (auto i : range_) {
|
|
|
|
|
func(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> range_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext>
|
|
|
|
|
struct ForRange {
|
|
|
|
|
ForRange(const DeviceContext& dev_ctx, size_t limit);
|
|
|
|
@ -106,35 +83,6 @@ struct ForRange<CUDADeviceContext> {
|
|
|
|
|
int limit_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T, typename Function>
|
|
|
|
|
__global__ static void ForRangeInElemwiseOp(Function func, T* vector,
|
|
|
|
|
int vector_size) {
|
|
|
|
|
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
|
|
|
|
|
if (idx < vector_size) {
|
|
|
|
|
func(vector[idx]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ForRangeIn<CUDADeviceContext> {
|
|
|
|
|
ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
|
|
|
|
|
: dev_ctx_(dev_ctx), range_(range) {}
|
|
|
|
|
|
|
|
|
|
template <typename Function>
|
|
|
|
|
inline void operator()(Function func) const {
|
|
|
|
|
constexpr int num_threads = 1024;
|
|
|
|
|
int range_size = range_.size();
|
|
|
|
|
int block_size = range_size <= num_threads ? range_size : num_threads;
|
|
|
|
|
int grid_size = (range_.size() + num_threads - 1) / num_threads;
|
|
|
|
|
|
|
|
|
|
ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
|
|
|
|
|
func, range_.CUDAData(dev_ctx_.GetPlace()), range_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const CUDADeviceContext& dev_ctx_;
|
|
|
|
|
framework::Vector<int64_t> range_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
} // namespace platform
|
|
|
|
|