|
|
|
@ -199,6 +199,10 @@ template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const half *x
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, bool *y,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int8_t *x0, const int8_t *x1, bool *y,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, bool *y,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
// Element-wise ArithMetic
|
|
|
|
|
template <typename T, typename Func>
|
|
|
|
@ -261,6 +265,10 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const half
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, int *y,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int8_t *x0, const int8_t *x1, int8_t *y,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, uint8_t *y,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
// Broadcast comparation
|
|
|
|
|
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }
|
|
|
|
@ -333,6 +341,12 @@ template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector
|
|
|
|
|
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
|
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
|
|
|
|
|
bool *y, cudaStream_t stream);
|
|
|
|
|
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
|
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int8_t *x0,
|
|
|
|
|
const int8_t *x1, bool *y, cudaStream_t stream);
|
|
|
|
|
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
|
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0,
|
|
|
|
|
const uint8_t *x1, bool *y, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
// Broadcast Arithmetic
|
|
|
|
|
template <typename T, typename Func>
|
|
|
|
@ -448,6 +462,12 @@ template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vect
|
|
|
|
|
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
|
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
|
|
|
|
|
int *y, cudaStream_t stream);
|
|
|
|
|
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
|
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int8_t *x0,
|
|
|
|
|
const int8_t *x1, int8_t *y, cudaStream_t stream);
|
|
|
|
|
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
|
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0,
|
|
|
|
|
const uint8_t *x1, uint8_t *y, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
// BroadcastTo
|
|
|
|
|
template <typename T>
|
|
|
|
|