parent
cfde85bc52
commit
07f3f07ff3
@ -0,0 +1,155 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "ScaleSubRegionOp.h"
|
||||||
|
#include "paddle/function/TensorShape.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void ScaleSubRegion<DEVICE_TYPE_CPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const real* indices,
|
||||||
|
const TensorShape shape,
|
||||||
|
const FuncConfig& conf) {
|
||||||
|
real value = conf.get<real>("value");
|
||||||
|
|
||||||
|
int number = shape[0];
|
||||||
|
int channel = shape[1];
|
||||||
|
int height = shape[2];
|
||||||
|
int width = shape[3];
|
||||||
|
|
||||||
|
memcpy(outputs, inputs, number * channel * height * width * sizeof(real));
|
||||||
|
|
||||||
|
for (int n = 0; n < number; ++n) {
|
||||||
|
// indices start from 1
|
||||||
|
int offset = n * 6;
|
||||||
|
for (int c = indices[offset] - 1; c < indices[offset + 1]; ++c) {
|
||||||
|
for (int h = indices[offset + 2] - 1; h < indices[offset + 3]; ++h) {
|
||||||
|
for (int w = indices[offset + 4] - 1; w < indices[offset + 5]; ++w) {
|
||||||
|
int idx = ((n * channel + c) * height + h) * width + w;
|
||||||
|
outputs[idx] *= value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void ScaleSubRegionGrad<DEVICE_TYPE_CPU>(const real* inGrad,
|
||||||
|
real* outGrad,
|
||||||
|
const real* indices,
|
||||||
|
const TensorShape shape,
|
||||||
|
const FuncConfig& conf) {
|
||||||
|
real value = conf.get<real>("value");
|
||||||
|
|
||||||
|
int number = shape[0];
|
||||||
|
int channel = shape[1];
|
||||||
|
int height = shape[2];
|
||||||
|
int width = shape[3];
|
||||||
|
|
||||||
|
for (int n = 0; n < number; ++n) {
|
||||||
|
for (int c = 0; c < channel; ++c) {
|
||||||
|
for (int h = 0; h < height; ++h) {
|
||||||
|
for (int w = 0; w < width; ++w) {
|
||||||
|
int idx = ((n * channel + c) * height + h) * width + w;
|
||||||
|
int offset = n * 6;
|
||||||
|
if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
|
||||||
|
h >= (indices[offset + 2] - 1) &&
|
||||||
|
h <= (indices[offset + 3] - 1) &&
|
||||||
|
w >= (indices[offset + 4] - 1) &&
|
||||||
|
w <= (indices[offset + 5] - 1)) {
|
||||||
|
outGrad[idx] += inGrad[idx] * value;
|
||||||
|
} else {
|
||||||
|
outGrad[idx] += inGrad[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief For each instance, ScaleSubRegion can be used to multiply a value to
|
||||||
|
* a specified sub continuous region. By providing start index and end
|
||||||
|
* index for C/H/W, you can specify the location and shape of the region.
|
||||||
|
*
|
||||||
|
* Argument in this Function:
|
||||||
|
* \param inputs A 4-D tensor with shape [N, C, H, W], only one input.
|
||||||
|
* \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
|
||||||
|
* \param outputs A 4-D tensor with same shape as inputs, output value.
|
||||||
|
*/
|
||||||
|
template <DeviceType Device>
|
||||||
|
class ScaleSubRegionFunc : public FunctionBase {
|
||||||
|
public:
|
||||||
|
void init(const FuncConfig& config) override { conf_ = config; }
|
||||||
|
|
||||||
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||||
|
CHECK_EQ(2UL, inputs.size());
|
||||||
|
CHECK_EQ(1UL, outputs.size());
|
||||||
|
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
|
||||||
|
|
||||||
|
TensorShape shape = inputs[0].shape();
|
||||||
|
|
||||||
|
ScaleSubRegion<Device>(outputs[0].data<real>(),
|
||||||
|
inputs[0].data<real>(),
|
||||||
|
inputs[1].data<real>(),
|
||||||
|
shape,
|
||||||
|
conf_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FuncConfig conf_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief The backward propagation of ScaleSubRegion Function.
|
||||||
|
*
|
||||||
|
* Argument in this Function:
|
||||||
|
* \param inputs A 4-D tensor with shape [N, C, H, W], output gradient.
|
||||||
|
* \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
|
||||||
|
* \param outputs A 4-D tensor with shape [N, C, H, W], gradient of input value.
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <DeviceType Device>
|
||||||
|
class ScaleSubRegionGradFunc : public FunctionBase {
|
||||||
|
public:
|
||||||
|
void init(const FuncConfig& config) override { conf_ = config; }
|
||||||
|
|
||||||
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||||
|
CHECK_EQ(2UL, inputs.size());
|
||||||
|
CHECK_EQ(1UL, outputs.size());
|
||||||
|
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
||||||
|
|
||||||
|
TensorShape shape = inputs[0].shape();
|
||||||
|
|
||||||
|
ScaleSubRegionGrad<Device>(inputs[0].data<real>(),
|
||||||
|
outputs[0].data<real>(),
|
||||||
|
inputs[1].data<real>(),
|
||||||
|
shape,
|
||||||
|
conf_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FuncConfig conf_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_TYPED_FUNC(ScaleSubRegion, CPU, ScaleSubRegionFunc);
|
||||||
|
REGISTER_TYPED_FUNC(ScaleSubRegionGrad, CPU, ScaleSubRegionGradFunc);
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
REGISTER_TYPED_FUNC(ScaleSubRegion, GPU, ScaleSubRegionFunc);
|
||||||
|
REGISTER_TYPED_FUNC(ScaleSubRegionGrad, GPU, ScaleSubRegionGradFunc);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Function to multiply a value to values in specified sub continuous
|
||||||
|
* region. Indices must be provided to indcate the location and shape of
|
||||||
|
* the region and the multiplied value is passed by configure variable.
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* \param[out] outputs Output value.
|
||||||
|
* \param[in] inputs Input data which contains NCHW information.
|
||||||
|
* \param[in] indices Indices data to indcate the sub region.
|
||||||
|
* \param[in] shape Tensor shape of input value.
|
||||||
|
* \param[in] conf Configure variable which contains the multiplied value.
|
||||||
|
*/
|
||||||
|
template <DeviceType Device>
|
||||||
|
void ScaleSubRegion(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const real* indices,
|
||||||
|
const TensorShape shape,
|
||||||
|
const FuncConfig& conf);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Backward propagation function of ScaleSubRegion.
|
||||||
|
*
|
||||||
|
* \param[out] inGrad Gradients of previous layer.
|
||||||
|
* \param[in] outGrad Output gradient.
|
||||||
|
* \param[in] indices Indices data.
|
||||||
|
* \param[in] shape The Shape of input tensor.
|
||||||
|
* \param[in] conf Configure variable.
|
||||||
|
*/
|
||||||
|
template <DeviceType Device>
|
||||||
|
void ScaleSubRegionGrad(const real* inGrad,
|
||||||
|
real* outGrad,
|
||||||
|
const real* indices,
|
||||||
|
const TensorShape shape,
|
||||||
|
const FuncConfig& conf);
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,116 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "ScaleSubRegionOp.h"
|
||||||
|
#include "hl_base.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
__global__ void KeScaleSubRegion(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const real* indices,
|
||||||
|
real value,
|
||||||
|
int channel,
|
||||||
|
int height,
|
||||||
|
int width,
|
||||||
|
int nthreads) {
|
||||||
|
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (idx < nthreads) {
|
||||||
|
const int w = idx % width;
|
||||||
|
const int h = (idx / width) % height;
|
||||||
|
const int c = (idx / width / height) % channel;
|
||||||
|
const int n = idx / width / height / channel;
|
||||||
|
|
||||||
|
const int offset = n * 6;
|
||||||
|
if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
|
||||||
|
h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
|
||||||
|
w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
|
||||||
|
outputs[idx] = inputs[idx] * value;
|
||||||
|
} else {
|
||||||
|
outputs[idx] = inputs[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void ScaleSubRegion<DEVICE_TYPE_GPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const real* indices,
|
||||||
|
const TensorShape shape,
|
||||||
|
const FuncConfig& conf) {
|
||||||
|
real value = conf.get<real>("value");
|
||||||
|
|
||||||
|
int number = shape[0];
|
||||||
|
int channel = shape[1];
|
||||||
|
int height = shape[2];
|
||||||
|
int width = shape[3];
|
||||||
|
|
||||||
|
size_t nth = number * channel * height * width;
|
||||||
|
int blockSize = 1024;
|
||||||
|
int gridSize = (nth + blockSize - 1) / blockSize;
|
||||||
|
|
||||||
|
KeScaleSubRegion<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
|
||||||
|
outputs, inputs, indices, value, channel, height, width, nth);
|
||||||
|
CHECK_SYNC("ScaleSubRegion");
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void KeScaleSubRegionDiff(const real* inGrad,
|
||||||
|
real* outGrad,
|
||||||
|
const real* indices,
|
||||||
|
real value,
|
||||||
|
int channel,
|
||||||
|
int height,
|
||||||
|
int width,
|
||||||
|
int nthreads) {
|
||||||
|
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (idx < nthreads) {
|
||||||
|
const int w = idx % width;
|
||||||
|
const int h = (idx / width) % height;
|
||||||
|
const int c = (idx / width / height) % channel;
|
||||||
|
const int n = idx / width / height / channel;
|
||||||
|
|
||||||
|
const int offset = n * 6;
|
||||||
|
if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
|
||||||
|
h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
|
||||||
|
w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
|
||||||
|
outGrad[idx] += inGrad[idx] * value;
|
||||||
|
} else {
|
||||||
|
outGrad[idx] += inGrad[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void ScaleSubRegionGrad<DEVICE_TYPE_GPU>(const real* inGrad,
|
||||||
|
real* outGrad,
|
||||||
|
const real* indices,
|
||||||
|
const TensorShape shape,
|
||||||
|
const FuncConfig& conf) {
|
||||||
|
real value = conf.get<real>("value");
|
||||||
|
|
||||||
|
int number = shape[0];
|
||||||
|
int channel = shape[1];
|
||||||
|
int height = shape[2];
|
||||||
|
int width = shape[3];
|
||||||
|
|
||||||
|
size_t nth = number * channel * height * width;
|
||||||
|
int blockSize = 1024;
|
||||||
|
int gridSize = (nth + blockSize - 1) / blockSize;
|
||||||
|
|
||||||
|
KeScaleSubRegionDiff<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
|
||||||
|
inGrad, outGrad, indices, value, channel, height, width, nth);
|
||||||
|
CHECK_SYNC("ScaleSubRegionGrad");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,72 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "FunctionTest.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
TEST(ScaleSubRegion, real) {
|
||||||
|
for (size_t numSamples : {5, 32}) {
|
||||||
|
for (size_t channels : {5, 5, 32}) {
|
||||||
|
for (size_t imgSizeH : {5, 33, 100}) {
|
||||||
|
for (size_t imgSizeW : {5, 32, 96}) {
|
||||||
|
for (real value : {-0.5, 0.0, 0.5}) {
|
||||||
|
for (bool firstHalf : {false, true}) {
|
||||||
|
VLOG(3) << " numSamples=" << numSamples
|
||||||
|
<< " channels=" << channels << " imgSizeH=" << imgSizeH
|
||||||
|
<< " imgSizeW=" << imgSizeW;
|
||||||
|
|
||||||
|
for (bool testGrad : {false, true}) {
|
||||||
|
CpuGpuFuncCompare compare(
|
||||||
|
testGrad ? "ScaleSubRegionGrad" : "ScaleSubRegion",
|
||||||
|
FuncConfig().set<real>("value", value));
|
||||||
|
|
||||||
|
TensorShape shape{numSamples, channels, imgSizeH, imgSizeW};
|
||||||
|
TensorShape indicesShape{numSamples, 6};
|
||||||
|
|
||||||
|
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape));
|
||||||
|
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, indicesShape));
|
||||||
|
|
||||||
|
compare.registerInitCallback([=](BufferArg& arg, size_t index) {
|
||||||
|
if (index == 1) {
|
||||||
|
real* data = (real*)arg.data();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < numSamples; ++i) {
|
||||||
|
size_t offset = i * 6;
|
||||||
|
data[offset] = firstHalf ? 1 : channels / 2;
|
||||||
|
data[offset + 1] = firstHalf ? channels / 2 : channels;
|
||||||
|
data[offset + 2] = firstHalf ? 1 : imgSizeH / 2;
|
||||||
|
data[offset + 3] = firstHalf ? imgSizeH / 2 : imgSizeH;
|
||||||
|
data[offset + 4] = firstHalf ? 1 : imgSizeW / 2;
|
||||||
|
data[offset + 5] = firstHalf ? imgSizeW / 2 : imgSizeW;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
compare.addOutputs(
|
||||||
|
BufferArg(
|
||||||
|
VALUE_TYPE_FLOAT, shape, testGrad ? ADD_TO : ASSIGN_TO),
|
||||||
|
testGrad ? ADD_TO : ASSIGN_TO);
|
||||||
|
compare.run();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "ScaleSubRegionLayer.h"
|
||||||
|
#include "paddle/utils/Stat.h"
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
REGISTER_LAYER(scale_sub_region, ScaleSubRegionLayer);
|
||||||
|
|
||||||
|
bool ScaleSubRegionLayer::init(const LayerMap& layerMap,
|
||||||
|
const ParameterMap& parameterMap) {
|
||||||
|
Layer::init(layerMap, parameterMap);
|
||||||
|
CHECK_EQ(static_cast<int>(inputLayers_.size()), 2);
|
||||||
|
auto& conf = config_.inputs(0).scale_sub_region_conf();
|
||||||
|
value_ = conf.value();
|
||||||
|
|
||||||
|
createFunction(forward_, "ScaleSubRegion", FuncConfig().set("value", value_));
|
||||||
|
createFunction(
|
||||||
|
backward_, "ScaleSubRegionGrad", FuncConfig().set("value", value_));
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ScaleSubRegionLayer::forward(PassType passType) {
|
||||||
|
Layer::forward(passType);
|
||||||
|
auto in0 = getInput(0);
|
||||||
|
imgH_ = in0.getFrameHeight();
|
||||||
|
imgW_ = in0.getFrameWidth();
|
||||||
|
if (imgH_ == 0 || imgW_ == 0) {
|
||||||
|
auto& conf = config_.inputs(0).scale_sub_region_conf();
|
||||||
|
imgH_ = conf.image_conf().img_size_y();
|
||||||
|
imgW_ = conf.image_conf().img_size();
|
||||||
|
}
|
||||||
|
MatrixPtr imgV = in0.value;
|
||||||
|
size_t batchSize = imgV->getHeight();
|
||||||
|
size_t spatialSize = imgH_ * imgW_;
|
||||||
|
channelsNum_ = imgV->getWidth() / spatialSize;
|
||||||
|
shape_ = TensorShape({batchSize, channelsNum_, imgH_, imgW_});
|
||||||
|
|
||||||
|
resetOutput(batchSize, imgV->getWidth());
|
||||||
|
auto out = getOutput();
|
||||||
|
out.setFrameHeight(imgH_);
|
||||||
|
out.setFrameWidth(imgW_);
|
||||||
|
|
||||||
|
MatrixPtr indicesV = getInputValue(1);
|
||||||
|
indicesShape_ = TensorShape({batchSize, 6});
|
||||||
|
|
||||||
|
REGISTER_TIMER_INFO("ScaleSubRegionForward", getName().c_str());
|
||||||
|
BufferArgs inArgs;
|
||||||
|
BufferArgs outArgs;
|
||||||
|
inArgs.addArg(*imgV, shape_);
|
||||||
|
inArgs.addArg(*indicesV, indicesShape_);
|
||||||
|
outArgs.addArg(*out.value, shape_, ASSIGN_TO);
|
||||||
|
forward_[0]->calc(inArgs, outArgs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ScaleSubRegionLayer::backward(const UpdateCallback& callback) {
|
||||||
|
REGISTER_TIMER_INFO("ScaleSubRegionBackward", getName().c_str());
|
||||||
|
BufferArgs inArgs;
|
||||||
|
BufferArgs outArgs;
|
||||||
|
inArgs.addArg(*getOutputGrad(), shape_);
|
||||||
|
inArgs.addArg(*getInputValue(1), indicesShape_);
|
||||||
|
outArgs.addArg(*getInputGrad(0), shape_, ADD_TO);
|
||||||
|
backward_[0]->calc(inArgs, outArgs);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Layer.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief For each instance, this layer can be used to multiply a value to a
|
||||||
|
* specified sub continuous region. By providing start index and end
|
||||||
|
* index for C/H/W, you can specify the location and shape of the
|
||||||
|
* region.
|
||||||
|
*
|
||||||
|
* input_0: Input value.
|
||||||
|
* input_1: Indices value to specify the location an shape of the
|
||||||
|
* region.
|
||||||
|
*/
|
||||||
|
class ScaleSubRegionLayer : public Layer {
|
||||||
|
public:
|
||||||
|
explicit ScaleSubRegionLayer(const LayerConfig& config) : Layer(config) {}
|
||||||
|
|
||||||
|
~ScaleSubRegionLayer() {}
|
||||||
|
|
||||||
|
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
|
||||||
|
|
||||||
|
void forward(PassType passType);
|
||||||
|
|
||||||
|
void backward(const UpdateCallback& callback = nullptr);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TensorShape shape_;
|
||||||
|
TensorShape indicesShape_;
|
||||||
|
size_t imgH_;
|
||||||
|
size_t imgW_;
|
||||||
|
size_t channelsNum_;
|
||||||
|
real value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,51 @@
|
|||||||
|
type: "nn"
|
||||||
|
layers {
|
||||||
|
name: "data"
|
||||||
|
type: "data"
|
||||||
|
size: 2016
|
||||||
|
active_type: ""
|
||||||
|
height: 48
|
||||||
|
width: 42
|
||||||
|
}
|
||||||
|
layers {
|
||||||
|
name: "indices"
|
||||||
|
type: "data"
|
||||||
|
size: 6
|
||||||
|
active_type: ""
|
||||||
|
}
|
||||||
|
layers {
|
||||||
|
name: "__scale_sub_region_0__"
|
||||||
|
type: "scale_sub_region"
|
||||||
|
size: 2016
|
||||||
|
active_type: ""
|
||||||
|
inputs {
|
||||||
|
input_layer_name: "data"
|
||||||
|
scale_sub_region_conf {
|
||||||
|
image_conf {
|
||||||
|
channels: 1
|
||||||
|
img_size: 42
|
||||||
|
img_size_y: 48
|
||||||
|
}
|
||||||
|
value: 0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputs {
|
||||||
|
input_layer_name: "indices"
|
||||||
|
}
|
||||||
|
height: 48
|
||||||
|
width: 42
|
||||||
|
}
|
||||||
|
input_layer_names: "data"
|
||||||
|
input_layer_names: "indices"
|
||||||
|
output_layer_names: "__scale_sub_region_0__"
|
||||||
|
sub_models {
|
||||||
|
name: "root"
|
||||||
|
layer_names: "data"
|
||||||
|
layer_names: "indices"
|
||||||
|
layer_names: "__scale_sub_region_0__"
|
||||||
|
input_layer_names: "data"
|
||||||
|
input_layer_names: "indices"
|
||||||
|
output_layer_names: "__scale_sub_region_0__"
|
||||||
|
is_recurrent_layer_group: false
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,11 @@
|
|||||||
|
from paddle.trainer_config_helpers import *
|
||||||
|
|
||||||
|
settings(batch_size=1000, learning_rate=1e-5)
|
||||||
|
|
||||||
|
data = data_layer(name='data', size=2016, height=48, width=42)
|
||||||
|
indices = data_layer(name='indices', size=6)
|
||||||
|
|
||||||
|
scale_sub_region = scale_sub_region_layer(
|
||||||
|
input=data, indices=indices, value=0.0)
|
||||||
|
|
||||||
|
outputs(scale_sub_region)
|
Loading…
Reference in new issue