|
|
|
@ -22,15 +22,22 @@ template <typename T>
|
|
|
|
|
__global__ void L2LossKernel(const size_t input_size, const T *input , T *output) {
|
|
|
|
|
T ret = 0;
|
|
|
|
|
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < input_size; id += blockDim.x * gridDim.x) {
|
|
|
|
|
ret = (input[id] * input[id]);
|
|
|
|
|
ret = input[id] * input[id];
|
|
|
|
|
ret /= static_cast<T>(2);
|
|
|
|
|
MsAtomicAdd(output, ret);
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void ClearOutputMem(T *output) {
|
|
|
|
|
output[0] = static_cast<T>(0);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void L2Loss(const size_t input_size, const T *input , T *output, cudaStream_t stream) {
|
|
|
|
|
ClearOutputMem<<<GET_BLOCKS(1), GET_THREADS, 0, stream>>>(output);
|
|
|
|
|
L2LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input, output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|