add cast cuda kernel (#29352)

revert-31562-mean
Zhang Ting 5 years ago committed by GitHub
parent c1a26e2a05
commit 30d9589afe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,39 @@ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
template <typename InT, typename OutT>
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
}
template <typename InT>
struct CastOpFunctor<platform::CUDADeviceContext, InT> {
const framework::Tensor* in_;
framework::Tensor* out_;
const platform::CUDADeviceContext& ctx_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::CUDADeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto* in = in_->data<InT>();
auto size = in_->numel();
auto* out = out_->mutable_data<OutT>(ctx_.GetPlace());
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx_, size);
CastCUDAKernel<InT, OutT><<<config.block_per_grid, config.thread_per_block,
0, ctx_.stream()>>>(in, size, out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;

Loading…
Cancel
Save