|
|
|
@ -87,6 +87,14 @@ struct FloorDivFunc<half, bool> {
|
|
|
|
|
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T, typename S>
|
|
|
|
|
struct AbsGradFunc {
|
|
|
|
|
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) {
|
|
|
|
|
T zero = 0.0;
|
|
|
|
|
return lhs < zero ? -rhs : rhs;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct PowerFunc<half, bool> {
|
|
|
|
@ -149,6 +157,9 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const
|
|
|
|
|
case BROADCAST_TYPE_FLOORDIV:
|
|
|
|
|
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
|
output);
|
|
|
|
|
case BROADCAST_TYPE_ABSGRAD:
|
|
|
|
|
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
|
|
|
|
output);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -192,6 +203,8 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const
|
|
|
|
|
return NoBroadcastOperator<T, S, AddFunc<T, S>>(nums, input0, input1, output);
|
|
|
|
|
case BROADCAST_TYPE_FLOORDIV:
|
|
|
|
|
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
|
|
|
|
|
case BROADCAST_TYPE_ABSGRAD:
|
|
|
|
|
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|