You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/var_conv_2d_op.cc

432 lines
16 KiB

/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/var_conv_2d_op.h"
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
void VarConv2dOpMaker::Make() {
AddInput("X",
"X (LoDTensor, default LoDTensor<float>) Input variable which "
"should contain lod information.");
AddInput("ROW", "(LoDTensor) the row variable provides lod information");
AddInput("COLUMN",
"(LoDTensor) the column variable provides lod information");
AddInput("W", "W (Tensor), the filter.");
AddAttr<int>("InputChannel", "the input filter num").SetDefault(1);
AddAttr<int>("OutputChannel", "the output filter num").SetDefault(1);
AddAttr<int>("StrideH", "the height of Stride").SetDefault(1);
AddAttr<int>("StrideW", "the width of Stride").SetDefault(1);
AddAttr<int>("KernelH", "the height of Kernel").SetDefault(1);
AddAttr<int>("KernelW", "the width of Kernel").SetDefault(1);
AddOutput("Out", "(LoDTensor, default LoDTensor<float>) Output variable");
AddOutput("Col",
"(LoDTensor, default LoDTensor<float>) the intermediate result "
"variable");
AddComment(R"DOC(
Var Size Conv Operator
This operator calculate Out = \sigma \left ( W * X + b \right ),
only support 2-D for X.
NOTE: only support 'float32' data type now.
)DOC");
}
void VarConv2dOP::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"W(Input) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ROW"),
"Input(ROW) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("COLUMN"),
"Input(COLUMN) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Col"),
"Col(Output) of VarConv2dOP should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of X(Input) can't be less than 2.");
auto w_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "W should be 2-D tensor");
int output_channel = ctx->Attrs().Get<int>("OutputChannel");
int input_channel = ctx->Attrs().Get<int>("InputChannel");
int kernel_h = ctx->Attrs().Get<int>("KernelH");
int kernel_w = ctx->Attrs().Get<int>("KernelW");
PADDLE_ENFORCE_EQ(w_dims[0], output_channel,
"W dim[0] should be equal to OutputChannel");
PADDLE_ENFORCE_EQ(
w_dims[1], input_channel * kernel_h * kernel_w,
"W dim[1] should be equal to InputChannel * StrideH * StrideW");
if (ctx->IsRuntime()) {
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
const auto& x_lod = x_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info.");
PADDLE_ENFORCE_GE(x_lod.size(), 1, "The Input(X)'s lod info is corrupted.");
PADDLE_ENFORCE_EQ(
x_dims[0], static_cast<int64_t>(x_lod[0].back()),
"The Input(X)'s lod info mismatches the actual tensor shape.");
framework::Variable* row_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("ROW")[0]);
const auto& row_lod = row_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!row_lod.empty(), "The Input(ROW) must hold lod info.");
framework::Variable* col_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("COLUMN")[0]);
const auto& col_lod = col_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!col_lod.empty(), "The Input(COLUMN) must hold lod info.");
} else {
std::vector<int64_t> out_dims_vec{-1};
out_dims_vec.push_back(1);
std::vector<int64_t> col_dims_vec{-1};
col_dims_vec.push_back(1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec));
ctx->SetOutputDim("Col", framework::make_ddim(col_dims_vec));
}
}
template <typename DeviceContext, typename T>
class CPUVarConv2dOPKernel : public framework::OpKernel<T> {
public:
void Im2Col(const framework::ExecutionContext& ctx, const LoDTensor& input,
LoDTensor* col) const {
int input_channel = ctx.Attr<int>("InputChannel");
auto* in_row = ctx.Input<LoDTensor>("ROW");
auto* in_col = ctx.Input<LoDTensor>("COLUMN");
int kernel_h = ctx.Attr<int>("KernelH");
int kernel_w = ctx.Attr<int>("KernelW");
int stride_h = ctx.Attr<int>("StrideH");
int stride_w = ctx.Attr<int>("StrideW");
int batch = input.lod()[0].size() - 1;
const auto& bottom_offset = input.lod()[0];
// 2-D lod info.
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
// top offset is the whole size of each data sample
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_x = top_im_y * top_im_x;
int top_y = input_channel * kernel_h * kernel_w;
top_size += top_y * top_x;
top_offset.push_back(top_size);
}
framework::LoD col_lod;
col_lod.push_back(top_offset);
col->set_lod(col_lod);
std::vector<int64_t> col_dims_vec{top_size};
col_dims_vec.push_back(1);
auto* top_data = col->mutable_data<T>(framework::make_ddim(col_dims_vec),
ctx.GetPlace());
auto* bottom_data = input.data<T>();
int kernel_win_size = kernel_h * kernel_w;
int half_kernel_h = kernel_h / 2;
int half_kernel_w = kernel_w / 2;
for (int b = 0; b < batch; ++b) {
int t_offset = top_offset[b];
int b_offset = bottom_offset[b];
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
if (width == 0 || height == 0) {
continue;
}
int top_im_x = (width - 1) / stride_w + 1;
int top_im_y = (height - 1) / stride_h + 1;
int top_x = top_im_y * top_im_x;
for (int z = 0; z < input_channel; ++z) {
int row_offset = kernel_win_size * z;
int im_offset = z * width * height;
for (int y = 0; y < height; y += stride_h) {
for (int x = 0; x < width; x += stride_w) {
int col_offset = x / stride_w + y / stride_h * top_im_x;
for (int ky = 0; ky < kernel_h; ++ky) {
for (int kx = 0; kx < kernel_w; ++kx) {
int im_y = y + ky - half_kernel_h;
int im_x = x + kx - half_kernel_w;
if (im_x >= 0 && im_x < width && im_y >= 0 && im_y < height) {
top_data[t_offset +
(row_offset + ky * kernel_w + kx) * top_x +
col_offset] =
bottom_data[b_offset + im_offset + im_y * width + im_x];
} else {
top_data[t_offset +
(row_offset + ky * kernel_w + kx) * top_x +
col_offset] = 0;
}
}
}
}
}
}
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* bottom = ctx.Input<LoDTensor>("X");
auto* in_row = ctx.Input<LoDTensor>("ROW");
auto* in_col = ctx.Input<LoDTensor>("COLUMN");
auto* w = ctx.Input<Tensor>("W");
auto* top = ctx.Output<LoDTensor>("Out");
auto* col = ctx.Output<LoDTensor>("Col");
int output_channel = ctx.Attr<int>("OutputChannel");
int input_channel = ctx.Attr<int>("InputChannel");
int kernel_h = ctx.Attr<int>("KernelH");
int kernel_w = ctx.Attr<int>("KernelW");
int stride_h = ctx.Attr<int>("StrideH");
int stride_w = ctx.Attr<int>("StrideW");
Im2Col(ctx, *bottom, col);
int batch = bottom->lod()[0].size() - 1;
const auto& col_offset = col->lod()[0];
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_im_size = top_im_y * top_im_x;
top_size += output_channel * top_im_size;
top_offset.push_back(top_size);
}
framework::LoD top_lod;
top_lod.push_back(top_offset);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{top_size};
top_dims_vec.push_back(1);
auto* top_data = top->mutable_data<T>(framework::make_ddim(top_dims_vec),
ctx.GetPlace());
auto* w_data = w->data<T>();
auto* col_data = col->data<T>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
continue;
}
blas.GEMM(CblasNoTrans, CblasNoTrans, output_channel, top_im_size,
input_channel * kernel_h * kernel_w, 1.0, w_data,
col_data + col_offset[b], 0.0, top_data + top_offset[b]);
}
}
};
void VarConv2dOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequencePadGradOp should not be null.");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
}
if (ctx->HasOutput(framework::GradVarName("W"))) {
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
}
}
template <typename DeviceContext, typename T>
class CPUVarConv2dOPGradKernel : public framework::OpKernel<T> {
public:
void Im2ColGrad(const framework::ExecutionContext& ctx, T* top_diff) const {
auto* x = ctx.Input<LoDTensor>("X");
auto* in_row = ctx.Input<LoDTensor>("ROW");
auto* in_col = ctx.Input<LoDTensor>("COLUMN");
auto* col = ctx.Input<LoDTensor>("Col");
int input_channel = ctx.Attr<int>("InputChannel");
int kernel_h = ctx.Attr<int>("KernelH");
int kernel_w = ctx.Attr<int>("KernelW");
int stride_h = ctx.Attr<int>("StrideH");
int stride_w = ctx.Attr<int>("StrideW");
auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
memset(dx_data, 0.0, x->dims()[0] * x->dims()[1] * sizeof(T));
const auto& bottom_offset = x->lod()[0];
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
const auto& top_offset = col->lod()[0];
int batch = x->lod()[0].size() - 1;
int kernel_win_size = kernel_h * kernel_w;
int half_kernel_h = kernel_h / 2;
int half_kernel_w = kernel_w / 2;
for (int b = 0; b < batch; ++b) {
int t_offset = top_offset[b];
int b_offset = bottom_offset[b];
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
if (width == 0 || height == 0) {
continue;
}
int top_im_x = (width - 1) / stride_w + 1;
int top_im_y = (height - 1) / stride_h + 1;
int top_x = top_im_y * top_im_x;
for (int z = 0; z < input_channel; ++z) {
int row_offset = kernel_win_size * z;
int im_offset = z * width * height;
for (int y = 0; y < height; y += stride_h) {
for (int x = 0; x < width; x += stride_w) {
int col_offset = x / stride_w + y / stride_h * top_im_x;
for (int ky = 0; ky < kernel_h; ++ky) {
for (int kx = 0; kx < kernel_w; ++kx) {
int im_y = y + ky - half_kernel_h;
int im_x = x + kx - half_kernel_w;
if (im_x >= 0 && im_x < width && im_y >= 0 && im_y < height) {
dx_data[b_offset + im_offset + im_y * width + im_x] +=
top_diff[t_offset +
(row_offset + ky * kernel_w + kx) * top_x +
col_offset];
}
}
}
}
}
}
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* w = ctx.Input<Tensor>("W");
auto* col = ctx.Input<LoDTensor>("Col");
auto* out = ctx.Input<LoDTensor>("Out");
int output_channel = ctx.Attr<int>("OutputChannel");
int input_channel = ctx.Attr<int>("InputChannel");
int kernel_h = ctx.Attr<int>("KernelH");
int kernel_w = ctx.Attr<int>("KernelW");
auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto* d_w = ctx.Output<Tensor>(framework::GradVarName("W"));
Tensor col_grad;
col_grad.Resize(col->dims());
auto* col_diff = col_grad.mutable_data<T>(ctx.GetPlace());
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* w_diff = d_w->mutable_data<T>(ctx.GetPlace());
memset(dx_data, 0.0, x->dims()[0] * x->dims()[1] * sizeof(T));
memset(w_diff, 0.0, w->dims()[0] * w->dims()[1] * sizeof(T));
memset(col_diff, 0.0, col->dims()[0] * col->dims()[1] * sizeof(T));
auto* top_diff = d_out->data<T>();
auto* w_data = w->data<T>();
auto* col_data = col->data<T>();
int batch = x->lod()[0].size() - 1;
const auto& top_offset = out->lod()[0];
const auto& col_offset = col->lod()[0];
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
continue;
}
blas.GEMM(CblasTrans, CblasNoTrans, input_channel * kernel_h * kernel_w,
top_im_size, output_channel, 1.0, w_data,
top_diff + top_offset[b], 1.0, col_diff + col_offset[b]);
blas.GEMM(CblasNoTrans, CblasTrans, output_channel,
input_channel * kernel_h * kernel_w, top_im_size, 1.0,
top_diff + top_offset[b], col_data + col_offset[b], 1.0,
w_diff);
}
Im2ColGrad(ctx, col_diff);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plt = paddle::platform;
namespace frm = paddle::framework;
REGISTER_OPERATOR(var_conv_2d, ops::VarConv2dOP, ops::VarConv2dOpMaker,
frm::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(var_conv_2d_grad, ops::VarConv2dOpGrad);
REGISTER_OP_CPU_KERNEL(var_conv_2d,
ops::CPUVarConv2dOPKernel<plt::CPUDeviceContext, float>);
// ops::CPUVarConv2dOPKernel<plt::CPUDeviceContext,
// double>
REGISTER_OP_CPU_KERNEL(
var_conv_2d_grad,
ops::CPUVarConv2dOPGradKernel<plt::CPUDeviceContext, float>);
// ops::CPUVarConv2dOPGradKernel<plt::CPUDeviceContext,
// double>