|
|
|
@ -28,35 +28,35 @@ __device__ __forceinline__ void CastBase(const S *input_addr, T *output_addr) {
|
|
|
|
|
|
|
|
|
|
// half --> integer
|
|
|
|
|
__device__ __forceinline__ void CastBase(const half *input_addr, uint64_t *output_addr) {
|
|
|
|
|
*output_addr = __half2ull_rn((*input_addr));
|
|
|
|
|
*output_addr = __half2ull_rd((*input_addr));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ void CastBase(const half *input_addr, int64_t *output_addr) {
|
|
|
|
|
*output_addr = __half2ll_rn((*input_addr));
|
|
|
|
|
*output_addr = __half2ll_rd((*input_addr));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ void CastBase(const half *input_addr, uint32_t *output_addr) {
|
|
|
|
|
*output_addr = __half2uint_rn((*input_addr));
|
|
|
|
|
*output_addr = __half2uint_rd((*input_addr));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ void CastBase(const half *input_addr, int32_t *output_addr) {
|
|
|
|
|
*output_addr = __half2int_rn((*input_addr));
|
|
|
|
|
*output_addr = __half2int_rd((*input_addr));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ void CastBase(const half *input_addr, uint16_t *output_addr) {
|
|
|
|
|
*output_addr = __half2ushort_rn((*input_addr));
|
|
|
|
|
*output_addr = __half2ushort_rd((*input_addr));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ void CastBase(const half *input_addr, int16_t *output_addr) {
|
|
|
|
|
*output_addr = __half2short_rn((*input_addr));
|
|
|
|
|
*output_addr = __half2short_rd((*input_addr));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ void CastBase(const half *input_addr, uint8_t *output_addr) {
|
|
|
|
|
*output_addr = static_cast<uint8_t>(__half2ushort_rn((*input_addr)));
|
|
|
|
|
*output_addr = static_cast<uint8_t>(__half2ushort_rd((*input_addr)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ void CastBase(const half *input_addr, int8_t *output_addr) {
|
|
|
|
|
*output_addr = static_cast<int8_t>(__half2short_rn((*input_addr)));
|
|
|
|
|
*output_addr = static_cast<int8_t>(__half2short_rd((*input_addr)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// integer --> half
|
|
|
|
|