conv shift: fix return before syncthreads

updateWriteDocsCN
Markus Kliegl 7 years ago
parent 3dc8834209
commit 42dd5da0fd

@ -62,11 +62,10 @@ __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
if (tx < num_x) {
int load_i = (i - y_half_width + x_width) % x_width;
sx[tx] = x[k * x_width + load_i];
} else {
return;
}
__syncthreads();
if (tx < num_x) {
// Compute dot product of sx[tx:tx + y_width] and sy.
T sum = 0;
for (int j = 0; j < y_width; ++j) {
@ -75,6 +74,7 @@ __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
// Save to out[k, i].
out[k * x_width + i] = sum;
}
}
// Compute x gradient - initial naive implementation with atomic add.

Loading…
Cancel
Save