|
|
|
@ -42,6 +42,19 @@ struct PowerFunc {
|
|
|
|
|
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct PowerFunc<half, half> {
|
|
|
|
|
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
|
|
|
|
|
return __float2half(pow(__half2float(lhs), __half2float(rhs)));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct PowerFunc<half, bool> {
|
|
|
|
|
// invalid branch
|
|
|
|
|
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; }
|
|
|
|
|
|
|
|
|
|
template <typename T, typename S, typename Func>
|
|
|
|
@ -131,8 +144,20 @@ template void Broadcast(const int &l0, const int &l1, const int &l2, const int &
|
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
enum BroadcastOpType op, const float *input0, const float *input1, float *output,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
enum BroadcastOpType op, const half *input0, const half *input1, bool *output,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
|
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
|
|
|
|
enum BroadcastOpType op, const half *input0, const half *input1, half *output,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
|
|
|
|
bool *output, cudaStream_t stream);
|
|
|
|
|
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
|
|
|
|
float *output, cudaStream_t stream);
|
|
|
|
|
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1,
|
|
|
|
|
bool *output, cudaStream_t stream);
|
|
|
|
|
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1,
|
|
|
|
|
half *output, cudaStream_t stream);
|
|
|
|
|