|
|
@ -26,15 +26,15 @@ namespace operators {
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
using DDim = framework::DDim;
|
|
|
|
using DDim = framework::DDim;
|
|
|
|
|
|
|
|
|
|
|
|
// Define Op classes in .h file so that other deconv
|
|
|
|
// Define Op classes in .h file so that other conv transpose
|
|
|
|
// operator implementations can reuse the code.
|
|
|
|
// operator implementations can reuse the code.
|
|
|
|
class Deconv2DOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
Deconv2DOpMaker(framework::OpProto* proto,
|
|
|
|
Conv2DTransposeOpMaker(framework::OpProto* proto,
|
|
|
|
framework::OpAttrChecker* op_checker);
|
|
|
|
framework::OpAttrChecker* op_checker);
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class Deconv2DOp : public framework::OperatorWithKernel {
|
|
|
|
class Conv2DTransposeOp : public framework::OperatorWithKernel {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
@ -42,7 +42,7 @@ class Deconv2DOp : public framework::OperatorWithKernel {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override;
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class Deconv2DOpGrad : public framework::OperatorWithKernel {
|
|
|
|
class Conv2DTransposeOpGrad : public framework::OperatorWithKernel {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
@ -51,7 +51,7 @@ class Deconv2DOpGrad : public framework::OperatorWithKernel {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
template <typename Place, typename T>
|
|
|
|
class GemmDeconv2DKernel : public framework::OpKernel<T> {
|
|
|
|
class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
const Tensor* input = context.Input<Tensor>("Input");
|
|
|
|
const Tensor* input = context.Input<Tensor>("Input");
|
|
|
@ -64,27 +64,27 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
// no paddings and groups allowed in deconv
|
|
|
|
// no paddings and groups allowed in deconv
|
|
|
|
|
|
|
|
|
|
|
|
int N = input->dims()[0];
|
|
|
|
const int batch_size = input->dims()[0];
|
|
|
|
int M = input->dims()[1];
|
|
|
|
const int m = input->dims()[1];
|
|
|
|
int H = input->dims()[2];
|
|
|
|
const int h = input->dims()[2];
|
|
|
|
int W = input->dims()[3];
|
|
|
|
const int w = input->dims()[3];
|
|
|
|
|
|
|
|
|
|
|
|
int K_H = filter.dims()[2];
|
|
|
|
const int k_h = filter.dims()[2];
|
|
|
|
int K_W = filter.dims()[3];
|
|
|
|
const int k_w = filter.dims()[3];
|
|
|
|
|
|
|
|
|
|
|
|
int C = output->dims()[1]; // output channels
|
|
|
|
const int c = output->dims()[1]; // output channels
|
|
|
|
int O_H = output->dims()[2];
|
|
|
|
const int o_h = output->dims()[2];
|
|
|
|
int O_W = output->dims()[3];
|
|
|
|
const int o_w = output->dims()[3];
|
|
|
|
|
|
|
|
|
|
|
|
paddle::operators::math::Col2ImFunctor<
|
|
|
|
paddle::operators::math::Col2ImFunctor<
|
|
|
|
paddle::operators::math::ColFormat::kCFO, Place, T>
|
|
|
|
paddle::operators::math::ColFormat::kCFO, Place, T>
|
|
|
|
col2im;
|
|
|
|
col2im;
|
|
|
|
|
|
|
|
|
|
|
|
// use col_shape in the im2col and col2im calculation
|
|
|
|
// use col_shape in the im2col and col2im calculation
|
|
|
|
DDim col_shape = {C, K_H, K_W, H, W};
|
|
|
|
DDim col_shape = {c, k_h, k_w, h, w};
|
|
|
|
|
|
|
|
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
DDim col_matrix_shape = {C * K_H * K_W, H * W};
|
|
|
|
DDim col_matrix_shape = {c * k_h * k_w, h * w};
|
|
|
|
|
|
|
|
|
|
|
|
Tensor col;
|
|
|
|
Tensor col;
|
|
|
|
col.mutable_data<T>(col_shape, context.GetPlace());
|
|
|
|
col.mutable_data<T>(col_shape, context.GetPlace());
|
|
|
@ -94,10 +94,10 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
|
|
|
|
Tensor col_matrix = col;
|
|
|
|
Tensor col_matrix = col;
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
|
|
|
|
|
|
|
DDim output_shape = {C, O_H, O_W};
|
|
|
|
DDim output_shape = {c, o_h, o_w};
|
|
|
|
DDim input_matrix_shape = {M, H * W};
|
|
|
|
DDim input_matrix_shape = {m, h * w};
|
|
|
|
|
|
|
|
|
|
|
|
DDim filter_matrix_shape = {M, C * K_H * K_W};
|
|
|
|
DDim filter_matrix_shape = {m, c * k_h * k_w};
|
|
|
|
filter.Resize(filter_matrix_shape);
|
|
|
|
filter.Resize(filter_matrix_shape);
|
|
|
|
|
|
|
|
|
|
|
|
// deconvolution: gemm + col2im (similar to conv-backward on input)
|
|
|
|
// deconvolution: gemm + col2im (similar to conv-backward on input)
|
|
|
@ -106,16 +106,16 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*output);
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*output);
|
|
|
|
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
|
|
|
|
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
for (int i = 0; i < batch_size; i++) {
|
|
|
|
// batch with size (M, H * W)
|
|
|
|
// batch with size (M, h * w)
|
|
|
|
Tensor input_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
|
|
|
|
Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
|
|
|
|
// filter size: (M, C * K_H * K_W)
|
|
|
|
// filter size: (M, c * k_h * k_w)
|
|
|
|
|
|
|
|
|
|
|
|
// output size: (C, O_H, O_W)
|
|
|
|
// output size: (c, o_h, o_w)
|
|
|
|
Tensor output_batch = output->Slice<T>(i, i + 1).Resize(output_shape);
|
|
|
|
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
|
|
|
|
|
|
|
|
|
|
|
|
// col_matrix = filter * input_batch
|
|
|
|
// col_matrix = filter * input_batch
|
|
|
|
// of shape (C * K_H * K_W, H * W)
|
|
|
|
// of shape (c * k_h * k_w, h * w)
|
|
|
|
math::matmul<Place, T>(context.device_context(), filter, true,
|
|
|
|
math::matmul<Place, T>(context.device_context(), filter, true,
|
|
|
|
input_batch, false, T(1.0), &col_matrix, T(0.0));
|
|
|
|
input_batch, false, T(1.0), &col_matrix, T(0.0));
|
|
|
|
col2im(context.device_context(), output_batch, col, strides[0],
|
|
|
|
col2im(context.device_context(), output_batch, col, strides[0],
|
|
|
@ -125,7 +125,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
template <typename Place, typename T>
|
|
|
|
class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
const Tensor* input = context.Input<Tensor>("Input");
|
|
|
|
const Tensor* input = context.Input<Tensor>("Input");
|
|
|
@ -145,17 +145,17 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
// Actually, no paddings and groups allowed in deconv.
|
|
|
|
// Actually, no paddings and groups allowed in deconv.
|
|
|
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
|
|
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
|
|
|
|
|
|
|
|
|
|
|
int N = input->dims()[0];
|
|
|
|
const int batch_size = input->dims()[0];
|
|
|
|
int M = input->dims()[1];
|
|
|
|
const int m = input->dims()[1];
|
|
|
|
int H = input->dims()[2];
|
|
|
|
const int h = input->dims()[2];
|
|
|
|
int W = input->dims()[3];
|
|
|
|
const int w = input->dims()[3];
|
|
|
|
|
|
|
|
|
|
|
|
int K_H = filter.dims()[2];
|
|
|
|
const int k_h = filter.dims()[2];
|
|
|
|
int K_W = filter.dims()[3];
|
|
|
|
const int k_w = filter.dims()[3];
|
|
|
|
|
|
|
|
|
|
|
|
int C = output_grad->dims()[1]; // output channels
|
|
|
|
const int c = output_grad->dims()[1]; // output channels
|
|
|
|
int O_H = output_grad->dims()[2];
|
|
|
|
const int o_h = output_grad->dims()[2];
|
|
|
|
int O_W = output_grad->dims()[3];
|
|
|
|
const int o_w = output_grad->dims()[3];
|
|
|
|
|
|
|
|
|
|
|
|
// Only im2col functor required for bp to get to the right shape
|
|
|
|
// Only im2col functor required for bp to get to the right shape
|
|
|
|
paddle::operators::math::Im2ColFunctor<
|
|
|
|
paddle::operators::math::Im2ColFunctor<
|
|
|
@ -163,10 +163,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
im2col;
|
|
|
|
im2col;
|
|
|
|
|
|
|
|
|
|
|
|
// use col_shape in the im2col and col2im calculation
|
|
|
|
// use col_shape in the im2col and col2im calculation
|
|
|
|
DDim col_shape = {C, K_H, K_W, H, W};
|
|
|
|
DDim col_shape = {c, k_h, k_w, h, w};
|
|
|
|
|
|
|
|
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
DDim col_matrix_shape_f = {C * H * W, K_H * K_W};
|
|
|
|
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
|
|
|
|
|
|
|
|
|
|
|
|
Tensor col;
|
|
|
|
Tensor col;
|
|
|
|
col.mutable_data<T>(col_shape, context.GetPlace());
|
|
|
|
col.mutable_data<T>(col_shape, context.GetPlace());
|
|
|
@ -174,10 +174,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
// but will be reshaped into a two-dimensional matrix shape
|
|
|
|
// but will be reshaped into a two-dimensional matrix shape
|
|
|
|
// to call the matrix multiplication interface.
|
|
|
|
// to call the matrix multiplication interface.
|
|
|
|
|
|
|
|
|
|
|
|
DDim output_shape = {C, O_H, O_W};
|
|
|
|
DDim output_shape = {c, o_h, o_w};
|
|
|
|
DDim input_matrix_shape = {M, H * W};
|
|
|
|
DDim input_matrix_shape = {m, h * w};
|
|
|
|
|
|
|
|
|
|
|
|
DDim filter_matrix_shape = {M, C * K_H * K_W};
|
|
|
|
DDim filter_matrix_shape = {m, c * k_h * k_w};
|
|
|
|
filter.Resize(filter_matrix_shape);
|
|
|
|
filter.Resize(filter_matrix_shape);
|
|
|
|
|
|
|
|
|
|
|
|
// deconvolution grad on input:
|
|
|
|
// deconvolution grad on input:
|
|
|
@ -185,29 +185,29 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
// input need to compute gradient
|
|
|
|
// input need to compute gradient
|
|
|
|
if (input_grad) {
|
|
|
|
if (input_grad) {
|
|
|
|
Tensor col_matrix = col;
|
|
|
|
Tensor col_matrix = col;
|
|
|
|
DDim col_matrix_shape = {C * K_H * K_W, H * W};
|
|
|
|
DDim col_matrix_shape = {c * k_h * k_w, h * w};
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
|
|
|
|
|
|
|
input_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
input_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*input_grad);
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*input_grad);
|
|
|
|
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
|
|
|
|
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
for (int i = 0; i < batch_size; i++) {
|
|
|
|
// batch with size (C, O_H * O_W)
|
|
|
|
// batch with size (c, o_h * o_w)
|
|
|
|
Tensor output_grad_batch =
|
|
|
|
Tensor output_grad_batch =
|
|
|
|
output_grad->Slice<T>(i, i + 1).Resize(output_shape);
|
|
|
|
output_grad->Slice(i, i + 1).Resize(output_shape);
|
|
|
|
// filter of size (M, C * K_H * K_W)
|
|
|
|
// filter of size (m, c * k_h * k_w)
|
|
|
|
|
|
|
|
|
|
|
|
// batch with size (M, H, W)
|
|
|
|
// batch with size (m, h, w)
|
|
|
|
Tensor input_grad_batch =
|
|
|
|
Tensor input_grad_batch =
|
|
|
|
input_grad->Slice<T>(i, i + 1).Resize(input_matrix_shape);
|
|
|
|
input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
|
|
|
|
|
|
|
|
|
|
|
|
// im2col: dy from (C, O_H, O_W) -> (C * K_H * K_W, H * W)
|
|
|
|
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w)
|
|
|
|
im2col(context.device_context(), output_grad_batch, col, strides[0],
|
|
|
|
im2col(context.device_context(), output_grad_batch, col, strides[0],
|
|
|
|
strides[1], paddings[0], paddings[1]);
|
|
|
|
strides[1], paddings[0], paddings[1]);
|
|
|
|
|
|
|
|
|
|
|
|
// gemm: dx = filter * dy
|
|
|
|
// gemm: dx = filter * dy
|
|
|
|
// (M, C * K_H * K_W) * (C * K_H * K_W, H * W) -> (M, C, H)
|
|
|
|
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h)
|
|
|
|
math::matmul<Place, T>(context.device_context(), filter, false,
|
|
|
|
math::matmul<Place, T>(context.device_context(), filter, false,
|
|
|
|
col_matrix, false, T(1.0), &input_grad_batch,
|
|
|
|
col_matrix, false, T(1.0), &input_grad_batch,
|
|
|
|
T(0.0));
|
|
|
|
T(0.0));
|
|
|
@ -217,7 +217,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
// filter gradient required
|
|
|
|
// filter gradient required
|
|
|
|
if (filter_grad) {
|
|
|
|
if (filter_grad) {
|
|
|
|
Tensor col_matrix_f = col;
|
|
|
|
Tensor col_matrix_f = col;
|
|
|
|
DDim col_matrix_shape_f = {C * H * W, K_H * K_W};
|
|
|
|
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
|
|
|
|
col_matrix_f.Resize(col_matrix_shape_f);
|
|
|
|
col_matrix_f.Resize(col_matrix_shape_f);
|
|
|
|
|
|
|
|
|
|
|
|
filter_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
filter_grad->mutable_data<T>(context.GetPlace());
|
|
|
@ -226,19 +226,19 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
|
|
|
|
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
|
|
|
|
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
|
|
for (int i = 0; i < batch_size; ++i) {
|
|
|
|
// batch with size (C, O_H, O_W)
|
|
|
|
// batch with size (c, o_h, o_w)
|
|
|
|
Tensor output_grad_batch =
|
|
|
|
Tensor output_grad_batch =
|
|
|
|
output_grad->Slice<T>(i, i + 1).Resize(output_shape);
|
|
|
|
output_grad->Slice(i, i + 1).Resize(output_shape);
|
|
|
|
// input batch
|
|
|
|
// input batch
|
|
|
|
Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
|
|
|
|
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
|
|
|
|
|
|
|
|
|
|
|
|
// im2col: (C * H * W, K_H * K_W)
|
|
|
|
// im2col: (c * h * w, k_h * k_w)
|
|
|
|
im2col(context.device_context(), output_grad_batch, col, strides[0],
|
|
|
|
im2col(context.device_context(), output_grad_batch, col, strides[0],
|
|
|
|
strides[1], paddings[0], paddings[1]);
|
|
|
|
strides[1], paddings[0], paddings[1]);
|
|
|
|
|
|
|
|
|
|
|
|
// gemm: d_filter = x * y_grad^T
|
|
|
|
// gemm: d_filter = x * y_grad^T
|
|
|
|
// (M, C * H * W) * (K_H * K_W, C * H * W) -> (M, C, H)
|
|
|
|
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h)
|
|
|
|
math::matmul<Place, T>(context.device_context(), in_batch, false,
|
|
|
|
math::matmul<Place, T>(context.device_context(), in_batch, false,
|
|
|
|
col_matrix_f, true, T(1.0), &filter_grad_,
|
|
|
|
col_matrix_f, true, T(1.0), &filter_grad_,
|
|
|
|
T(1.0));
|
|
|
|
T(1.0));
|
|
|
|