|
|
|
@ -29,45 +29,71 @@
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
|
|
|
|
|
|
// Transform on host or device. It provides the same API in std library.
|
|
|
|
|
template <typename InputIter, typename OutputIter, typename UnaryOperation>
|
|
|
|
|
void Transform(const DeviceContext& context, InputIter first, InputIter last,
|
|
|
|
|
OutputIter result, UnaryOperation op) {
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
if (is_cpu_place(place)) {
|
|
|
|
|
template <typename Place>
|
|
|
|
|
struct Transform {
|
|
|
|
|
template <typename InputIter, typename OutputIter, typename UnaryOperation>
|
|
|
|
|
void operator()(const DeviceContext& context, InputIter first, InputIter last,
|
|
|
|
|
OutputIter result, UnaryOperation op);
|
|
|
|
|
|
|
|
|
|
template <typename InputIter1, typename InputIter2, typename OutputIter,
|
|
|
|
|
typename BinaryOperation>
|
|
|
|
|
void operator()(const DeviceContext& context, InputIter1 first1,
|
|
|
|
|
InputIter1 last1, InputIter2 first2, OutputIter result,
|
|
|
|
|
BinaryOperation op);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct Transform<platform::CPUPlace> {
|
|
|
|
|
template <typename InputIter, typename OutputIter, typename UnaryOperation>
|
|
|
|
|
void operator()(const DeviceContext& context, InputIter first, InputIter last,
|
|
|
|
|
OutputIter result, UnaryOperation op) {
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
|
|
|
|
|
std::transform(first, last, result, op);
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context);
|
|
|
|
|
using namespace details;
|
|
|
|
|
thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first),
|
|
|
|
|
DevPtrCast(last), DevPtrCast(result), op);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename InputIter1, typename InputIter2, typename OutputIter,
|
|
|
|
|
typename BinaryOperation>
|
|
|
|
|
void Transform(const DeviceContext& context, InputIter1 first1,
|
|
|
|
|
InputIter1 last1, InputIter2 first2, OutputIter result,
|
|
|
|
|
BinaryOperation op) {
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
if (is_cpu_place(place)) {
|
|
|
|
|
template <typename InputIter1, typename InputIter2, typename OutputIter,
|
|
|
|
|
typename BinaryOperation>
|
|
|
|
|
void operator()(const DeviceContext& context, InputIter1 first1,
|
|
|
|
|
InputIter1 last1, InputIter2 first2, OutputIter result,
|
|
|
|
|
BinaryOperation op) {
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
|
|
|
|
|
std::transform(first1, last1, first2, result, op);
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
template <>
|
|
|
|
|
struct Transform<platform::GPUPlace> {
|
|
|
|
|
template <typename InputIter, typename OutputIter, typename UnaryOperation>
|
|
|
|
|
void operator()(const DeviceContext& context, InputIter first, InputIter last,
|
|
|
|
|
OutputIter result, UnaryOperation op) {
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
|
|
|
|
|
auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context);
|
|
|
|
|
using namespace details;
|
|
|
|
|
thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first1),
|
|
|
|
|
DevPtrCast(last1), DevPtrCast(first2), DevPtrCast(result),
|
|
|
|
|
thrust::transform(thrust::cuda::par.on(ctx.stream()),
|
|
|
|
|
details::DevPtrCast(first), details::DevPtrCast(last),
|
|
|
|
|
details::DevPtrCast(result), op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename InputIter1, typename InputIter2, typename OutputIter,
|
|
|
|
|
typename BinaryOperation>
|
|
|
|
|
void operator()(const DeviceContext& context, InputIter1 first1,
|
|
|
|
|
InputIter1 last1, InputIter2 first2, OutputIter result,
|
|
|
|
|
BinaryOperation op) {
|
|
|
|
|
auto place = context.GetPlace();
|
|
|
|
|
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
|
|
|
|
|
auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context);
|
|
|
|
|
thrust::transform(thrust::cuda::par.on(ctx.stream()),
|
|
|
|
|
details::DevPtrCast(first1), details::DevPtrCast(last1),
|
|
|
|
|
details::DevPtrCast(first2), details::DevPtrCast(result),
|
|
|
|
|
op);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
} // namespace platform
|
|
|
|
|
} // namespace paddle
|
|
|
|
|