commit
17fe832209
@ -0,0 +1,146 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Function.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/*
|
||||
* \brief Based on the ConvFunctionBase class, the forward calculation,
|
||||
* backward input calculation and backward filter calculation
|
||||
* of convolution operations can be implemented.
|
||||
*
|
||||
* Arguments of forward and backward calculation:
|
||||
* 1. Forward calculation of convolution.
|
||||
* inputs = {INPUT, FILTER}, outputs = {OUTPUT}
|
||||
* The first and second input arguments are input image and filter data.
|
||||
* The output argument is output image.
|
||||
*
|
||||
* 2. Backward input calculation of convolution.
|
||||
* inputs = {OUTPUT_GRAD, FILTER}, outputs = {INPUT_GRAD}
|
||||
* The first and second input arguments are output grad image
|
||||
* and filter data.
|
||||
* The output argument is input grad image.
|
||||
*
|
||||
* 3. Backward filter calculation of convolution.
|
||||
* inputs = {OUTPUT_GRAD, INPUT}, outputs = {FILTER_GRAD}
|
||||
* The first and second input arguments are output grad image
|
||||
* and input image.
|
||||
* The output argument is filter grad.
|
||||
*
|
||||
* Arguments format of input, filter and output:
|
||||
* 1. Input image, output image, input image gradient, output image gradient
|
||||
* are all NCHW format. Where N is batch size, C is the number of channels,
|
||||
* H and W is the height and width of image or image gradient.
|
||||
*
|
||||
* 2. The format of the filter data is MCHW, where M is the number of output
|
||||
* image channels, C is the number of input image channels,
|
||||
* H and W is height and width of filter.
|
||||
*
|
||||
* If `groups` is greater than 1, the filter's data format should be GMCHW,
|
||||
* where G is the `groups`, and G * M is the number of output image
|
||||
* channels, G * C is the number of input image channels,
|
||||
* H and W is height and width of filter.
|
||||
*/
|
||||
class ConvFunctionBase : public FunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {
|
||||
// function arguments
|
||||
strides_ = config.get<std::vector<size_t>>("strides");
|
||||
paddings_ = config.get<std::vector<size_t>>("paddings");
|
||||
groups_ = config.get<size_t>("groups");
|
||||
|
||||
// number of inputs and outputs
|
||||
numInputs_ = 2;
|
||||
numOutputs_ = 1;
|
||||
}
|
||||
|
||||
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
|
||||
|
||||
// input can be INPUT and INPUT_GRAD
|
||||
// filter can be FILTER and FILTER_GRAD
|
||||
// output can be OUTPUT and OUTPUT_GRAD
|
||||
void check(const TensorShape& input,
|
||||
const TensorShape& filter,
|
||||
const TensorShape& output) {
|
||||
// inputs and outputs arguments should be 4-dimensional.
|
||||
CHECK_EQ(input.ndims(), (size_t)4);
|
||||
CHECK_EQ(output.ndims(), (size_t)4);
|
||||
// The batchSize of the input needs to be equal to
|
||||
// the batchSize of the output.
|
||||
CHECK_EQ(input[0], output[0]);
|
||||
|
||||
if (filter.ndims() == (size_t)4) {
|
||||
// If the filter's dimension is 4, groups convolution is not supported.
|
||||
CHECK_EQ(groups_, (size_t)1);
|
||||
// The input and output channel dimensions are the second and first
|
||||
// dimensions of the filter shape.
|
||||
CHECK_EQ(input[1], filter[1]);
|
||||
CHECK_EQ(output[1], filter[0]);
|
||||
} else {
|
||||
// filter argument should be 5-dimensional.
|
||||
CHECK_EQ(filter.ndims(), (size_t)5);
|
||||
// The first dimension of the filter is the size of the group
|
||||
CHECK_EQ(filter[0], groups_);
|
||||
// The input and output channel dimensions are the third and second
|
||||
// dimensions of the filter shape.
|
||||
CHECK_EQ(input[1], filter[2] * groups_);
|
||||
CHECK_EQ(output[1], filter[1] * groups_);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
size_t getFilterHeight(const TensorShape& filter) const {
|
||||
return filter[filter.ndims() - 2];
|
||||
}
|
||||
|
||||
size_t getFilterWidth(const TensorShape& filter) const {
|
||||
return filter[filter.ndims() - 1];
|
||||
}
|
||||
|
||||
std::vector<size_t> strides_;
|
||||
std::vector<size_t> paddings_;
|
||||
|
||||
/// Group size, refer to grouped convolution in
|
||||
/// Alex Krizhevsky's paper: when group=2, the first half of the
|
||||
/// filters are only connected to the first half of the input channels,
|
||||
/// and the second half only connected to the second half.
|
||||
size_t groups_;
|
||||
|
||||
inline int strideH() const { return strides_[0]; }
|
||||
|
||||
inline int strideW() const { return strides_[1]; }
|
||||
|
||||
inline int paddingH() const { return paddings_[0]; }
|
||||
|
||||
inline int paddingW() const { return paddings_[1]; }
|
||||
|
||||
// A temporary memory in convolution calculation.
|
||||
MemoryHandlePtr memory_;
|
||||
|
||||
template <DeviceType Device>
|
||||
void resizeBuffer(size_t newSize) {
|
||||
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
|
||||
if (Device == DEVICE_TYPE_CPU) {
|
||||
memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real));
|
||||
} else {
|
||||
memory_ = std::make_shared<GpuMemoryHandle>(newSize * sizeof(real));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,210 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
#include "Function.h"
|
||||
#include "FunctionTest.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
enum TestType {
|
||||
kForwardTest = 0,
|
||||
kBackwardInputTest = 1,
|
||||
kBackwardFilterTest = 2,
|
||||
};
|
||||
|
||||
template <DeviceType DType1, DeviceType DType2>
|
||||
class ConvolutionTest {
|
||||
public:
|
||||
ConvolutionTest(const std::string& conv1,
|
||||
const std::string& conv2,
|
||||
TestType type,
|
||||
std::string algo = "auto") {
|
||||
for (size_t batchSize : {1, 32}) {
|
||||
for (size_t inputSize : {7, 14, 54}) {
|
||||
for (size_t filterSize : {1, 3, 5}) {
|
||||
for (size_t inputChannels : {3, 64}) {
|
||||
for (size_t outputChannels : {3, 64, 128}) {
|
||||
if (inputChannels < outputChannels) break;
|
||||
for (size_t stride : {1, 2}) {
|
||||
for (size_t padding : {0, 1}) {
|
||||
if (padding >= filterSize) break;
|
||||
size_t outputSize =
|
||||
(inputSize - filterSize + 2 * padding + stride) / stride;
|
||||
VLOG(3) << " batchSize=" << batchSize
|
||||
<< " inputChannels=" << inputChannels
|
||||
<< " inputHeight=" << inputSize
|
||||
<< " inputWidth=" << inputSize
|
||||
<< " outputChannels=" << outputChannels
|
||||
<< " filterHeight=" << filterSize
|
||||
<< " filterWidth=" << filterSize
|
||||
<< " outputHeight=" << outputSize
|
||||
<< " outputWidth=" << outputSize
|
||||
<< " stride=" << stride << " padding=" << padding;
|
||||
|
||||
std::vector<size_t> paddings = {padding, padding};
|
||||
std::vector<size_t> strides = {stride, stride};
|
||||
Compare2Function<DType1, DType2> test(
|
||||
conv1,
|
||||
conv2,
|
||||
FuncConfig()
|
||||
.set("paddings", paddings)
|
||||
.set("strides", strides)
|
||||
.set("groups", (size_t)1)
|
||||
.set("algo", algo));
|
||||
|
||||
TensorShape input{
|
||||
batchSize, inputChannels, inputSize, inputSize};
|
||||
TensorShape filter{
|
||||
outputChannels, inputChannels, filterSize, filterSize};
|
||||
TensorShape output{
|
||||
batchSize, outputChannels, outputSize, outputSize};
|
||||
|
||||
if (type == kForwardTest) {
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
|
||||
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
|
||||
test.run();
|
||||
} else if (type == kBackwardInputTest) {
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
|
||||
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
|
||||
test.run();
|
||||
} else if (type == kBackwardFilterTest) {
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
|
||||
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
|
||||
test.run();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Mainly used to test cases where the height and width (input, filter)
|
||||
// are not equal.
|
||||
template <DeviceType DType1, DeviceType DType2>
|
||||
class ConvolutionTest2 {
|
||||
public:
|
||||
ConvolutionTest2(const std::string& conv1,
|
||||
const std::string& conv2,
|
||||
TestType type,
|
||||
std::string algo = "auto") {
|
||||
for (size_t batchSize : {16}) {
|
||||
for (size_t inputHeight : {7, 31}) {
|
||||
for (size_t inputWidth : {10, 54}) {
|
||||
for (size_t filterHeight : {1, 5}) {
|
||||
for (size_t filterWidth : {3, 7}) {
|
||||
for (size_t inputChannels : {7}) {
|
||||
for (size_t outputChannels : {32}) {
|
||||
size_t stride = 1;
|
||||
size_t padding = 0;
|
||||
size_t outputHeight =
|
||||
(inputHeight - filterHeight + 2 * padding + stride) /
|
||||
stride;
|
||||
size_t outputWidth =
|
||||
(inputWidth - filterWidth + 2 * padding + stride) /
|
||||
stride;
|
||||
VLOG(3) << " batchSize=" << batchSize
|
||||
<< " inputChannels=" << inputChannels
|
||||
<< " inputHeight=" << inputHeight
|
||||
<< " inputWidth=" << inputWidth
|
||||
<< " outputChannels=" << outputChannels
|
||||
<< " filterHeight=" << filterHeight
|
||||
<< " filterWidth=" << filterWidth
|
||||
<< " outputHeight=" << outputHeight
|
||||
<< " outputWidth=" << outputWidth
|
||||
<< " stride=" << stride << " padding=" << padding;
|
||||
|
||||
std::vector<size_t> paddings = {padding, padding};
|
||||
std::vector<size_t> strides = {stride, stride};
|
||||
Compare2Function<DType1, DType2> test(
|
||||
conv1,
|
||||
conv2,
|
||||
FuncConfig()
|
||||
.set("paddings", paddings)
|
||||
.set("strides", strides)
|
||||
.set("groups", (size_t)1)
|
||||
.set("algo", algo));
|
||||
|
||||
TensorShape input{
|
||||
batchSize, inputChannels, inputHeight, inputWidth};
|
||||
TensorShape filter{
|
||||
outputChannels, inputChannels, filterHeight, filterWidth};
|
||||
TensorShape output{
|
||||
batchSize, outputChannels, outputHeight, outputWidth};
|
||||
|
||||
if (type == kForwardTest) {
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
|
||||
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
|
||||
test.run();
|
||||
} else if (type == kBackwardInputTest) {
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
|
||||
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
|
||||
test.run();
|
||||
} else if (type == kBackwardFilterTest) {
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
|
||||
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
|
||||
test.run();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST(Forward, GEMM) {
|
||||
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
|
||||
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest);
|
||||
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test2(
|
||||
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest);
|
||||
}
|
||||
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
TEST(Forward, GEMM2) {
|
||||
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
|
||||
"GemmConv-CPU", "GemmConv-GPU", kForwardTest);
|
||||
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
|
||||
"GemmConv-CPU", "GemmConv-GPU", kForwardTest);
|
||||
}
|
||||
|
||||
TEST(BackwardInput, GEMM) {
|
||||
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
|
||||
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest);
|
||||
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
|
||||
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest);
|
||||
}
|
||||
|
||||
TEST(BackwardFilter, GEMM) {
|
||||
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
|
||||
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest);
|
||||
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
|
||||
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,62 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ConvOp.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/*
|
||||
* imData = [input_channels, input_height, input_width]
|
||||
* colData = [input_channels, filter_height, filter_width,
|
||||
* output_height, output_width]
|
||||
*/
|
||||
template <DeviceType Device, class T>
|
||||
class Im2ColFunctor {
|
||||
public:
|
||||
void operator()(const T* imData,
|
||||
int inputChannels,
|
||||
int inputHeight,
|
||||
int inputWidth,
|
||||
int filterHeight,
|
||||
int filterWidth,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth,
|
||||
int outputHeight,
|
||||
int outputWidth,
|
||||
T* colData);
|
||||
};
|
||||
|
||||
template <DeviceType Device, class T>
|
||||
class Col2ImFunctor {
|
||||
public:
|
||||
void operator()(const T* colData,
|
||||
int inputChannels,
|
||||
int inputHeight,
|
||||
int inputWidth,
|
||||
int filterHeight,
|
||||
int filterWidth,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth,
|
||||
int outputHeight,
|
||||
int outputWidth,
|
||||
T* imData);
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,186 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "ConvOp.h"
|
||||
#include "GemmConvOp.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
template<class T>
|
||||
__global__
|
||||
void im2col(const T* data_im, int numOuts, int height, int width,
|
||||
int blockH, int blockW,
|
||||
int strideH, int strideW,
|
||||
int paddingH, int paddingW,
|
||||
int height_col, int width_col,
|
||||
T* data_col) {
|
||||
int index =
|
||||
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
|
||||
if (index < numOuts) {
|
||||
int w_out = index % width_col;
|
||||
index /= width_col;
|
||||
int h_out = index % height_col;
|
||||
int channel_in = index / height_col;
|
||||
int channel_out = channel_in * blockH * blockW;
|
||||
int h_in = h_out * strideH;
|
||||
int w_in = w_out * strideW;
|
||||
|
||||
data_col += (channel_out * height_col + h_out) * width_col + w_out;
|
||||
for (int i = 0; i < blockH; ++i) {
|
||||
for (int j = 0; j < blockW; ++j) {
|
||||
int rIdx = int(h_in+i);
|
||||
int cIdx = int(w_in+j);
|
||||
if ((rIdx-(int)paddingH) >= (int)height ||
|
||||
(rIdx-(int)paddingH) < 0 ||
|
||||
(cIdx-(int)paddingW) >= (int)width ||
|
||||
(cIdx-(int)paddingW) < 0) {
|
||||
*data_col = 0;
|
||||
} else {
|
||||
rIdx = rIdx + channel_in*height - paddingH;
|
||||
cIdx = cIdx - paddingW;
|
||||
*data_col = data_im[rIdx* width + cIdx];
|
||||
}
|
||||
data_col += height_col * width_col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
class Im2ColFunctor<DEVICE_TYPE_GPU, T> {
|
||||
public:
|
||||
void operator()(const T* imData,
|
||||
int inputChannels,
|
||||
int inputHeight,
|
||||
int inputWidth,
|
||||
int filterHeight,
|
||||
int filterWidth,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth,
|
||||
int outputHeight,
|
||||
int outputWidth,
|
||||
T* colData) {
|
||||
int numKernels = inputChannels * outputHeight * outputWidth;
|
||||
int blocks = (numKernels + 1024 -1) / 1024;
|
||||
int blockX = 512;
|
||||
int blockY = (blocks + 512 - 1) / 512;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(blockX, blockY);
|
||||
im2col<T><<< grid, threads, 0, STREAM_DEFAULT >>>
|
||||
(imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth,
|
||||
strideHeight, strideWidth, paddingHeight, paddingWidth,
|
||||
outputHeight, outputWidth, colData);
|
||||
CHECK_SYNC("Im2ColFunctor GPU failed");
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
__global__
|
||||
void col2im(size_t n, const T* data_col, size_t height,
|
||||
size_t width, size_t channels,
|
||||
size_t blockH, size_t blockW,
|
||||
size_t strideH, size_t strideW,
|
||||
size_t paddingH, size_t paddingW,
|
||||
size_t height_col, size_t width_col,
|
||||
T* data_im) {
|
||||
size_t index =
|
||||
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
|
||||
if (index < n) {
|
||||
T val = 0;
|
||||
int w = int(index % width);
|
||||
int h = int((index / width) % height);
|
||||
int c = int(index / (width * height));
|
||||
if ((w - (int)paddingW) >= 0 &&
|
||||
(w - (int)paddingW) < (width-2 * paddingW) &&
|
||||
(h - (int)paddingH) >= 0 &&
|
||||
(h - paddingH) < (height - 2 * paddingH)) {
|
||||
// compute the start and end of the output
|
||||
int w_col_start =
|
||||
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
|
||||
int w_col_end =
|
||||
min((int)(w / (int)strideW + 1), (int)(width_col));
|
||||
int h_col_start =
|
||||
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
|
||||
int h_col_end = min(int(h / strideH + 1), int(height_col));
|
||||
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
|
||||
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
|
||||
// the col location: [c * width * height + h_out, w_out]
|
||||
int c_col = int(c * blockH* blockW) + \
|
||||
(h - h_col * (int)strideH) * (int)blockW +
|
||||
(w - w_col * (int)strideW);
|
||||
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
|
||||
}
|
||||
}
|
||||
h -= paddingH;
|
||||
w -= paddingW;
|
||||
data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
|
||||
h*(width-2*paddingW) + w] += val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
class Col2ImFunctor<DEVICE_TYPE_GPU, T> {
|
||||
public:
|
||||
void operator()(const T* colData,
|
||||
int inputChannels,
|
||||
int inputHeight,
|
||||
int inputWidth,
|
||||
int filterHeight,
|
||||
int filterWidth,
|
||||
int strideHeight,
|
||||
int strideWidth,
|
||||
int paddingHeight,
|
||||
int paddingWidth,
|
||||
int outputHeight,
|
||||
int outputWidth,
|
||||
T* imData) {
|
||||
size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight)
|
||||
* (inputWidth + 2*paddingWidth);
|
||||
|
||||
size_t blocks = (numKernels + 1024 -1) / 1024;
|
||||
size_t blockX = 512;
|
||||
size_t blockY = (blocks+512-1)/512;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(blockX, blockY);
|
||||
|
||||
// To avoid involving atomic operations, we will launch one kernel per
|
||||
// bottom dimension, and then in the kernel add up the top dimensions.
|
||||
col2im<T><<< grid, threads, 0, STREAM_DEFAULT >>>
|
||||
(numKernels,
|
||||
colData,
|
||||
inputHeight + 2*paddingHeight,
|
||||
inputWidth + 2*paddingWidth,
|
||||
inputChannels,
|
||||
filterHeight,
|
||||
filterWidth,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
paddingHeight,
|
||||
paddingWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
imData);
|
||||
CHECK_SYNC("Col2ImFunctor GPU failed");
|
||||
}
|
||||
};
|
||||
|
||||
template class Im2ColFunctor<DEVICE_TYPE_GPU, float>;
|
||||
template class Im2ColFunctor<DEVICE_TYPE_GPU, double>;
|
||||
template class Col2ImFunctor<DEVICE_TYPE_GPU, float>;
|
||||
template class Col2ImFunctor<DEVICE_TYPE_GPU, double>;
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,96 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/math/MathFunctions.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
// TODO(hedaoyuan): Since the hl_matrix_mul interface does not conform to the
|
||||
// cblas_dgemm interface's parameter format, it is necessary to introduce
|
||||
// GemmFunctor as a new interface. Later, when considering the implementation
|
||||
// of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul
|
||||
// interface.
|
||||
template <DeviceType Device, class T>
|
||||
class GemmFunctor {
|
||||
public:
|
||||
void operator()(const CBLAS_TRANSPOSE transA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const T alpha,
|
||||
const T* A,
|
||||
const int lda,
|
||||
const T* B,
|
||||
const int ldb,
|
||||
const T beta,
|
||||
T* C,
|
||||
const int ldc);
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class GemmFunctor<DEVICE_TYPE_CPU, T> {
|
||||
public:
|
||||
void operator()(const CBLAS_TRANSPOSE transA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const T alpha,
|
||||
const T* A,
|
||||
const int lda,
|
||||
const T* B,
|
||||
const int ldb,
|
||||
const T beta,
|
||||
T* C,
|
||||
const int ldc) {
|
||||
gemm<T>(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class GemmFunctor<DEVICE_TYPE_GPU, T> {
|
||||
public:
|
||||
void operator()(const CBLAS_TRANSPOSE transA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const T alpha,
|
||||
const T* A,
|
||||
const int lda,
|
||||
const T* B,
|
||||
const int ldb,
|
||||
const T beta,
|
||||
T* C,
|
||||
const int ldc) {
|
||||
hl_matrix_mul((T*)A,
|
||||
transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
|
||||
(T*)B,
|
||||
TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
|
||||
C,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta,
|
||||
lda,
|
||||
ldb,
|
||||
ldc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,137 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "ConvOp.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/*
|
||||
* The three arguments are stored in memory in row major order.
|
||||
* inputData = [batchSize, inputChannels, inputHeight, inputWidth]
|
||||
* filterData = [outputChannels, inputChannels, filterHeight, filterWidth]
|
||||
* outputData = [batchSize, outputChannels, outputHeight, outputWidth]
|
||||
*/
|
||||
template <class T>
|
||||
class NaiveConvFunctor {
|
||||
public:
|
||||
void operator()(const T* inputData,
|
||||
size_t batchSize,
|
||||
size_t inputChannels,
|
||||
size_t inputHeight,
|
||||
size_t inputWidth,
|
||||
const T* filterData,
|
||||
size_t filterHeight,
|
||||
size_t filterWidth,
|
||||
T* outputData,
|
||||
size_t outputChannels,
|
||||
size_t outputHeight,
|
||||
size_t outputWidth,
|
||||
size_t paddingH,
|
||||
size_t paddingW,
|
||||
size_t strideH,
|
||||
size_t strideW) {
|
||||
for (size_t batch = 0; batch < batchSize; batch++) {
|
||||
for (size_t outC = 0; outC < outputChannels; outC++) {
|
||||
for (size_t outH = 0; outH < outputHeight; outH++) {
|
||||
for (size_t outW = 0; outW < outputWidth; outW++) {
|
||||
const int inStartH = (outH * strideH) - paddingH;
|
||||
const int inStartW = (outW * strideW) - paddingW;
|
||||
T outValue = (T)0;
|
||||
for (size_t inC = 0; inC < inputChannels; inC++) {
|
||||
for (size_t fH = 0; fH < filterHeight; fH++) {
|
||||
for (size_t fW = 0; fW < filterWidth; fW++) {
|
||||
T inValue;
|
||||
const int inH = inStartH + fH;
|
||||
const int inW = inStartW + fW;
|
||||
if ((inH >= 0 && inH < inputHeight) &&
|
||||
(inW >= 0 && inW < inputWidth)) {
|
||||
size_t offsetInput =
|
||||
batch * inputChannels * inputHeight * inputWidth +
|
||||
inC * inputHeight * inputWidth + inH * inputWidth + inW;
|
||||
inValue = inputData[offsetInput];
|
||||
} else {
|
||||
inValue = (T)0;
|
||||
}
|
||||
size_t offsetFilter =
|
||||
outC * inputChannels * filterHeight * filterWidth +
|
||||
inC * filterHeight * filterWidth + fH * filterWidth + fW;
|
||||
T filterValue = filterData[offsetFilter];
|
||||
outValue += (inValue * filterValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t offset =
|
||||
batch * outputChannels * outputHeight * outputWidth +
|
||||
outC * outputHeight * outputWidth + outH * outputWidth + outW;
|
||||
outputData[offset] = outValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <DeviceType Device>
|
||||
class NaiveConvFunction : public ConvFunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {
|
||||
ConvFunctionBase::init(config);
|
||||
}
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
CHECK_EQ(numInputs_, inputs.size());
|
||||
CHECK_EQ(numOutputs_, outputs.size());
|
||||
const TensorShape& input = inputs[0].shape();
|
||||
const TensorShape& filter = inputs[1].shape();
|
||||
const TensorShape& output = outputs[0].shape();
|
||||
check(input, filter, output);
|
||||
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
|
||||
|
||||
size_t batchSize = inputs[0].shape()[0];
|
||||
size_t inputChannels = inputs[0].shape()[1];
|
||||
size_t inputHeight = inputs[0].shape()[2];
|
||||
size_t inputWidth = inputs[0].shape()[3];
|
||||
size_t filterHeight = inputs[1].shape()[2];
|
||||
size_t filterWidth = inputs[1].shape()[3];
|
||||
size_t outputChannels = outputs[0].shape()[1];
|
||||
size_t outputHeight = outputs[0].shape()[2];
|
||||
size_t outputWidth = outputs[0].shape()[3];
|
||||
|
||||
real* inputData = inputs[0].data<real>();
|
||||
real* filterData = inputs[1].data<real>();
|
||||
real* outputData = outputs[0].data<real>();
|
||||
NaiveConvFunctor<real> conv;
|
||||
conv(inputData,
|
||||
batchSize,
|
||||
inputChannels,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
filterData,
|
||||
filterHeight,
|
||||
filterWidth,
|
||||
outputData,
|
||||
outputChannels,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
paddingH(),
|
||||
paddingW(),
|
||||
strideH(),
|
||||
strideW());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_TYPED_FUNC(NaiveConv, CPU, NaiveConvFunction);
|
||||
|
||||
} // namespace paddle
|
@ -1,90 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "ExpandConvTransLayer.h"
|
||||
#include "paddle/utils/Logging.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
/* The implementation of the convTransLayer is basically a swap of forward and
|
||||
* backward of the original convLayer.
|
||||
* The variable naming follows the convention of the convLayer.
|
||||
* */
|
||||
|
||||
namespace paddle {
|
||||
|
||||
REGISTER_LAYER(exconvt, ExpandConvTransLayer);
|
||||
|
||||
bool ExpandConvTransLayer::init(const LayerMap &layerMap,
|
||||
const ParameterMap ¶meterMap) {
|
||||
/* Initialize the basic convolutional parent class */
|
||||
ExpandConvBaseLayer::init(layerMap, parameterMap);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ExpandConvTransLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
|
||||
/* malloc memory for the output_ if necessary */
|
||||
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
|
||||
resetOutput(batchSize, getOutputSize());
|
||||
|
||||
MatrixPtr output = nullptr;
|
||||
for (size_t i = 0; i < inputLayers_.size(); ++i) {
|
||||
LayerPtr prevLayer = getPrev(i);
|
||||
output = prevLayer->getOutputValue();
|
||||
REGISTER_TIMER_INFO("shrinkFwd", getName().c_str());
|
||||
bpropActs(output, getOutputValue(), i);
|
||||
}
|
||||
|
||||
/* add the bias-vector */
|
||||
if (biases_.get()) {
|
||||
if (sharedBiases_) {
|
||||
addSharedBias();
|
||||
} else {
|
||||
addUnsharedBias();
|
||||
}
|
||||
}
|
||||
|
||||
/* activation */
|
||||
forwardActivation();
|
||||
}
|
||||
|
||||
void ExpandConvTransLayer::backward(const UpdateCallback &callback) {
|
||||
backwardActivation();
|
||||
|
||||
MatrixPtr imageGrad = getOutputGrad();
|
||||
if (biases_ && biases_->getWGrad()) {
|
||||
bpropBiases(imageGrad);
|
||||
/* Increasing the number of gradient */
|
||||
biases_->getParameterPtr()->incUpdate(callback);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inputLayers_.size(); ++i) {
|
||||
/* First, calculate the input layers error */
|
||||
for (size_t off = 0; off < imageGrad->getHeight(); off++) {
|
||||
if (getPrev(i)->getOutputGrad()) {
|
||||
expandFwdOnce(imageGrad, getPrev(i)->getOutputGrad(), i, off);
|
||||
}
|
||||
}
|
||||
if (weights_[i]->getWGrad()) {
|
||||
/* Then, calculate the W-gradient for the current layer */
|
||||
bpropWeights(imageGrad, getPrev(i)->getOutputValue(), i);
|
||||
/* Increasing the number of gradient */
|
||||
weights_[i]->getParameterPtr()->incUpdate(callback);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -1,44 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "ExpandConvBaseLayer.h"
|
||||
#include "paddle/math/Matrix.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief A subclass of convolution layer.
|
||||
* This layer expands input and use matrix multiplication to
|
||||
* calculate convolution transpose (deconv) operation.
|
||||
*
|
||||
* The config file api is img_conv_layer with flag trans=True.
|
||||
*/
|
||||
class ExpandConvTransLayer : public ExpandConvBaseLayer {
|
||||
public:
|
||||
explicit ExpandConvTransLayer(const LayerConfig& config)
|
||||
: ExpandConvBaseLayer(config) {}
|
||||
|
||||
~ExpandConvTransLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback) override;
|
||||
};
|
||||
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue