parent
371147003c
commit
3b65bc7a26
@ -0,0 +1,128 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "ConvFunc.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/*
|
||||
* The three arguments are stored in memory in row major order.
|
||||
* inputData = [batchSize, inputChannels, inputHeight, inputWidth]
|
||||
* filterData = [outputChannels, inputChannels, filterHeight, filterWidth]
|
||||
* outputData = [batchSize, outputChannels, outputHeight, outputWidth]
|
||||
*/
|
||||
template <class T>
|
||||
class NaiveConvFunctor {
|
||||
public:
|
||||
void operator()(const T* inputData,
|
||||
size_t batchSize,
|
||||
size_t inputChannels,
|
||||
size_t inputHeight,
|
||||
size_t inputWidth,
|
||||
const T* filterData,
|
||||
size_t filterHeight,
|
||||
size_t filterWidth,
|
||||
T* outputData,
|
||||
size_t outputChannels,
|
||||
size_t outputHeight,
|
||||
size_t outputWidth,
|
||||
size_t padding,
|
||||
size_t stride) {
|
||||
for (size_t batch = 0; batch < batchSize; batch++) {
|
||||
for (size_t outC = 0; outC < outputChannels; outC++) {
|
||||
for (size_t outH = 0; outH < outputHeight; outH++) {
|
||||
for (size_t outW = 0; outW < outputWidth; outW++) {
|
||||
const int inStartH = (outH * stride) - padding;
|
||||
const int inStartW = (outW * stride) - padding;
|
||||
T outValue = (T)0;
|
||||
for (size_t inC = 0; inC < inputChannels; inC++) {
|
||||
for (size_t fH = 0; fH < filterHeight; fH++) {
|
||||
for (size_t fW = 0; fW < filterWidth; fW++) {
|
||||
T inValue;
|
||||
const int inH = inStartH + fH;
|
||||
const int inW = inStartW + fW;
|
||||
if ((inH >= 0 && inH < inputHeight) &&
|
||||
(inW >= 0 && inW < inputWidth)) {
|
||||
size_t offsetInput =
|
||||
batch * inputChannels * inputHeight * inputWidth +
|
||||
inC * inputHeight * inputWidth + inH * inputWidth + inW;
|
||||
inValue = inputData[offsetInput];
|
||||
} else {
|
||||
inValue = (T)0;
|
||||
}
|
||||
size_t offsetFilter =
|
||||
outC * inputChannels * filterHeight * filterWidth +
|
||||
inC * filterHeight * filterWidth + fH * filterWidth + fW;
|
||||
T filterValue = filterData[offsetFilter];
|
||||
outValue += (inValue * filterValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t offset =
|
||||
batch * outputChannels * outputHeight * outputWidth +
|
||||
outC * outputHeight * outputWidth + outH * outputWidth + outW;
|
||||
outputData[offset] = outValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <DeviceType Device>
|
||||
class NaiveConvFunction : public ConvFunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {
|
||||
ConvFunctionBase::init(config);
|
||||
}
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
check(inputs, outputs);
|
||||
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
|
||||
|
||||
size_t batchSize = inputs[0].shape()[0];
|
||||
size_t inputChannels = inputs[0].shape()[1];
|
||||
size_t inputHeight = inputs[0].shape()[2];
|
||||
size_t inputWidth = inputs[0].shape()[3];
|
||||
size_t filterHeight = inputs[1].shape()[2];
|
||||
size_t filterWidth = inputs[1].shape()[2];
|
||||
size_t outputChannels = outputs[0].shape()[1];
|
||||
size_t outputHeight = outputs[0].shape()[2];
|
||||
size_t outputWidth = outputs[0].shape()[3];
|
||||
|
||||
float* inputData = inputs[0].data<float>();
|
||||
float* filterData = inputs[1].data<float>();
|
||||
float* outputData = outputs[0].data<float>();
|
||||
NaiveConvFunctor<float> conv;
|
||||
conv(inputData,
|
||||
batchSize,
|
||||
inputChannels,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
filterData,
|
||||
filterHeight,
|
||||
filterWidth,
|
||||
outputData,
|
||||
outputChannels,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
padding_,
|
||||
stride_);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_TYPED_FUNC(NaiveConv, CPU, NaiveConvFunction);
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,67 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "Function.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/*
|
||||
* Function Arguments:
|
||||
*
|
||||
* \param inputs[0] Input image data, is NCHW format, where N is batch size,
|
||||
* C is the number of channels, H and W is the height and
|
||||
* width of input image.
|
||||
* \param inputs[1] Filter data, is MCHW, where M is the number of output
|
||||
* channels, C is the number of input channels, H and W
|
||||
* is height and width of filter.
|
||||
* \param outputs[0] Output image data, is NCHW format, where N is batch size,
|
||||
* C is the number of channels, H and W is the height and
|
||||
* width of output image.
|
||||
*
|
||||
* \note Implemented based on the ConvFunctionBase class only supports
|
||||
* input data in the NCHW format.
|
||||
*/
|
||||
class ConvFunctionBase : public FunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {
|
||||
// function arguments
|
||||
stride_ = config.get<size_t>("stride");
|
||||
padding_ = config.get<size_t>("padding");
|
||||
|
||||
// number of inputs and outputs
|
||||
numInputs_ = 2;
|
||||
numOutputs_ = 1;
|
||||
}
|
||||
|
||||
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
|
||||
|
||||
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
CHECK_EQ(numInputs_, inputs.size());
|
||||
CHECK_EQ(numOutputs_, outputs.size());
|
||||
|
||||
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
|
||||
CHECK_EQ(inputs[1].shape().ndims(), (size_t)4);
|
||||
CHECK_EQ(outputs[0].shape().ndims(), (size_t)4);
|
||||
|
||||
CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]);
|
||||
CHECK(inputs[0].shape()[1] == inputs[1].shape()[1]);
|
||||
CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]);
|
||||
}
|
||||
|
||||
protected:
|
||||
size_t padding_;
|
||||
size_t stride_;
|
||||
};
|
||||
|
||||
} // namespace paddle
|
Loading…
Reference in new issue