commit
0b4880ba91
@ -0,0 +1,202 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "Function.h"
|
||||
#include "Im2Col.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/*
|
||||
* \brief Converts the image data of four dimensions(NCHW) into
|
||||
* a sequence data of three dimensions(NST) in the forward calculation,
|
||||
* which is reversed in the backward calculation.
|
||||
* Where N is batch size, S is the length of the sequence after each
|
||||
* image is expanded, T is the size of each time step in the sequence.
|
||||
*
|
||||
* Arguments in forward function:
|
||||
* \param inputs[0] Image data of NCHW format.
|
||||
* \param outputs[0] Sequence data of NST format.
|
||||
*
|
||||
* Arguments in backward function:
|
||||
* \param inputs[0] Sequence data of NST format.
|
||||
* \param outputs[0] Image data of NCHW format.
|
||||
*/
|
||||
class BlockExpandFunction : public FunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {
|
||||
// function arguments
|
||||
strides_ = config.get<std::vector<size_t>>("strides");
|
||||
paddings_ = config.get<std::vector<size_t>>("paddings");
|
||||
blocks_ = config.get<std::vector<size_t>>("blocks");
|
||||
|
||||
// number of inputs and outputs
|
||||
numInputs_ = 1;
|
||||
numOutputs_ = 1;
|
||||
}
|
||||
|
||||
void checkShape(const TensorShape& image, const TensorShape& sequence) const {
|
||||
// image shape should be 4-dimensional.
|
||||
CHECK_EQ(image.ndims(), (size_t)4);
|
||||
// sequence shape should be 3-dimensional.
|
||||
CHECK_EQ(sequence.ndims(), (size_t)3);
|
||||
// The batchSize of the image needs to be equal to
|
||||
// the batchSize of the sequence.
|
||||
CHECK_EQ(image[0], sequence[0]);
|
||||
}
|
||||
|
||||
// Calculate the shape of colData based on the shape of the image
|
||||
// and the shape of the sequence.
|
||||
TensorShape getColShape(const TensorShape& image,
|
||||
const TensorShape& sequence) const {
|
||||
size_t inputChannels = image[1];
|
||||
size_t inputHeight = image[2];
|
||||
size_t inputWidth = image[3];
|
||||
size_t seqLength = sequence[1];
|
||||
size_t stepSize = sequence[2];
|
||||
size_t outputHeight =
|
||||
1 +
|
||||
(inputHeight + 2 * paddingH() - blockH() + strideH() - 1) / strideH();
|
||||
size_t outputWidth =
|
||||
1 +
|
||||
(inputWidth + 2 * paddingW() - blockW() + strideW() - 1) / strideW();
|
||||
CHECK_EQ(seqLength, outputHeight * outputWidth);
|
||||
CHECK_EQ(stepSize, inputChannels * blockH() * blockW());
|
||||
|
||||
// [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
|
||||
return TensorShape({outputHeight,
|
||||
outputWidth,
|
||||
inputChannels,
|
||||
(size_t)blockH(),
|
||||
(size_t)blockW()});
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<size_t> strides_;
|
||||
std::vector<size_t> paddings_;
|
||||
std::vector<size_t> blocks_;
|
||||
|
||||
inline int strideH() const { return strides_[0]; }
|
||||
|
||||
inline int strideW() const { return strides_[1]; }
|
||||
|
||||
inline int paddingH() const { return paddings_[0]; }
|
||||
|
||||
inline int paddingW() const { return paddings_[1]; }
|
||||
|
||||
inline int blockH() const { return blocks_[0]; }
|
||||
|
||||
inline int blockW() const { return blocks_[1]; }
|
||||
};
|
||||
|
||||
template <DeviceType Device>
|
||||
class BlockExpandForward : public BlockExpandFunction {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {
|
||||
BlockExpandFunction::init(config);
|
||||
}
|
||||
|
||||
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
const TensorShape& image = inputs[0].shape();
|
||||
const TensorShape& sequence = outputs[0].shape();
|
||||
checkShape(image, sequence);
|
||||
}
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
CHECK_EQ(numInputs_, inputs.size());
|
||||
CHECK_EQ(numOutputs_, outputs.size());
|
||||
check(inputs, outputs);
|
||||
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
|
||||
const TensorShape& image = inputs[0].shape();
|
||||
const TensorShape& sequence = outputs[0].shape();
|
||||
|
||||
TensorShape imShape = TensorShape({image[1], image[2], image[3]});
|
||||
TensorShape colShape = getColShape(image, sequence);
|
||||
size_t batchSize = image[0];
|
||||
|
||||
real* imageData = inputs[0].data<real>();
|
||||
real* seqData = outputs[0].data<real>();
|
||||
Im2ColFunctor<kOCF, Device, real> im2col;
|
||||
for (size_t i = 0; i < batchSize; i++) {
|
||||
// The result of im2col is [outputHeight, outputWidth,
|
||||
// inputChannels, filterHeight, filterWidth], and it is easy to
|
||||
// reshape into [seqLength, stepSize], where seqLength is equal
|
||||
// output_height * output_width, stepSize is equal
|
||||
// input_channels * filter_height * filter_width
|
||||
im2col(imageData,
|
||||
imShape,
|
||||
seqData,
|
||||
colShape,
|
||||
strideH(),
|
||||
strideW(),
|
||||
paddingH(),
|
||||
paddingW());
|
||||
imageData += imShape.getElements();
|
||||
seqData += colShape.getElements();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <DeviceType Device>
|
||||
class BlockExpandBackward : public BlockExpandFunction {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {
|
||||
BlockExpandFunction::init(config);
|
||||
}
|
||||
|
||||
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
const TensorShape& image = outputs[0].shape();
|
||||
const TensorShape& sequence = inputs[0].shape();
|
||||
checkShape(image, sequence);
|
||||
}
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
CHECK_EQ(numInputs_, inputs.size());
|
||||
CHECK_EQ(numOutputs_, outputs.size());
|
||||
check(inputs, outputs);
|
||||
// Since the implementation of Col2ImFunctor is ADD_TO,
|
||||
// this function only supports ADD_TO mode.
|
||||
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
||||
const TensorShape& image = outputs[0].shape();
|
||||
const TensorShape& sequence = inputs[0].shape();
|
||||
|
||||
TensorShape imShape = TensorShape({image[1], image[2], image[3]});
|
||||
TensorShape colShape = getColShape(image, sequence);
|
||||
size_t batchSize = image[0];
|
||||
|
||||
real* imageData = outputs[0].data<real>();
|
||||
real* seqData = inputs[0].data<real>();
|
||||
Col2ImFunctor<kOCF, Device, real> col2im;
|
||||
for (size_t i = 0; i < batchSize; i++) {
|
||||
col2im(imageData,
|
||||
imShape,
|
||||
seqData,
|
||||
colShape,
|
||||
strideH(),
|
||||
strideW(),
|
||||
paddingH(),
|
||||
paddingW());
|
||||
imageData += imShape.getElements();
|
||||
seqData += colShape.getElements();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_TYPED_FUNC(BlockExpand, CPU, BlockExpandForward);
|
||||
REGISTER_TYPED_FUNC(BlockExpandGrad, CPU, BlockExpandBackward);
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
REGISTER_TYPED_FUNC(BlockExpand, GPU, BlockExpandForward);
|
||||
REGISTER_TYPED_FUNC(BlockExpandGrad, GPU, BlockExpandBackward);
|
||||
#endif
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,107 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "FunctionTest.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
TEST(BlockExpandForward, real) {
|
||||
for (size_t batchSize : {5, 32}) {
|
||||
for (size_t channels : {1, 5, 32}) {
|
||||
for (size_t inputHeight : {5, 33, 100}) {
|
||||
for (size_t inputWidth : {5, 32, 96}) {
|
||||
for (size_t block : {1, 3, 5}) {
|
||||
for (size_t stride : {1, 2}) {
|
||||
for (size_t padding : {0, 1}) {
|
||||
// init Test object
|
||||
std::vector<size_t> strides = {stride, stride};
|
||||
std::vector<size_t> paddings = {padding, padding};
|
||||
std::vector<size_t> blocks = {block, block};
|
||||
CpuGpuFuncCompare test("BlockExpand",
|
||||
FuncConfig()
|
||||
.set("strides", strides)
|
||||
.set("paddings", paddings)
|
||||
.set("blocks", blocks));
|
||||
|
||||
size_t outputHeight =
|
||||
1 +
|
||||
(inputHeight + 2 * padding - block + stride - 1) / stride;
|
||||
size_t outputWidth =
|
||||
1 +
|
||||
(inputWidth + 2 * padding - block + stride - 1) / stride;
|
||||
TensorShape inputShape =
|
||||
TensorShape({batchSize, channels, inputHeight, inputWidth});
|
||||
TensorShape outputShape =
|
||||
TensorShape({batchSize,
|
||||
outputHeight * outputWidth,
|
||||
channels * block * block});
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, inputShape));
|
||||
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, outputShape));
|
||||
// run Function
|
||||
test.run();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BlockExpandBackward, real) {
|
||||
for (size_t batchSize : {5, 32}) {
|
||||
for (size_t channels : {1, 5, 32}) {
|
||||
for (size_t inputHeight : {5, 33, 100}) {
|
||||
for (size_t inputWidth : {5, 32, 96}) {
|
||||
for (size_t block : {1, 3, 5}) {
|
||||
for (size_t stride : {1, 2}) {
|
||||
for (size_t padding : {0, 1}) {
|
||||
// init Test object
|
||||
std::vector<size_t> strides = {stride, stride};
|
||||
std::vector<size_t> paddings = {padding, padding};
|
||||
std::vector<size_t> blocks = {block, block};
|
||||
CpuGpuFuncCompare test("BlockExpandGrad",
|
||||
FuncConfig()
|
||||
.set("strides", strides)
|
||||
.set("paddings", paddings)
|
||||
.set("blocks", blocks));
|
||||
|
||||
size_t outputHeight =
|
||||
1 +
|
||||
(inputHeight + 2 * padding - block + stride - 1) / stride;
|
||||
size_t outputWidth =
|
||||
1 +
|
||||
(inputWidth + 2 * padding - block + stride - 1) / stride;
|
||||
TensorShape inputShape =
|
||||
TensorShape({batchSize, channels, inputHeight, inputWidth});
|
||||
TensorShape outputShape =
|
||||
TensorShape({batchSize,
|
||||
outputHeight * outputWidth,
|
||||
channels * block * block});
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, outputShape));
|
||||
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, inputShape),
|
||||
ADD_TO);
|
||||
// run Function
|
||||
test.run();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -1,62 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ConvOp.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/*
|
||||
* imData = [input_channels, input_height, input_width]
|
||||
* colData = [input_channels, filter_height, filter_width,
|
||||
* output_height, output_width]
|
||||
*/
|
||||
template <DeviceType Device, class T>
|
||||
class Im2ColFunctor {
|
||||
public:
|
||||
void operator()(const T* imData,
|
||||
int inputChannels,
|
||||
int inputHeight,
|
||||
int inputWidth,
|
||||
int filterHeight,
|
||||
int filterWidth,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth,
|
||||
int outputHeight,
|
||||
int outputWidth,
|
||||
T* colData);
|
||||
};
|
||||
|
||||
template <DeviceType Device, class T>
|
||||
class Col2ImFunctor {
|
||||
public:
|
||||
void operator()(const T* colData,
|
||||
int inputChannels,
|
||||
int inputHeight,
|
||||
int inputWidth,
|
||||
int filterHeight,
|
||||
int filterWidth,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth,
|
||||
int outputHeight,
|
||||
int outputWidth,
|
||||
T* imData);
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -1,186 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "ConvOp.h"
|
||||
#include "GemmConvOp.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
template<class T>
|
||||
__global__
|
||||
void im2col(const T* data_im, int numOuts, int height, int width,
|
||||
int blockH, int blockW,
|
||||
int strideH, int strideW,
|
||||
int paddingH, int paddingW,
|
||||
int height_col, int width_col,
|
||||
T* data_col) {
|
||||
int index =
|
||||
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
|
||||
if (index < numOuts) {
|
||||
int w_out = index % width_col;
|
||||
index /= width_col;
|
||||
int h_out = index % height_col;
|
||||
int channel_in = index / height_col;
|
||||
int channel_out = channel_in * blockH * blockW;
|
||||
int h_in = h_out * strideH;
|
||||
int w_in = w_out * strideW;
|
||||
|
||||
data_col += (channel_out * height_col + h_out) * width_col + w_out;
|
||||
for (int i = 0; i < blockH; ++i) {
|
||||
for (int j = 0; j < blockW; ++j) {
|
||||
int rIdx = int(h_in+i);
|
||||
int cIdx = int(w_in+j);
|
||||
if ((rIdx-(int)paddingH) >= (int)height ||
|
||||
(rIdx-(int)paddingH) < 0 ||
|
||||
(cIdx-(int)paddingW) >= (int)width ||
|
||||
(cIdx-(int)paddingW) < 0) {
|
||||
*data_col = 0;
|
||||
} else {
|
||||
rIdx = rIdx + channel_in*height - paddingH;
|
||||
cIdx = cIdx - paddingW;
|
||||
*data_col = data_im[rIdx* width + cIdx];
|
||||
}
|
||||
data_col += height_col * width_col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
class Im2ColFunctor<DEVICE_TYPE_GPU, T> {
|
||||
public:
|
||||
void operator()(const T* imData,
|
||||
int inputChannels,
|
||||
int inputHeight,
|
||||
int inputWidth,
|
||||
int filterHeight,
|
||||
int filterWidth,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth,
|
||||
int outputHeight,
|
||||
int outputWidth,
|
||||
T* colData) {
|
||||
int numKernels = inputChannels * outputHeight * outputWidth;
|
||||
int blocks = (numKernels + 1024 -1) / 1024;
|
||||
int blockX = 512;
|
||||
int blockY = (blocks + 512 - 1) / 512;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(blockX, blockY);
|
||||
im2col<T><<< grid, threads, 0, STREAM_DEFAULT >>>
|
||||
(imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth,
|
||||
strideHeight, strideWidth, paddingHeight, paddingWidth,
|
||||
outputHeight, outputWidth, colData);
|
||||
CHECK_SYNC("Im2ColFunctor GPU failed");
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
__global__
|
||||
void col2im(size_t n, const T* data_col, size_t height,
|
||||
size_t width, size_t channels,
|
||||
size_t blockH, size_t blockW,
|
||||
size_t strideH, size_t strideW,
|
||||
size_t paddingH, size_t paddingW,
|
||||
size_t height_col, size_t width_col,
|
||||
T* data_im) {
|
||||
size_t index =
|
||||
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
|
||||
if (index < n) {
|
||||
T val = 0;
|
||||
int w = int(index % width);
|
||||
int h = int((index / width) % height);
|
||||
int c = int(index / (width * height));
|
||||
if ((w - (int)paddingW) >= 0 &&
|
||||
(w - (int)paddingW) < (width-2 * paddingW) &&
|
||||
(h - (int)paddingH) >= 0 &&
|
||||
(h - paddingH) < (height - 2 * paddingH)) {
|
||||
// compute the start and end of the output
|
||||
int w_col_start =
|
||||
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
|
||||
int w_col_end =
|
||||
min((int)(w / (int)strideW + 1), (int)(width_col));
|
||||
int h_col_start =
|
||||
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
|
||||
int h_col_end = min(int(h / strideH + 1), int(height_col));
|
||||
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
|
||||
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
|
||||
// the col location: [c * width * height + h_out, w_out]
|
||||
int c_col = int(c * blockH* blockW) + \
|
||||
(h - h_col * (int)strideH) * (int)blockW +
|
||||
(w - w_col * (int)strideW);
|
||||
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
|
||||
}
|
||||
}
|
||||
h -= paddingH;
|
||||
w -= paddingW;
|
||||
data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
|
||||
h*(width-2*paddingW) + w] += val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
class Col2ImFunctor<DEVICE_TYPE_GPU, T> {
|
||||
public:
|
||||
void operator()(const T* colData,
|
||||
int inputChannels,
|
||||
int inputHeight,
|
||||
int inputWidth,
|
||||
int filterHeight,
|
||||
int filterWidth,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth,
|
||||
int outputHeight,
|
||||
int outputWidth,
|
||||
T* imData) {
|
||||
size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight)
|
||||
* (inputWidth + 2*paddingWidth);
|
||||
|
||||
size_t blocks = (numKernels + 1024 -1) / 1024;
|
||||
size_t blockX = 512;
|
||||
size_t blockY = (blocks+512-1)/512;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(blockX, blockY);
|
||||
|
||||
// To avoid involving atomic operations, we will launch one kernel per
|
||||
// bottom dimension, and then in the kernel add up the top dimensions.
|
||||
col2im<T><<< grid, threads, 0, STREAM_DEFAULT >>>
|
||||
(numKernels,
|
||||
colData,
|
||||
inputHeight + 2*paddingHeight,
|
||||
inputWidth + 2*paddingWidth,
|
||||
inputChannels,
|
||||
filterHeight,
|
||||
filterWidth,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
paddingHeight,
|
||||
paddingWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
imData);
|
||||
CHECK_SYNC("Col2ImFunctor GPU failed");
|
||||
}
|
||||
};
|
||||
|
||||
template class Im2ColFunctor<DEVICE_TYPE_GPU, float>;
|
||||
template class Im2ColFunctor<DEVICE_TYPE_GPU, double>;
|
||||
template class Col2ImFunctor<DEVICE_TYPE_GPU, float>;
|
||||
template class Col2ImFunctor<DEVICE_TYPE_GPU, double>;
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,96 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "TensorShape.h"
|
||||
#include "TensorType.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
|
||||
enum ColFormat { kCFO = 0, kOCF = 1 };
|
||||
|
||||
/*
|
||||
* \brief Converts the image data of three dimensions(CHW) into a colData of
|
||||
* five dimensions in the Im2ColFunctor calculation,
|
||||
* And in the Col2ImFunctor calculation, it is reversed.
|
||||
*
|
||||
* \param imData Image data.
|
||||
* \param imShape The shape of imData,
|
||||
* [inputChannels, inputHeight, inputWidth].
|
||||
* \param colData Column data.
|
||||
* \param colShape The shape of colData.
|
||||
*
|
||||
* If the template argument Format is kCFO, the shape of colData is:
|
||||
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
|
||||
* So, it is easy to reshape into a convolution matrix for convolution
|
||||
* calculation based on matrix multiplication.
|
||||
* The shape of convolution matrix is [height, width], where the height is equal
|
||||
* inputChannels * filterHeight * filterWidth, and the width is equal
|
||||
* outputHeight * outputWidth.
|
||||
*
|
||||
* Reshape:
|
||||
* shape of colData shape of convolution matrix
|
||||
* [inputChannels,
|
||||
* filterHeight,
|
||||
* filterWidth, ======> [height, width]
|
||||
* outputHeight,
|
||||
* outputWidth]
|
||||
*
|
||||
* If the template argument Format is kOCF, the shape of colData is:
|
||||
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
|
||||
* So, it is easy to reshape into a sequence matrix for rnn calculation.
|
||||
* The shape of sequence matrix is [seqLength, stepSize], where the seqLength
|
||||
* is equal outputHeight * outputWidth, and the stepSize is equal
|
||||
* inputChannels * filterHeight * filterWidth.
|
||||
*
|
||||
* Reshape:
|
||||
* shape of colData shape of sequence matrix
|
||||
* [outputHeight,
|
||||
* outputWidth,
|
||||
* inputChannels, ======> [seqLength, stepSize]
|
||||
* filterHeight,
|
||||
* filterWidth]
|
||||
*
|
||||
* \note The caller needs to ensure that imShape.inputChannels is equal to
|
||||
* colShape.inputChannels.
|
||||
*/
|
||||
template <ColFormat Format, DeviceType Device, class T>
|
||||
class Im2ColFunctor {
|
||||
public:
|
||||
void operator()(const T* imData,
|
||||
const TensorShape& imShape,
|
||||
T* colData,
|
||||
const TensorShape& colShape,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth);
|
||||
};
|
||||
|
||||
template <ColFormat Format, DeviceType Device, class T>
|
||||
class Col2ImFunctor {
|
||||
public:
|
||||
void operator()(T* imData,
|
||||
const TensorShape& imShape,
|
||||
const T* colData,
|
||||
const TensorShape& colShape,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth);
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,235 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "Im2Col.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/*
|
||||
* imShape = [inputChannels, inputHeight, inputWidth]
|
||||
* colShape =
|
||||
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
|
||||
*/
|
||||
template <class T>
|
||||
class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> {
|
||||
public:
|
||||
void operator()(const T* imData,
|
||||
const TensorShape& imShape,
|
||||
T* colData,
|
||||
const TensorShape& colShape,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth) {
|
||||
int inputChannels = imShape[0];
|
||||
int inputHeight = imShape[1];
|
||||
int inputWidth = imShape[2];
|
||||
int filterHeight = colShape[1];
|
||||
int filterWidth = colShape[2];
|
||||
int outputHeight = colShape[3];
|
||||
int outputWidth = colShape[4];
|
||||
int channelsCol = inputChannels * filterHeight * filterWidth;
|
||||
|
||||
for (int c = 0; c < channelsCol; ++c) {
|
||||
int wOffset = c % filterWidth;
|
||||
int hOffset = (c / filterWidth) % filterHeight;
|
||||
int c_im = c / filterWidth / filterHeight;
|
||||
for (int h = 0; h < outputHeight; ++h) {
|
||||
for (int w = 0; w < outputWidth; ++w) {
|
||||
int imRowIdx = h * strideHeight + hOffset;
|
||||
int imColIdx = w * strideWidth + wOffset;
|
||||
if ((imRowIdx - paddingHeight) < 0 ||
|
||||
(imRowIdx - paddingHeight) >= inputHeight ||
|
||||
(imColIdx - paddingWidth) < 0 ||
|
||||
(imColIdx - paddingWidth) >= inputWidth) {
|
||||
colData[(c * outputHeight + h) * outputWidth + w] = T(0);
|
||||
} else {
|
||||
imRowIdx += c_im * inputHeight - paddingHeight;
|
||||
imColIdx -= paddingWidth;
|
||||
colData[(c * outputHeight + h) * outputWidth + w] =
|
||||
imData[imRowIdx * inputWidth + imColIdx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* imShape = [inputChannels, inputHeight, inputWidth]
|
||||
* colShape =
|
||||
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
|
||||
*/
|
||||
template <class T>
|
||||
class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, T> {
|
||||
public:
|
||||
void operator()(T* imData,
|
||||
const TensorShape& imShape,
|
||||
const T* colData,
|
||||
const TensorShape& colShape,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth) {
|
||||
int inputChannels = imShape[0];
|
||||
int inputHeight = imShape[1];
|
||||
int inputWidth = imShape[2];
|
||||
int filterHeight = colShape[1];
|
||||
int filterWidth = colShape[2];
|
||||
int outputHeight = colShape[3];
|
||||
int outputWidth = colShape[4];
|
||||
int channelsCol = inputChannels * filterHeight * filterWidth;
|
||||
|
||||
for (int c = 0; c < channelsCol; ++c) {
|
||||
int wOffset = c % filterWidth;
|
||||
int hOffset = (c / filterWidth) % filterHeight;
|
||||
int c_im = c / filterWidth / filterHeight;
|
||||
for (int h = 0; h < outputHeight; ++h) {
|
||||
for (int w = 0; w < outputWidth; ++w) {
|
||||
int imRowIdx = h * strideHeight + hOffset;
|
||||
int imColIdx = w * strideWidth + wOffset;
|
||||
if ((imRowIdx - paddingHeight) >= 0 &&
|
||||
(imRowIdx - paddingHeight) < inputHeight &&
|
||||
(imColIdx - paddingWidth) >= 0 &&
|
||||
(imColIdx - paddingWidth) < inputWidth) {
|
||||
imRowIdx += c_im * inputHeight - paddingHeight;
|
||||
imColIdx -= paddingWidth;
|
||||
imData[imRowIdx * inputWidth + imColIdx] +=
|
||||
colData[(c * outputHeight + h) * outputWidth + w];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, float>;
|
||||
template class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, double>;
|
||||
template class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, float>;
|
||||
template class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, double>;
|
||||
|
||||
/*
|
||||
* imShape = [inputChannels, inputHeight, inputWidth]
|
||||
* colShape =
|
||||
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
|
||||
*/
|
||||
template <class T>
|
||||
class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, T> {
|
||||
public:
|
||||
void operator()(const T* imData,
|
||||
const TensorShape& imShape,
|
||||
T* colData,
|
||||
const TensorShape& colShape,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth) {
|
||||
int inputChannels = imShape[0];
|
||||
int inputHeight = imShape[1];
|
||||
int inputWidth = imShape[2];
|
||||
int filterHeight = colShape[3];
|
||||
int filterWidth = colShape[4];
|
||||
int outputHeight = colShape[0];
|
||||
int outputWidth = colShape[1];
|
||||
for (int outputH = 0; outputH < outputHeight; ++outputH) {
|
||||
for (int outputW = 0; outputW < outputWidth; ++outputW) {
|
||||
for (int channel = 0; channel < inputChannels; ++channel) {
|
||||
for (int filterH = 0; filterH < filterHeight; ++filterH) {
|
||||
for (int filterW = 0; filterW < filterWidth; ++filterW) {
|
||||
int imRowOffset =
|
||||
outputH * strideHeight + filterH - paddingHeight;
|
||||
int imColOffset = outputW * strideWidth + filterW - paddingWidth;
|
||||
int colDataOffset =
|
||||
(((outputH * outputWidth + outputW) * inputChannels +
|
||||
channel) *
|
||||
filterHeight +
|
||||
filterH) *
|
||||
filterWidth +
|
||||
filterW;
|
||||
if (imRowOffset < 0 || imRowOffset >= inputHeight ||
|
||||
imColOffset < 0 || imColOffset >= inputWidth) {
|
||||
colData[colDataOffset] = float(0);
|
||||
} else {
|
||||
int imDataOffset =
|
||||
(channel * inputHeight + imRowOffset) * inputWidth +
|
||||
imColOffset;
|
||||
colData[colDataOffset] = imData[imDataOffset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* imShape = [inputChannels, inputHeight, inputWidth]
|
||||
* colShape =
|
||||
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
|
||||
*/
|
||||
template <class T>
|
||||
class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> {
|
||||
public:
|
||||
void operator()(T* imData,
|
||||
const TensorShape& imShape,
|
||||
const T* colData,
|
||||
const TensorShape& colShape,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth) {
|
||||
int inputChannels = imShape[0];
|
||||
int inputHeight = imShape[1];
|
||||
int inputWidth = imShape[2];
|
||||
int filterHeight = colShape[3];
|
||||
int filterWidth = colShape[4];
|
||||
int outputHeight = colShape[0];
|
||||
int outputWidth = colShape[1];
|
||||
for (int outputH = 0; outputH < outputHeight; ++outputH) {
|
||||
for (int outputW = 0; outputW < outputWidth; ++outputW) {
|
||||
for (int channel = 0; channel < inputChannels; ++channel) {
|
||||
for (int filterH = 0; filterH < filterHeight; ++filterH) {
|
||||
for (int filterW = 0; filterW < filterWidth; ++filterW) {
|
||||
int imRowOffset =
|
||||
outputH * strideHeight + filterH - paddingHeight;
|
||||
int imColOffset = outputW * strideWidth + filterW - paddingWidth;
|
||||
int colDataOffset =
|
||||
(((outputH * outputWidth + outputW) * inputChannels +
|
||||
channel) *
|
||||
filterHeight +
|
||||
filterH) *
|
||||
filterWidth +
|
||||
filterW;
|
||||
if (imRowOffset >= 0 && imRowOffset < inputHeight &&
|
||||
imColOffset >= 0 && imColOffset < inputWidth) {
|
||||
int imDataOffset =
|
||||
(channel * inputHeight + imRowOffset) * inputWidth +
|
||||
imColOffset;
|
||||
imData[imDataOffset] += colData[colDataOffset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, float>;
|
||||
template class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, double>;
|
||||
template class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, float>;
|
||||
template class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, double>;
|
||||
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,125 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "Im2Col.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "Function.h"
|
||||
#include "paddle/math/Matrix.h"
|
||||
#include "paddle/math/tests/TensorCheck.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
template <DeviceType Device, class T>
|
||||
void TestIm2ColFunctor() {
|
||||
for (size_t channels : {1, 5, 32}) {
|
||||
for (size_t inputHeight : {5, 33, 100}) {
|
||||
for (size_t inputWidth : {5, 32, 96}) {
|
||||
for (size_t filterHeight : {1, 5}) {
|
||||
for (size_t filterWidth : {3, 7}) {
|
||||
for (size_t stride : {1, 2}) {
|
||||
for (size_t padding : {0, 1}) {
|
||||
if (inputHeight <= filterHeight || inputWidth <= filterWidth)
|
||||
break;
|
||||
if (padding >= filterHeight || padding >= filterWidth) break;
|
||||
size_t outputHeight =
|
||||
(inputHeight - filterHeight + 2 * padding + stride) /
|
||||
stride;
|
||||
size_t outputWidth =
|
||||
(inputWidth - filterWidth + 2 * padding + stride) / stride;
|
||||
|
||||
TensorShape imShape =
|
||||
TensorShape({channels, inputHeight, inputWidth});
|
||||
TensorShape colShape1 = TensorShape({channels,
|
||||
filterHeight,
|
||||
filterWidth,
|
||||
outputHeight,
|
||||
outputWidth});
|
||||
TensorShape colShape2 = TensorShape({outputHeight,
|
||||
outputWidth,
|
||||
channels,
|
||||
filterHeight,
|
||||
filterWidth});
|
||||
|
||||
size_t height = channels * filterHeight * filterWidth;
|
||||
size_t width = outputHeight * outputWidth;
|
||||
VectorPtr input1 = Vector::create(imShape.getElements(), false);
|
||||
VectorPtr input2 = Vector::create(imShape.getElements(), false);
|
||||
MatrixPtr output1 = Matrix::create(height, width, false, false);
|
||||
MatrixPtr output2 = Matrix::create(width, height, false, false);
|
||||
input1->uniform(0.001, 1);
|
||||
input2->copyFrom(*input1);
|
||||
|
||||
Im2ColFunctor<kCFO, Device, T> im2Col1;
|
||||
Im2ColFunctor<kOCF, Device, T> im2Col2;
|
||||
im2Col1(input1->getData(),
|
||||
imShape,
|
||||
output1->getData(),
|
||||
colShape1,
|
||||
stride,
|
||||
stride,
|
||||
padding,
|
||||
padding);
|
||||
im2Col2(input2->getData(),
|
||||
imShape,
|
||||
output2->getData(),
|
||||
colShape2,
|
||||
stride,
|
||||
stride,
|
||||
padding,
|
||||
padding);
|
||||
|
||||
// The transposition of the result of ColFormat == kCFO
|
||||
// is equal to the result of ColFormat == kOCF.
|
||||
MatrixPtr test;
|
||||
output2->transpose(test, true);
|
||||
autotest::TensorCheckErr(*output1, *test);
|
||||
|
||||
Col2ImFunctor<kCFO, Device, T> col2Im1;
|
||||
Col2ImFunctor<kOCF, Device, T> col2Im2;
|
||||
col2Im1(input1->getData(),
|
||||
imShape,
|
||||
output1->getData(),
|
||||
colShape1,
|
||||
stride,
|
||||
stride,
|
||||
padding,
|
||||
padding);
|
||||
col2Im2(input2->getData(),
|
||||
imShape,
|
||||
output2->getData(),
|
||||
colShape2,
|
||||
stride,
|
||||
stride,
|
||||
padding,
|
||||
padding);
|
||||
|
||||
autotest::TensorCheckErr(*input1, *input2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor<DEVICE_TYPE_CPU, float>(); }
|
||||
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
|
||||
TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); }
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace paddle
|
Loading…
Reference in new issue