Conv Shift Operator (#4591)
* conv_shift_op: initial implementation using Eigen Limitations: - both gradient outputs must be specified and are always computed - explicit for loops => could be optimized in various ways (e.g., different memory layout) * conv shift - gradient fixes fix case when not all output gradients desired * conv shift: minor cleanup * conv shift - more minor cleanup * conv shift: clean up & initial GPU implementation * fix rebase issuerevert-4814-Add_sequence_project_op
parent
7973d3a0ad
commit
a281b38393
@ -0,0 +1,206 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/conv_shift_op.h"
|
||||
#include "paddle/framework/eigen.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
class ConvShiftOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
|
||||
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
|
||||
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
|
||||
"The 1st dimension of Input(X) and Input(Y) should "
|
||||
"be equal.");
|
||||
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
|
||||
"The 2nd dimension of Input(Y) should be odd.");
|
||||
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
|
||||
"The 2nd dimension of Input(Y) should be less than or "
|
||||
"equal to the 2nd dimension of Input(X).");
|
||||
ctx->SetOutputDim("Out", x_dims);
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class ConvShiftGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should be not null.");
|
||||
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
ctx->SetOutputDim(x_grad_name, x_dims);
|
||||
}
|
||||
|
||||
auto y_grad_name = framework::GradVarName("Y");
|
||||
if (ctx->HasOutput(y_grad_name)) {
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
ctx->SetOutputDim(y_grad_name, y_dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
ConvShiftOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
|
||||
"where B is the batch size and M is the data dimension.");
|
||||
AddInput("Y",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x N, "
|
||||
"where B is the batch size and N is the data dimension. N must "
|
||||
"be odd.");
|
||||
AddOutput("Out",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
|
||||
"i.e., the same shape as X.");
|
||||
AddComment(R"DOC(
|
||||
ConvShift Operator.
|
||||
|
||||
A layer for circular convolution of two vectors,
|
||||
as used in the Neural Turing Machine: https://arxiv.org/abs/1410.5401
|
||||
|
||||
The equation is:
|
||||
|
||||
\f[
|
||||
Out[i] = \sum_{j=-(N-1)/2}^{(N-1)/2} X_{i+j} * Y_{j}
|
||||
\f]
|
||||
|
||||
where X's index is computed modulo M, and b's index is computed modulo N.
|
||||
|
||||
Both of the input `X` and `Y` can carry LoD (Level of Details) information.
|
||||
However, the output only shares the LoD information with input `X`.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ConvShiftKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *X = context.Input<Tensor>("X");
|
||||
auto *Y = context.Input<Tensor>("Y");
|
||||
auto *Out = context.Output<Tensor>("Out");
|
||||
Out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto x = EigenMatrix<T>::From(*X);
|
||||
auto y = EigenMatrix<T>::From(*Y);
|
||||
auto out = EigenMatrix<T>::From(*Out);
|
||||
out.setZero();
|
||||
|
||||
size_t batch_size = X->dims()[0];
|
||||
size_t x_width = X->dims()[1];
|
||||
size_t y_width = Y->dims()[1];
|
||||
size_t y_half_width = (y_width - 1) / 2;
|
||||
|
||||
for (size_t k = 0; k < batch_size; ++k) {
|
||||
for (size_t i = 0; i < x_width; ++i) {
|
||||
for (size_t j = 0; j < y_width; ++j) {
|
||||
int index = (i + j - y_half_width + x_width) % x_width;
|
||||
out(k, i) += x(k, index) * y(k, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ConvShiftGradKernel<platform::CPUPlace, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *X = context.Input<Tensor>("X");
|
||||
auto *Y = context.Input<Tensor>("Y");
|
||||
auto *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto *dX = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto *dY = context.Output<Tensor>(framework::GradVarName("Y"));
|
||||
|
||||
auto x = EigenMatrix<T>::From(*X);
|
||||
auto y = EigenMatrix<T>::From(*Y);
|
||||
auto dout = EigenMatrix<T>::From(*dOut);
|
||||
|
||||
auto x_dims = X->dims();
|
||||
auto y_dims = Y->dims();
|
||||
size_t batch_size = x_dims[0];
|
||||
size_t x_width = x_dims[1];
|
||||
size_t y_width = y_dims[1];
|
||||
size_t y_half_width = (y_width - 1) / 2;
|
||||
|
||||
// The below trades code duplication for efficiency (keeping the if
|
||||
// statement outside of the loop).
|
||||
if (dX) {
|
||||
dX->mutable_data<T>(context.GetPlace());
|
||||
auto dx = EigenMatrix<T>::From(*dX);
|
||||
dx.setZero();
|
||||
for (size_t k = 0; k < batch_size; ++k) {
|
||||
for (size_t i = 0; i < x_width; ++i) {
|
||||
for (size_t j = 0; j < y_width; ++j) {
|
||||
int index = (i + j - y_half_width + x_width) % x_width;
|
||||
dx(k, index) += dout(k, i) * y(k, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dY) {
|
||||
dY->mutable_data<T>(context.GetPlace());
|
||||
auto dy = EigenMatrix<T>::From(*dY);
|
||||
dy.setZero();
|
||||
for (size_t k = 0; k < batch_size; ++k) {
|
||||
for (size_t i = 0; i < x_width; ++i) {
|
||||
for (size_t j = 0; j < y_width; ++j) {
|
||||
int index = (i + j - y_half_width + x_width) % x_width;
|
||||
dy(k, j) += x(k, index) * dout(k, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker,
|
||||
conv_shift_grad, ops::ConvShiftGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(conv_shift,
|
||||
ops::ConvShiftKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
conv_shift_grad,
|
||||
ops::ConvShiftGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,194 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/conv_shift_op.h"
|
||||
#include "paddle/platform/cuda_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
namespace {
|
||||
|
||||
inline int div_up(int x, int y) { return (x + y - 1) / y; }
|
||||
|
||||
// Some notes on the design:
|
||||
//
|
||||
// Each thread is responsible for computing a single output out[k, i].
|
||||
// Thread blocks are based on tiles of x with height 1 in the batch dimension.
|
||||
//
|
||||
// This design is based on the typical use case where the filter
|
||||
// y is fairly small. For large y, it would probably be more efficient
|
||||
// to also tile across y.
|
||||
template <typename T>
|
||||
__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width,
|
||||
int y_width, int y_half_width,
|
||||
int batch_size) {
|
||||
extern __shared__ T mem[];
|
||||
|
||||
int tx = threadIdx.x;
|
||||
int i = blockIdx.x * blockDim.x + tx; // global x index
|
||||
int k = blockIdx.y; // batch index
|
||||
|
||||
// Check if we are in a boundary block with fewer x's to process than
|
||||
// blockDim.x.
|
||||
int num_x =
|
||||
(blockIdx.x == gridDim.x - 1) ? (x_width % blockDim.x) : blockDim.x;
|
||||
|
||||
T *sx = mem;
|
||||
T *sx_pad = &mem[num_x];
|
||||
T *sy = &mem[blockDim.x + y_width];
|
||||
|
||||
// Collaboratively load y[k, :] and length-y padding of x into shared memory.
|
||||
int pad_start = blockIdx.x * blockDim.x + num_x + x_width - y_half_width;
|
||||
for (int j = tx; j < y_width; j += blockDim.x) {
|
||||
sy[j] = y[k * y_width + j];
|
||||
sx_pad[j] = x[k * x_width + (pad_start + j) % x_width];
|
||||
}
|
||||
|
||||
// Load a cyclically shifted slice of x into shared memory.
|
||||
if (tx < num_x) {
|
||||
int load_i = (i - y_half_width + x_width) % x_width;
|
||||
sx[tx] = x[k * x_width + load_i];
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute dot product of sx[tx:tx + y_width] and sy.
|
||||
T sum = 0;
|
||||
for (int j = 0; j < y_width; ++j) {
|
||||
sum += sx[tx + j] * sy[j];
|
||||
}
|
||||
|
||||
// Save to out[k, i].
|
||||
out[k * x_width + i] = sum;
|
||||
}
|
||||
|
||||
// Compute x gradient - initial naive implementation with atomic add.
|
||||
template <typename T>
|
||||
__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width,
|
||||
int y_width, int y_half_width, int batch_size) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
|
||||
int j = blockIdx.y; // y index
|
||||
int k = blockIdx.z; // batch index
|
||||
|
||||
if (i < x_width) {
|
||||
int index = (i + j - y_half_width + x_width) % x_width;
|
||||
atomicAdd(&dx[k * x_width + index],
|
||||
dout[k * x_width + i] * y[k * y_width + j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute y gradient - initial naive implementation with atomic add.
|
||||
template <typename T>
|
||||
__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width,
|
||||
int y_width, int y_half_width, int batch_size) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
|
||||
int j = blockIdx.y; // y index
|
||||
int k = blockIdx.z; // batch index
|
||||
|
||||
if (i < x_width) {
|
||||
int index = (i + j - y_half_width + x_width) % x_width;
|
||||
atomicAdd(&dy[k * y_width + j],
|
||||
x[k * x_width + index] * dout[k * x_width + i]);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
const Tensor *X = context.Input<Tensor>("X");
|
||||
const Tensor *Y = context.Input<Tensor>("Y");
|
||||
Tensor *Out = context.Output<Tensor>("Out");
|
||||
const T *x_data = X->data<T>();
|
||||
const T *y_data = Y->data<T>();
|
||||
T *out_data = Out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
int batch_size = X->dims()[0];
|
||||
int x_width = X->dims()[1];
|
||||
int y_width = Y->dims()[1];
|
||||
int y_half_width = (y_width - 1) / 2;
|
||||
|
||||
const int x_per_block = 256;
|
||||
int num_x_blocks = div_up(x_width, x_per_block);
|
||||
int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T);
|
||||
|
||||
dim3 grid_dim(num_x_blocks, batch_size);
|
||||
|
||||
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
|
||||
context.device_context())
|
||||
.stream();
|
||||
|
||||
conv_shift_forward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
|
||||
x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ConvShiftGradKernel<platform::GPUPlace, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
const Tensor *X = context.Input<Tensor>("X");
|
||||
const Tensor *Y = context.Input<Tensor>("Y");
|
||||
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
const T *x_data = X->data<T>();
|
||||
const T *y_data = Y->data<T>();
|
||||
const T *dout_data = dOut->data<T>();
|
||||
|
||||
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
Tensor *dY = context.Output<Tensor>(framework::GradVarName("Y"));
|
||||
|
||||
int batch_size = X->dims()[0];
|
||||
int x_width = X->dims()[1];
|
||||
int y_width = Y->dims()[1];
|
||||
int y_half_width = (y_width - 1) / 2;
|
||||
|
||||
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
|
||||
context.device_context())
|
||||
.stream();
|
||||
|
||||
const int x_per_block = 256;
|
||||
int num_x_blocks = div_up(x_width, x_per_block);
|
||||
dim3 grid_dim(num_x_blocks, y_width, batch_size);
|
||||
|
||||
if (dX) {
|
||||
T *dx_data = dX->mutable_data<T>(context.GetPlace());
|
||||
cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream);
|
||||
conv_shift_dx<T><<<grid_dim, x_per_block, 0, stream>>>(
|
||||
dout_data, y_data, dx_data, x_width, y_width, y_half_width,
|
||||
batch_size);
|
||||
}
|
||||
if (dY) {
|
||||
T *dy_data = dY->mutable_data<T>(context.GetPlace());
|
||||
cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream);
|
||||
conv_shift_dy<T><<<grid_dim, x_per_block, 0, stream>>>(
|
||||
x_data, dout_data, dy_data, x_width, y_width, y_half_width,
|
||||
batch_size);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(conv_shift,
|
||||
ops::ConvShiftKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
conv_shift_grad,
|
||||
ops::ConvShiftGradKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,33 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename Place, typename T>
|
||||
class ConvShiftKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override;
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class ConvShiftGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override;
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,47 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def conv_shift_forward(x, y):
|
||||
out = np.zeros_like(x)
|
||||
M = x.shape[1]
|
||||
N = y.shape[1]
|
||||
y_half_width = (N - 1) / 2
|
||||
for i in xrange(M):
|
||||
for j in xrange(N):
|
||||
out[:, i] += x[:, (i + j + M - y_half_width) % M] * y[:, j]
|
||||
return out
|
||||
|
||||
|
||||
class TestConvShiftOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "conv_shift"
|
||||
|
||||
batch_size = 4
|
||||
x_dim = 17
|
||||
y_dim = 3 # must be odd and <= x_dim
|
||||
x = np.random.random((batch_size, x_dim)).astype("float32")
|
||||
y = np.random.random((batch_size, y_dim)).astype("float32")
|
||||
self.inputs = {'X': x, 'Y': y}
|
||||
|
||||
out = conv_shift_forward(x, y)
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05)
|
||||
|
||||
def test_check_grad_ignore_x(self):
|
||||
self.check_grad(
|
||||
['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X"))
|
||||
|
||||
def test_check_grad_ignore_y(self):
|
||||
self.check_grad(
|
||||
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue