|
|
|
@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "RowConvOp.h"
|
|
|
|
|
#include "hl_base.h"
|
|
|
|
|
#include "paddle/cuda/include/hl_base.h"
|
|
|
|
|
#include "paddle/function/RowConvOp.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
@ -94,7 +94,7 @@ __global__ void KeRowConv2(real* y,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out,
|
|
|
|
|
void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out, // NOLINT
|
|
|
|
|
const GpuMatrix& in,
|
|
|
|
|
const GpuMatrix& filter,
|
|
|
|
|
const GpuIVector& seq) {
|
|
|
|
@ -144,6 +144,10 @@ __global__ void KeRowConvBwWeight(real* dw,
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
// NOTE(zcd): temporary solution
|
|
|
|
|
unsigned mask = 0u;
|
|
|
|
|
CREATE_SHFL_MASK(mask, true);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < numSeq; ++i) {
|
|
|
|
|
const int start = starts[i];
|
|
|
|
|
const int end = starts[i + 1];
|
|
|
|
@ -170,11 +174,10 @@ __global__ void KeRowConvBwWeight(real* dw,
|
|
|
|
|
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
// warp size and blockDim.x is 32.
|
|
|
|
|
val += __shfl_down(val, 16);
|
|
|
|
|
val += __shfl_down(val, 8);
|
|
|
|
|
val += __shfl_down(val, 4);
|
|
|
|
|
val += __shfl_down(val, 2);
|
|
|
|
|
val += __shfl_down(val, 1);
|
|
|
|
|
|
|
|
|
|
for (int offset = 16; offset > 0; offset /= 2)
|
|
|
|
|
val += __shfl_down_sync(mask, val, offset);
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
if (tidx == 0) {
|
|
|
|
|
sh_dw[t][tidy] += val;
|
|
|
|
@ -205,6 +208,10 @@ __global__ void KeRowConvBwWeight2(real* dw,
|
|
|
|
|
__shared__ real sh_x[BLOCK_H][BLOCK_W];
|
|
|
|
|
__shared__ real sh_dy[BLOCK_H][BLOCK_W];
|
|
|
|
|
|
|
|
|
|
// NOTE(zcd): temporary solution
|
|
|
|
|
unsigned mask = 0u;
|
|
|
|
|
CREATE_SHFL_MASK(mask, true);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < numSeq; ++i) {
|
|
|
|
|
const int start = starts[i];
|
|
|
|
|
const int end = starts[i + 1];
|
|
|
|
@ -230,11 +237,9 @@ __global__ void KeRowConvBwWeight2(real* dw,
|
|
|
|
|
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
// warp size and blockDim.x is 32.
|
|
|
|
|
val += __shfl_down(val, 16);
|
|
|
|
|
val += __shfl_down(val, 8);
|
|
|
|
|
val += __shfl_down(val, 4);
|
|
|
|
|
val += __shfl_down(val, 2);
|
|
|
|
|
val += __shfl_down(val, 1);
|
|
|
|
|
for (int offset = 16; offset > 0; offset /= 2)
|
|
|
|
|
val += __shfl_down_sync(mask, val, offset);
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (tidx == 0 && (gidx + tidy) < width) {
|
|
|
|
@ -323,8 +328,8 @@ template <>
|
|
|
|
|
void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
|
|
|
|
|
const GpuMatrix& in,
|
|
|
|
|
const GpuMatrix& filter,
|
|
|
|
|
GpuMatrix& inG,
|
|
|
|
|
GpuMatrix& filterG,
|
|
|
|
|
GpuMatrix& inG, // NOLINT
|
|
|
|
|
GpuMatrix& filterG, // NOLINT
|
|
|
|
|
const GpuIVector& seq) {
|
|
|
|
|
const size_t numSeq = seq.getSize() - 1;
|
|
|
|
|
const size_t contextLength = filter.getHeight();
|
|
|
|
|