1. Add switch function for switching image dimensions order 2. Add CpuMatrix::backwardSoftmax function 3. Add pixel softmax layer, python wrapper and grad_testAdaptive_data_structure_for_SwitchOrderLayer
parent
9837896827
commit
29f25fbe03
@ -0,0 +1,132 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "SwitchOp.h"
|
||||||
|
#include "paddle/math/Vector.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void NCHW2NHWC<DEVICE_TYPE_CPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inC,
|
||||||
|
const int inH,
|
||||||
|
const int inW) {
|
||||||
|
for (int n = 0; n < num; ++n) {
|
||||||
|
for (int c = 0; c < inC; ++c) {
|
||||||
|
for (int h = 0; h < inH; ++h) {
|
||||||
|
for (int w = 0; w < inW; ++w) {
|
||||||
|
outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inH,
|
||||||
|
const int inW,
|
||||||
|
const int inC) {
|
||||||
|
for (int n = 0; n < num; ++n) {
|
||||||
|
for (int h = 0; h < inH; ++h) {
|
||||||
|
for (int w = 0; w < inW; ++w) {
|
||||||
|
for (int c = 0; c < inC; ++c) {
|
||||||
|
outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Padding zeros to input according to the specify dimension.
|
||||||
|
* The struct pad_ contains the padding size in each dimension.
|
||||||
|
* The input and output is a 4D tensor. In PadFunc, we only
|
||||||
|
* pad zeros to the 2nd to 4th dimension.
|
||||||
|
*
|
||||||
|
* Argument in this Function:
|
||||||
|
* \param pad_ A struct object contains the padding size in each dimension.
|
||||||
|
* It has six integers. The channelStart and channelEnd indicate
|
||||||
|
* how many zeros to add before and after the input in channel
|
||||||
|
* dimension. And the heightStart and heightEnd indicate padding
|
||||||
|
* in height dimension. The widthStart and widthEnd indicate the
|
||||||
|
* padding in width dimension.
|
||||||
|
* \param inputs A 4D tensor, only one input.
|
||||||
|
* \param outputs A 4D tensor, the output value after padding.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <DeviceType Device>
|
||||||
|
class NCHW2NHWCFunc : public FunctionBase {
|
||||||
|
public:
|
||||||
|
void init(const FuncConfig& config) override {}
|
||||||
|
|
||||||
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||||
|
CHECK_EQ(1UL, inputs.size());
|
||||||
|
CHECK_EQ(1UL, outputs.size());
|
||||||
|
|
||||||
|
size_t num = inputs[0].shape()[0];
|
||||||
|
size_t inC = inputs[0].shape()[1];
|
||||||
|
size_t inH = inputs[0].shape()[2];
|
||||||
|
size_t inW = inputs[0].shape()[3];
|
||||||
|
typename Tensor<real, Device>::Vector vec(outputs[0].shape().getElements(),
|
||||||
|
outputs[0].data<real>());
|
||||||
|
vec.zero();
|
||||||
|
|
||||||
|
NCHW2NHWC<Device>(
|
||||||
|
outputs[0].data<real>(), inputs[0].data<real>(), num, inC, inH, inW);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief The backward propagation of padding Function. Remove the elements
|
||||||
|
* in the padding positions of forward.
|
||||||
|
*
|
||||||
|
* Argument in this Function:
|
||||||
|
* \param pad_ The same meaning as it in PadFunc.
|
||||||
|
* \param inputs The gradient with respect to the output value of PadFunc.
|
||||||
|
* \param outputs The gradient with respect to the input value of PadFunc.
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <DeviceType Device>
|
||||||
|
class NHWC2NCHWFunc : public FunctionBase {
|
||||||
|
public:
|
||||||
|
void init(const FuncConfig& config) override {}
|
||||||
|
|
||||||
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||||
|
CHECK_EQ(1UL, inputs.size());
|
||||||
|
CHECK_EQ(1UL, outputs.size());
|
||||||
|
|
||||||
|
size_t num = inputs[0].shape()[0];
|
||||||
|
size_t inH = inputs[0].shape()[1];
|
||||||
|
size_t inW = inputs[0].shape()[2];
|
||||||
|
size_t inC = inputs[0].shape()[3];
|
||||||
|
|
||||||
|
NHWC2NCHW<Device>(
|
||||||
|
outputs[0].data<real>(), inputs[0].data<real>(), num, inH, inW, inC);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc);
|
||||||
|
REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc);
|
||||||
|
#ifndef PADDLE_ONLY_CPU
|
||||||
|
REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc);
|
||||||
|
REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,62 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief This funtion switch dimension order of image input.
|
||||||
|
* The input and output is a 4D tensor. Switch order 'batch_size,
|
||||||
|
*channels, height, width' to
|
||||||
|
* order 'batch_size, height, width, channels'.
|
||||||
|
*
|
||||||
|
* \param[out] outputs save results.
|
||||||
|
* \param[in] inputs input data.
|
||||||
|
* \param[in] num batch size of input data.
|
||||||
|
* \param[in] inC channel number of input data.
|
||||||
|
* \param[in] inH height of input data.
|
||||||
|
* \param[in] inH with of input data.
|
||||||
|
*/
|
||||||
|
template <DeviceType Device>
|
||||||
|
void NCHW2NHWC(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inC,
|
||||||
|
const int inH,
|
||||||
|
const int inW);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief This funtion switch dimension order of image input.
|
||||||
|
* The input and output is a 4D tensor. Switch order 'batch_size,
|
||||||
|
*height, width, channels' to
|
||||||
|
* order 'batch_size, channels, height, width'.
|
||||||
|
*
|
||||||
|
* \param[out] inGrad gradients of previous layer.
|
||||||
|
* \param[in] outGrad output gradients.
|
||||||
|
* \param[in] num batch size of input data.
|
||||||
|
* \param[in] inH height of input data.
|
||||||
|
* \param[in] inW with of input data.
|
||||||
|
* \param[in] inC channel number of input data.
|
||||||
|
*/
|
||||||
|
template <DeviceType Device>
|
||||||
|
void NHWC2NCHW(real* inGrad,
|
||||||
|
const real* outGrad,
|
||||||
|
const int num,
|
||||||
|
const int inH,
|
||||||
|
const int inW,
|
||||||
|
const int inC);
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,80 @@
|
|||||||
|
/* Copyright (c) 2016 Paddle
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "hl_base.h"
|
||||||
|
#include "SwitchOp.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
__global__ void KeNCHW2NHWC(real* outputs, const real* inputs,
|
||||||
|
int inC, int inH, int inW,
|
||||||
|
int nthreads) {
|
||||||
|
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (idx < nthreads) {
|
||||||
|
const int w = idx % inW;
|
||||||
|
const int h = (idx / inW) % inH;
|
||||||
|
const int c = (idx / inW / inH) % inC;
|
||||||
|
const int n = idx / inW / inH / inC;
|
||||||
|
|
||||||
|
const int off = ((n * inH + h) * inW + w) * inC +c;
|
||||||
|
outputs[off] = inputs[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void NCHW2NHWC<DEVICE_TYPE_GPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inC,
|
||||||
|
const int inH,
|
||||||
|
const int inW) {
|
||||||
|
size_t nth = num * inC * inH * inW;
|
||||||
|
int blockSize = 1024;
|
||||||
|
int gridSize = (nth + 1024 - 1) / 1024;
|
||||||
|
KeNCHW2NHWC<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
|
||||||
|
(outputs, inputs, inC, inH, inW, nth);
|
||||||
|
CHECK_SYNC("NCHW2NHWC");
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void KeNHWC2NCHW(real* outputs, const real* inputs,
|
||||||
|
int inH, int inW, int inC,
|
||||||
|
int nthreads) {
|
||||||
|
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (idx < nthreads) {
|
||||||
|
const int c = idx % inC;
|
||||||
|
const int w = (idx / inC) % inW;
|
||||||
|
const int h = (idx / inC / inW) % inH;
|
||||||
|
const int n = idx / inW / inH / inC;
|
||||||
|
|
||||||
|
const int off = ((n * inC + c) * inH + h) * inW + w;
|
||||||
|
outputs[off] = inputs[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void NHWC2NCHW<DEVICE_TYPE_GPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inH,
|
||||||
|
const int inW,
|
||||||
|
const int inC) {
|
||||||
|
int nth = num * inC * inH * inW;
|
||||||
|
int blockSize = 1024;
|
||||||
|
int gridSize = (nth + 1024 - 1) / 1024;
|
||||||
|
KeNHWC2NCHW<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
|
||||||
|
(outputs, inputs, inH, inW, inC, nth);
|
||||||
|
CHECK_SYNC("NHWC2NCHW");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "FunctionTest.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
TEST(Pad, real) {
|
||||||
|
for (size_t numSamples : {1, 4, 8, 16}) {
|
||||||
|
for (size_t channels : {1, 4, 8, 16}) {
|
||||||
|
for (size_t imgSizeH : {1, 4, 8, 16}) {
|
||||||
|
for (size_t imgSizeW : {1, 4, 8, 16}) {
|
||||||
|
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
|
||||||
|
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
|
||||||
|
for (bool test_grad : {true, false}) {
|
||||||
|
CpuGpuFuncCompare compare(test_grad ? "NHWC2NCHW" : "NCHW2NHWC",
|
||||||
|
FuncConfig());
|
||||||
|
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
|
||||||
|
TensorShape outDims{numSamples, imgSizeH, imgSizeW, channels};
|
||||||
|
compare.addInputs(
|
||||||
|
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
|
||||||
|
compare.addOutputs(BufferArg(
|
||||||
|
VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
|
||||||
|
compare.run();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,89 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "PixelSoftmaxLayer.h"
|
||||||
|
#include "paddle/utils/Stat.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
REGISTER_LAYER(pixel_softmax, PixelSoftmaxLayer);
|
||||||
|
|
||||||
|
bool PixelSoftmaxLayer::init(const LayerMap& layerMap,
|
||||||
|
const ParameterMap& parameterMap) {
|
||||||
|
/* Initialize the basic parent class */
|
||||||
|
Layer::init(layerMap, parameterMap);
|
||||||
|
auto& img_conf = config_.inputs(0).image_conf();
|
||||||
|
inH_ =
|
||||||
|
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
|
||||||
|
inW_ = img_conf.img_size();
|
||||||
|
inC_ = img_conf.channels();
|
||||||
|
createFunction(forward_, "NCHW2NHWC", FuncConfig());
|
||||||
|
createFunction(backward_, "NHWC2NCHW", FuncConfig());
|
||||||
|
inDims_ = TensorShape({0, inH_, inW_, inC_});
|
||||||
|
outDims_ = TensorShape({0, inC_, inH_, inW_});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PixelSoftmaxLayer::forward(PassType passType) {
|
||||||
|
Layer::forward(passType);
|
||||||
|
MatrixPtr input = inputLayers_[0]->getOutputValue();
|
||||||
|
size_t batchSize = input->getHeight();
|
||||||
|
// cout<<"useGpu:"<<useGpu(deviceId_)<<endl;
|
||||||
|
Matrix::resizeOrCreate(
|
||||||
|
tmpInput_, batchSize * inH_ * inW_, inC_, false, useGpu_);
|
||||||
|
Matrix::resizeOrCreate(
|
||||||
|
tmpOutput_, batchSize * inH_ * inW_, inC_, false, useGpu_);
|
||||||
|
tmpOutput_->zeroMem();
|
||||||
|
resetOutput(batchSize, inH_ * inW_ * inC_);
|
||||||
|
inDims_.setDim(0, batchSize);
|
||||||
|
outDims_.setDim(0, batchSize);
|
||||||
|
|
||||||
|
// switch NCHW to NHWC
|
||||||
|
BufferArgs inputs;
|
||||||
|
BufferArgs outputs;
|
||||||
|
inputs.addArg(*getInputValue(0), inDims_);
|
||||||
|
outputs.addArg(*tmpInput_, outDims_);
|
||||||
|
forward_[0]->calc(inputs, outputs);
|
||||||
|
// softmax forward and save softmax result into tmpMatrix_
|
||||||
|
tmpInput_->softmax(*tmpOutput_);
|
||||||
|
|
||||||
|
// switch NHWC to NCHW
|
||||||
|
BufferArgs inputs_1;
|
||||||
|
BufferArgs outputs_1;
|
||||||
|
inputs_1.addArg(*tmpOutput_, outDims_);
|
||||||
|
outputs_1.addArg(*getOutputValue(), inDims_);
|
||||||
|
backward_[0]->calc(inputs_1, outputs_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PixelSoftmaxLayer::backward(const UpdateCallback& callback) {
|
||||||
|
(void)callback;
|
||||||
|
REGISTER_TIMER_INFO("PixelSoftmaxBackward", getName().c_str());
|
||||||
|
|
||||||
|
// switch NCHW to NHWC
|
||||||
|
BufferArgs inputs;
|
||||||
|
BufferArgs outputs;
|
||||||
|
inputs.addArg(*getOutputGrad(), inDims_);
|
||||||
|
outputs.addArg(*tmpInput_, outDims_);
|
||||||
|
forward_[0]->calc(inputs, outputs);
|
||||||
|
// softmax backward and save grad result into tmpOutput_
|
||||||
|
tmpInput_->softmaxBackward(*tmpOutput_);
|
||||||
|
|
||||||
|
// switch NHWC to NCHW
|
||||||
|
BufferArgs inputs_1;
|
||||||
|
BufferArgs outputs_1;
|
||||||
|
inputs_1.addArg(*tmpInput_, outDims_);
|
||||||
|
outputs_1.addArg(*getInputGrad(0), inDims_);
|
||||||
|
backward_[0]->calc(inputs_1, outputs_1);
|
||||||
|
}
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Layer.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief This layer calculate softmax in image channel dimension.
|
||||||
|
*/
|
||||||
|
class PixelSoftmaxLayer : public Layer {
|
||||||
|
public:
|
||||||
|
explicit PixelSoftmaxLayer(const LayerConfig& config) : Layer(config) {}
|
||||||
|
|
||||||
|
~PixelSoftmaxLayer() {}
|
||||||
|
|
||||||
|
bool init(const LayerMap& layerMap,
|
||||||
|
const ParameterMap& parameterMap) override;
|
||||||
|
void forward(PassType passType) override;
|
||||||
|
void backward(const UpdateCallback& callback = nullptr) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
uint32_t inC_;
|
||||||
|
uint32_t inH_;
|
||||||
|
uint32_t inW_;
|
||||||
|
TensorShape inDims_;
|
||||||
|
TensorShape outDims_;
|
||||||
|
MatrixPtr tmpInput_;
|
||||||
|
MatrixPtr tmpOutput_;
|
||||||
|
};
|
||||||
|
} // namespace paddle
|
||||||
Loading…
Reference in new issue