1. Add switch function for switching image dimensions order 2. Add CpuMatrix::backwardSoftmax function 3. Add pixel softmax layer, python wrapper and grad_testAdaptive_data_structure_for_SwitchOrderLayer
parent
9837896827
commit
29f25fbe03
@ -0,0 +1,132 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "SwitchOp.h"
|
||||
#include "paddle/math/Vector.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
template <>
|
||||
void NCHW2NHWC<DEVICE_TYPE_CPU>(real* outputs,
|
||||
const real* inputs,
|
||||
const int num,
|
||||
const int inC,
|
||||
const int inH,
|
||||
const int inW) {
|
||||
for (int n = 0; n < num; ++n) {
|
||||
for (int c = 0; c < inC; ++c) {
|
||||
for (int h = 0; h < inH; ++h) {
|
||||
for (int w = 0; w < inW; ++w) {
|
||||
outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs,
|
||||
const real* inputs,
|
||||
const int num,
|
||||
const int inH,
|
||||
const int inW,
|
||||
const int inC) {
|
||||
for (int n = 0; n < num; ++n) {
|
||||
for (int h = 0; h < inH; ++h) {
|
||||
for (int w = 0; w < inW; ++w) {
|
||||
for (int c = 0; c < inC; ++c) {
|
||||
outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Padding zeros to input according to the specify dimension.
|
||||
* The struct pad_ contains the padding size in each dimension.
|
||||
* The input and output is a 4D tensor. In PadFunc, we only
|
||||
* pad zeros to the 2nd to 4th dimension.
|
||||
*
|
||||
* Argument in this Function:
|
||||
* \param pad_ A struct object contains the padding size in each dimension.
|
||||
* It has six integers. The channelStart and channelEnd indicate
|
||||
* how many zeros to add before and after the input in channel
|
||||
* dimension. And the heightStart and heightEnd indicate padding
|
||||
* in height dimension. The widthStart and widthEnd indicate the
|
||||
* padding in width dimension.
|
||||
* \param inputs A 4D tensor, only one input.
|
||||
* \param outputs A 4D tensor, the output value after padding.
|
||||
*
|
||||
*/
|
||||
|
||||
template <DeviceType Device>
|
||||
class NCHW2NHWCFunc : public FunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {}
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
CHECK_EQ(1UL, inputs.size());
|
||||
CHECK_EQ(1UL, outputs.size());
|
||||
|
||||
size_t num = inputs[0].shape()[0];
|
||||
size_t inC = inputs[0].shape()[1];
|
||||
size_t inH = inputs[0].shape()[2];
|
||||
size_t inW = inputs[0].shape()[3];
|
||||
typename Tensor<real, Device>::Vector vec(outputs[0].shape().getElements(),
|
||||
outputs[0].data<real>());
|
||||
vec.zero();
|
||||
|
||||
NCHW2NHWC<Device>(
|
||||
outputs[0].data<real>(), inputs[0].data<real>(), num, inC, inH, inW);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief The backward propagation of padding Function. Remove the elements
|
||||
* in the padding positions of forward.
|
||||
*
|
||||
* Argument in this Function:
|
||||
* \param pad_ The same meaning as it in PadFunc.
|
||||
* \param inputs The gradient with respect to the output value of PadFunc.
|
||||
* \param outputs The gradient with respect to the input value of PadFunc.
|
||||
*/
|
||||
|
||||
template <DeviceType Device>
|
||||
class NHWC2NCHWFunc : public FunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {}
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
CHECK_EQ(1UL, inputs.size());
|
||||
CHECK_EQ(1UL, outputs.size());
|
||||
|
||||
size_t num = inputs[0].shape()[0];
|
||||
size_t inH = inputs[0].shape()[1];
|
||||
size_t inW = inputs[0].shape()[2];
|
||||
size_t inC = inputs[0].shape()[3];
|
||||
|
||||
NHWC2NCHW<Device>(
|
||||
outputs[0].data<real>(), inputs[0].data<real>(), num, inH, inW, inC);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc);
|
||||
REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc);
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc);
|
||||
REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc);
|
||||
#endif
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,62 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Function.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* \brief This funtion switch dimension order of image input.
|
||||
* The input and output is a 4D tensor. Switch order 'batch_size,
|
||||
*channels, height, width' to
|
||||
* order 'batch_size, height, width, channels'.
|
||||
*
|
||||
* \param[out] outputs save results.
|
||||
* \param[in] inputs input data.
|
||||
* \param[in] num batch size of input data.
|
||||
* \param[in] inC channel number of input data.
|
||||
* \param[in] inH height of input data.
|
||||
* \param[in] inH with of input data.
|
||||
*/
|
||||
template <DeviceType Device>
|
||||
void NCHW2NHWC(real* outputs,
|
||||
const real* inputs,
|
||||
const int num,
|
||||
const int inC,
|
||||
const int inH,
|
||||
const int inW);
|
||||
|
||||
/**
|
||||
* \brief This funtion switch dimension order of image input.
|
||||
* The input and output is a 4D tensor. Switch order 'batch_size,
|
||||
*height, width, channels' to
|
||||
* order 'batch_size, channels, height, width'.
|
||||
*
|
||||
* \param[out] inGrad gradients of previous layer.
|
||||
* \param[in] outGrad output gradients.
|
||||
* \param[in] num batch size of input data.
|
||||
* \param[in] inH height of input data.
|
||||
* \param[in] inW with of input data.
|
||||
* \param[in] inC channel number of input data.
|
||||
*/
|
||||
template <DeviceType Device>
|
||||
void NHWC2NCHW(real* inGrad,
|
||||
const real* outGrad,
|
||||
const int num,
|
||||
const int inH,
|
||||
const int inW,
|
||||
const int inC);
|
||||
} // namespace paddle
|
@ -0,0 +1,80 @@
|
||||
/* Copyright (c) 2016 Paddle
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "hl_base.h"
|
||||
#include "SwitchOp.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
__global__ void KeNCHW2NHWC(real* outputs, const real* inputs,
|
||||
int inC, int inH, int inW,
|
||||
int nthreads) {
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx < nthreads) {
|
||||
const int w = idx % inW;
|
||||
const int h = (idx / inW) % inH;
|
||||
const int c = (idx / inW / inH) % inC;
|
||||
const int n = idx / inW / inH / inC;
|
||||
|
||||
const int off = ((n * inH + h) * inW + w) * inC +c;
|
||||
outputs[off] = inputs[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void NCHW2NHWC<DEVICE_TYPE_GPU>(real* outputs,
|
||||
const real* inputs,
|
||||
const int num,
|
||||
const int inC,
|
||||
const int inH,
|
||||
const int inW) {
|
||||
size_t nth = num * inC * inH * inW;
|
||||
int blockSize = 1024;
|
||||
int gridSize = (nth + 1024 - 1) / 1024;
|
||||
KeNCHW2NHWC<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
|
||||
(outputs, inputs, inC, inH, inW, nth);
|
||||
CHECK_SYNC("NCHW2NHWC");
|
||||
}
|
||||
|
||||
__global__ void KeNHWC2NCHW(real* outputs, const real* inputs,
|
||||
int inH, int inW, int inC,
|
||||
int nthreads) {
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx < nthreads) {
|
||||
const int c = idx % inC;
|
||||
const int w = (idx / inC) % inW;
|
||||
const int h = (idx / inC / inW) % inH;
|
||||
const int n = idx / inW / inH / inC;
|
||||
|
||||
const int off = ((n * inC + c) * inH + h) * inW + w;
|
||||
outputs[off] = inputs[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void NHWC2NCHW<DEVICE_TYPE_GPU>(real* outputs,
|
||||
const real* inputs,
|
||||
const int num,
|
||||
const int inH,
|
||||
const int inW,
|
||||
const int inC) {
|
||||
int nth = num * inC * inH * inW;
|
||||
int blockSize = 1024;
|
||||
int gridSize = (nth + 1024 - 1) / 1024;
|
||||
KeNHWC2NCHW<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
|
||||
(outputs, inputs, inH, inW, inC, nth);
|
||||
CHECK_SYNC("NHWC2NCHW");
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,44 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "FunctionTest.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
TEST(Pad, real) {
|
||||
for (size_t numSamples : {1, 4, 8, 16}) {
|
||||
for (size_t channels : {1, 4, 8, 16}) {
|
||||
for (size_t imgSizeH : {1, 4, 8, 16}) {
|
||||
for (size_t imgSizeW : {1, 4, 8, 16}) {
|
||||
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
|
||||
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
|
||||
for (bool test_grad : {true, false}) {
|
||||
CpuGpuFuncCompare compare(test_grad ? "NHWC2NCHW" : "NCHW2NHWC",
|
||||
FuncConfig());
|
||||
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
|
||||
TensorShape outDims{numSamples, imgSizeH, imgSizeW, channels};
|
||||
compare.addInputs(
|
||||
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
|
||||
compare.addOutputs(BufferArg(
|
||||
VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
|
||||
compare.run();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,89 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "PixelSoftmaxLayer.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
REGISTER_LAYER(pixel_softmax, PixelSoftmaxLayer);
|
||||
|
||||
bool PixelSoftmaxLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
/* Initialize the basic parent class */
|
||||
Layer::init(layerMap, parameterMap);
|
||||
auto& img_conf = config_.inputs(0).image_conf();
|
||||
inH_ =
|
||||
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
|
||||
inW_ = img_conf.img_size();
|
||||
inC_ = img_conf.channels();
|
||||
createFunction(forward_, "NCHW2NHWC", FuncConfig());
|
||||
createFunction(backward_, "NHWC2NCHW", FuncConfig());
|
||||
inDims_ = TensorShape({0, inH_, inW_, inC_});
|
||||
outDims_ = TensorShape({0, inC_, inH_, inW_});
|
||||
return true;
|
||||
}
|
||||
|
||||
void PixelSoftmaxLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
MatrixPtr input = inputLayers_[0]->getOutputValue();
|
||||
size_t batchSize = input->getHeight();
|
||||
// cout<<"useGpu:"<<useGpu(deviceId_)<<endl;
|
||||
Matrix::resizeOrCreate(
|
||||
tmpInput_, batchSize * inH_ * inW_, inC_, false, useGpu_);
|
||||
Matrix::resizeOrCreate(
|
||||
tmpOutput_, batchSize * inH_ * inW_, inC_, false, useGpu_);
|
||||
tmpOutput_->zeroMem();
|
||||
resetOutput(batchSize, inH_ * inW_ * inC_);
|
||||
inDims_.setDim(0, batchSize);
|
||||
outDims_.setDim(0, batchSize);
|
||||
|
||||
// switch NCHW to NHWC
|
||||
BufferArgs inputs;
|
||||
BufferArgs outputs;
|
||||
inputs.addArg(*getInputValue(0), inDims_);
|
||||
outputs.addArg(*tmpInput_, outDims_);
|
||||
forward_[0]->calc(inputs, outputs);
|
||||
// softmax forward and save softmax result into tmpMatrix_
|
||||
tmpInput_->softmax(*tmpOutput_);
|
||||
|
||||
// switch NHWC to NCHW
|
||||
BufferArgs inputs_1;
|
||||
BufferArgs outputs_1;
|
||||
inputs_1.addArg(*tmpOutput_, outDims_);
|
||||
outputs_1.addArg(*getOutputValue(), inDims_);
|
||||
backward_[0]->calc(inputs_1, outputs_1);
|
||||
}
|
||||
|
||||
void PixelSoftmaxLayer::backward(const UpdateCallback& callback) {
|
||||
(void)callback;
|
||||
REGISTER_TIMER_INFO("PixelSoftmaxBackward", getName().c_str());
|
||||
|
||||
// switch NCHW to NHWC
|
||||
BufferArgs inputs;
|
||||
BufferArgs outputs;
|
||||
inputs.addArg(*getOutputGrad(), inDims_);
|
||||
outputs.addArg(*tmpInput_, outDims_);
|
||||
forward_[0]->calc(inputs, outputs);
|
||||
// softmax backward and save grad result into tmpOutput_
|
||||
tmpInput_->softmaxBackward(*tmpOutput_);
|
||||
|
||||
// switch NHWC to NCHW
|
||||
BufferArgs inputs_1;
|
||||
BufferArgs outputs_1;
|
||||
inputs_1.addArg(*tmpInput_, outDims_);
|
||||
outputs_1.addArg(*getInputGrad(0), inDims_);
|
||||
backward_[0]->calc(inputs_1, outputs_1);
|
||||
}
|
||||
} // namespace paddle
|
@ -0,0 +1,44 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Layer.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* \brief This layer calculate softmax in image channel dimension.
|
||||
*/
|
||||
class PixelSoftmaxLayer : public Layer {
|
||||
public:
|
||||
explicit PixelSoftmaxLayer(const LayerConfig& config) : Layer(config) {}
|
||||
|
||||
~PixelSoftmaxLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback = nullptr) override;
|
||||
|
||||
protected:
|
||||
uint32_t inC_;
|
||||
uint32_t inH_;
|
||||
uint32_t inW_;
|
||||
TensorShape inDims_;
|
||||
TensorShape outDims_;
|
||||
MatrixPtr tmpInput_;
|
||||
MatrixPtr tmpOutput_;
|
||||
};
|
||||
} // namespace paddle
|
Loading…
Reference in new issue