|
|
|
@ -23,20 +23,9 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
inline void get_blockexpand_output_shape(int img_height, int img_width,
|
|
|
|
|
int block_height, int block_width,
|
|
|
|
|
int stride_height, int stride_width,
|
|
|
|
|
int padding_height, int padding_width,
|
|
|
|
|
int& outputHeight, int& outputWidth) {
|
|
|
|
|
outputHeight =
|
|
|
|
|
1 +
|
|
|
|
|
(img_height + 2 * padding_height - block_height + stride_height - 1) /
|
|
|
|
|
stride_height;
|
|
|
|
|
|
|
|
|
|
outputWidth =
|
|
|
|
|
1 +
|
|
|
|
|
(img_width + 2 * padding_width - block_width + stride_width - 1) /
|
|
|
|
|
stride_width;
|
|
|
|
|
inline int get_output_size(int img_size, int block_size, int stride,
|
|
|
|
|
int padding) {
|
|
|
|
|
return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
@ -45,40 +34,54 @@ class BlockExpandKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
using namespace framework;
|
|
|
|
|
const Tensor* in = ctx.Input<Tensor>("X");
|
|
|
|
|
Tensor* out = ctx.Output<Tensor>("Out");
|
|
|
|
|
LoDTensor* out = ctx.Output<LoDTensor>("Out");
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto in_dim = in->dims();
|
|
|
|
|
int N = in_dim[0];
|
|
|
|
|
int C = in_dim[1];
|
|
|
|
|
int batch_size = in_dim[0];
|
|
|
|
|
int img_channels = in_dim[1];
|
|
|
|
|
int img_height = in_dim[2];
|
|
|
|
|
int img_width = in_dim[3];
|
|
|
|
|
|
|
|
|
|
int block_height = ctx.Attr<int>("blockHeight");
|
|
|
|
|
int block_width = ctx.Attr<int>("blockWidth");
|
|
|
|
|
int stride_height = ctx.Attr<int>("strideHeight");
|
|
|
|
|
int stride_width = ctx.Attr<int>("strideWidth");
|
|
|
|
|
int padding_height = ctx.Attr<int>("paddingHeight");
|
|
|
|
|
int padding_width = ctx.Attr<int>("paddingWidth");
|
|
|
|
|
|
|
|
|
|
int outputHeight = 0;
|
|
|
|
|
int outputWidth = 0;
|
|
|
|
|
|
|
|
|
|
get_blockexpand_output_shape(
|
|
|
|
|
img_height, img_width, block_height, block_width, stride_height,
|
|
|
|
|
stride_width, padding_height, padding_width, outputHeight, outputWidth);
|
|
|
|
|
|
|
|
|
|
std::vector<int> stride({stride_height, stride_width});
|
|
|
|
|
std::vector<int> padding({padding_height, padding_width});
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
|
Tensor src = in->Slice(i, i + 1).Resize({C, img_height, img_width});
|
|
|
|
|
Tensor dst = out->Slice(i, i + 1).Resize(
|
|
|
|
|
{outputHeight, outputWidth, C, block_height, block_width});
|
|
|
|
|
int block_height = ctx.Attr<int>("block_height");
|
|
|
|
|
int block_width = ctx.Attr<int>("block_width");
|
|
|
|
|
int stride_height = ctx.Attr<int>("stride_height");
|
|
|
|
|
int stride_width = ctx.Attr<int>("stride_width");
|
|
|
|
|
int padding_height = ctx.Attr<int>("padding_height");
|
|
|
|
|
int padding_width = ctx.Attr<int>("padding_width");
|
|
|
|
|
|
|
|
|
|
int output_height = get_output_size(img_height, block_height, stride_height,
|
|
|
|
|
padding_height);
|
|
|
|
|
int output_width =
|
|
|
|
|
get_output_size(img_width, block_width, stride_width, padding_width);
|
|
|
|
|
|
|
|
|
|
const std::vector<int> dilations({1, 1});
|
|
|
|
|
const std::vector<int> strides(
|
|
|
|
|
{stride_height, stride_width, stride_height, stride_width});
|
|
|
|
|
const std::vector<int> paddings(
|
|
|
|
|
{padding_height, padding_width, padding_height, padding_width});
|
|
|
|
|
|
|
|
|
|
auto out_dims = out->dims();
|
|
|
|
|
out->Resize({batch_size, out->numel() / batch_size});
|
|
|
|
|
for (int i = 0; i < batch_size; i++) {
|
|
|
|
|
const Tensor src =
|
|
|
|
|
in->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
|
|
|
|
|
Tensor dst = out->Slice(i, i + 1).Resize({output_height, output_width,
|
|
|
|
|
img_channels, block_height,
|
|
|
|
|
block_width});
|
|
|
|
|
|
|
|
|
|
math::Im2ColFunctor<math::ColFormat::kOCF, Place, T> f;
|
|
|
|
|
f(ctx.device_context(), src, stride, padding, &dst);
|
|
|
|
|
f(ctx.device_context(), src, dilations, strides, paddings, &dst);
|
|
|
|
|
}
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
|
|
|
|
|
// set lod information
|
|
|
|
|
// TODO(wanghaoshuang): Move this to InferShape
|
|
|
|
|
framework::LoD lod(1);
|
|
|
|
|
for (int i = 0, offset = 0; i < batch_size + 1; ++i) {
|
|
|
|
|
lod[0].push_back(offset);
|
|
|
|
|
offset += output_height * output_width;
|
|
|
|
|
}
|
|
|
|
|
out->set_lod(lod);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -88,7 +91,8 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
using namespace framework;
|
|
|
|
|
auto* in = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
Tensor* d_out =
|
|
|
|
|
const_cast<Tensor*>(ctx.Input<Tensor>(framework::GradVarName("Out")));
|
|
|
|
|
auto* d_x = ctx.Output<Tensor>(GradVarName("X"));
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
@ -96,36 +100,40 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
x_v.device(ctx.GetEigenDevice<Place>()) = x_v.constant(0.0);
|
|
|
|
|
|
|
|
|
|
auto in_dim = in->dims();
|
|
|
|
|
int N = in_dim[0];
|
|
|
|
|
int C = in_dim[1];
|
|
|
|
|
int batch_size = in_dim[0];
|
|
|
|
|
int img_channels = in_dim[1];
|
|
|
|
|
int img_height = in_dim[2];
|
|
|
|
|
int img_width = in_dim[3];
|
|
|
|
|
|
|
|
|
|
int block_height = ctx.Attr<int>("blockHeight");
|
|
|
|
|
int block_width = ctx.Attr<int>("blockWidth");
|
|
|
|
|
int stride_height = ctx.Attr<int>("strideHeight");
|
|
|
|
|
int stride_width = ctx.Attr<int>("strideWidth");
|
|
|
|
|
int padding_height = ctx.Attr<int>("paddingHeight");
|
|
|
|
|
int padding_width = ctx.Attr<int>("paddingWidth");
|
|
|
|
|
|
|
|
|
|
int outputHeight = 0;
|
|
|
|
|
int outputWidth = 0;
|
|
|
|
|
|
|
|
|
|
get_blockexpand_output_shape(
|
|
|
|
|
img_height, img_width, block_height, block_width, stride_height,
|
|
|
|
|
stride_width, padding_height, padding_width, outputHeight, outputWidth);
|
|
|
|
|
|
|
|
|
|
std::vector<int> stride({stride_height, stride_width});
|
|
|
|
|
std::vector<int> padding({padding_height, padding_width});
|
|
|
|
|
// std::vector<int> stride({stride_height, stride_width});
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
|
Tensor dst = d_x->Slice(i, i + 1).Resize({C, img_height, img_width});
|
|
|
|
|
Tensor src = d_out->Slice(i, i + 1).Resize(
|
|
|
|
|
{outputHeight, outputWidth, C, block_height, block_width});
|
|
|
|
|
int block_height = ctx.Attr<int>("block_height");
|
|
|
|
|
int block_width = ctx.Attr<int>("block_width");
|
|
|
|
|
int stride_height = ctx.Attr<int>("stride_height");
|
|
|
|
|
int stride_width = ctx.Attr<int>("stride_width");
|
|
|
|
|
int padding_height = ctx.Attr<int>("padding_height");
|
|
|
|
|
int padding_width = ctx.Attr<int>("padding_width");
|
|
|
|
|
int output_height = get_output_size(img_height, block_height, stride_height,
|
|
|
|
|
padding_height);
|
|
|
|
|
int output_width =
|
|
|
|
|
get_output_size(img_width, block_width, stride_width, padding_width);
|
|
|
|
|
|
|
|
|
|
const std::vector<int> dilations({1, 1});
|
|
|
|
|
const std::vector<int> strides(
|
|
|
|
|
{stride_height, stride_width, stride_height, stride_width});
|
|
|
|
|
const std::vector<int> paddings(
|
|
|
|
|
{padding_height, padding_width, padding_height, padding_width});
|
|
|
|
|
|
|
|
|
|
auto d_out_dims = d_out->dims();
|
|
|
|
|
d_out->Resize({batch_size, d_out->numel() / batch_size});
|
|
|
|
|
for (int i = 0; i < batch_size; i++) {
|
|
|
|
|
Tensor dst =
|
|
|
|
|
d_x->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
|
|
|
|
|
const Tensor src = d_out->Slice(i, i + 1).Resize(
|
|
|
|
|
{output_height, output_width, img_channels, block_height,
|
|
|
|
|
block_width});
|
|
|
|
|
math::Col2ImFunctor<math::ColFormat::kOCF, Place, T> f;
|
|
|
|
|
f(ctx.device_context(), dst, stride, padding, &src);
|
|
|
|
|
f(ctx.device_context(), src, dilations, strides, paddings, &dst);
|
|
|
|
|
}
|
|
|
|
|
d_out->Resize(d_out_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|