commit
ec236f4624
@ -0,0 +1,84 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// EigenDim converts paddle::platform::DDim into Eigen::DSizes.
|
||||
template <int D>
|
||||
struct EigenDim {
|
||||
using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
|
||||
|
||||
static Type From(const DDim& dims) {
|
||||
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
|
||||
Type ret;
|
||||
for (int d = 0; d < arity(dims); d++) {
|
||||
ret[d] = dims[d];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor.
|
||||
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
struct EigenTensor {
|
||||
// TODO(qijun) Now, default type in unaligned, and we will make a benchmark on
|
||||
// the speed of aligned and unaligned version in future.
|
||||
using Type = Eigen::TensorMap<Eigen::Tensor<T, D, MajorType, IndexType>>;
|
||||
|
||||
using ConstType =
|
||||
Eigen::TensorMap<Eigen::Tensor<const T, D, MajorType, IndexType>>;
|
||||
|
||||
static Type From(Tensor& tensor, DDim dims) {
|
||||
return Type(tensor.data<T>(), EigenDim<D>::From(dims));
|
||||
}
|
||||
|
||||
static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); }
|
||||
|
||||
static ConstType From(const Tensor& tensor, DDim dims) {
|
||||
return ConstType(tensor.data<T>(), EigenDim<D>::From(dims));
|
||||
}
|
||||
|
||||
static ConstType From(const Tensor& tensor) {
|
||||
return From(tensor, tensor.dims_);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
|
||||
// Flatten is to reshape a Tensor into a one dimension EigenVector
|
||||
static typename EigenTensor<T, 1>::Type Flatten(Tensor& tensor) {
|
||||
return EigenTensor<T, 1>::From(
|
||||
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
|
||||
}
|
||||
|
||||
static typename EigenTensor<T, 1>::ConstType Flatten(const Tensor& tensor) {
|
||||
return EigenTensor<T, 1>::From(
|
||||
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = EigenTensor<T, 2, MajorType, IndexType>;
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,101 @@
|
||||
/*
|
||||
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
TEST(EigenDim, From) {
|
||||
EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3}));
|
||||
ASSERT_EQ(1, ed[0]);
|
||||
ASSERT_EQ(2, ed[1]);
|
||||
ASSERT_EQ(3, ed[2]);
|
||||
}
|
||||
|
||||
TEST(Eigen, Tensor) {
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
|
||||
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
|
||||
|
||||
ASSERT_EQ(1, et.dimension(0));
|
||||
ASSERT_EQ(2, et.dimension(1));
|
||||
ASSERT_EQ(3, et.dimension(2));
|
||||
|
||||
for (int i = 0; i < 1; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
for (int k = 0; k < 3; k++) {
|
||||
ASSERT_NEAR((i * 2 + j) * 3 + k, et(i, j, k), 1e-6f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Eigen, VectorFrom) {
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({6}), platform::CPUPlace());
|
||||
for (int i = 0; i < 6; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenVector<float>::Type ev = EigenVector<float>::From(t);
|
||||
|
||||
ASSERT_EQ(6, ev.dimension(0));
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
ASSERT_NEAR(i, ev(i), 1e-6f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Eigen, VectorFlatten) {
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
|
||||
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenVector<float>::Type ev = EigenVector<float>::Flatten(t);
|
||||
|
||||
ASSERT_EQ(1 * 2 * 3, ev.dimension(0));
|
||||
|
||||
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||
ASSERT_NEAR(i, ev(i), 1e-6f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Eigen, Matrix) {
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({2, 3}), platform::CPUPlace());
|
||||
for (int i = 0; i < 2 * 3; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenMatrix<float>::Type em = EigenMatrix<float>::From(t);
|
||||
|
||||
ASSERT_EQ(2, em.dimension(0));
|
||||
ASSERT_EQ(3, em.dimension(1));
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
ASSERT_NEAR(i * 3 + j, em(i, j), 1e-6f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,15 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/enforce.h"
|
@ -1,75 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <glog/logging.h>
|
||||
#include <paddle/string/printf.h>
|
||||
#include <exception>
|
||||
#include <sstream>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
/**
|
||||
* @brief Enforce exception. Inherits std::exception
|
||||
*
|
||||
* All enforce condition not met, will throw an EnforceNotMet exception.
|
||||
*/
|
||||
class EnforceNotMet : public std::exception {
|
||||
public:
|
||||
EnforceNotMet(const std::string& msg, const char* file, int fileline) {
|
||||
std::ostringstream sout;
|
||||
sout << msg << " at [" << file << ":" << fileline << "];";
|
||||
all_msg_ = sout.str();
|
||||
}
|
||||
|
||||
const char* what() const noexcept override { return all_msg_.c_str(); }
|
||||
|
||||
private:
|
||||
std::string all_msg_;
|
||||
};
|
||||
|
||||
// From https://stackoverflow.com/questions/30130930/
|
||||
// __buildin_expect is in C++ 11 standard. Since the condition which enforced
|
||||
// should be true in most situation, it will make the compiler generate faster
|
||||
// code by adding `UNLIKELY` macro.
|
||||
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
|
||||
|
||||
/**
|
||||
* @brief Throw a EnforceNotMet exception, automatically filled __FILE__ &
|
||||
* __LINE__
|
||||
*
|
||||
* This macro take __VA_ARGS__, user can pass any type if that type can
|
||||
* serialize to std::ostream
|
||||
*/
|
||||
#define PADDLE_THROW(...) \
|
||||
do { \
|
||||
throw ::paddle::framework::EnforceNotMet( \
|
||||
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Enforce a condition, otherwise throw an EnforceNotMet
|
||||
*/
|
||||
#ifdef NDEBUG
|
||||
#define PADDLE_ENFORCE(condition, ...) \
|
||||
do { \
|
||||
if (UNLIKELY(!(condition))) { \
|
||||
PADDLE_THROW(__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#else
|
||||
#define PADDLE_ENFORCE(condition, ...) \
|
||||
CHECK(condition) << ::paddle::string::Sprintf(__VA_ARGS__);
|
||||
#endif
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,67 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// Helper to define Tensor types given that the scalar is of type T.
|
||||
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
|
||||
struct TTypes {
|
||||
// Rank-<NDIMS> tensor of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Tensor;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstTensor;
|
||||
|
||||
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Scalar;
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
|
||||
Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
ConstScalar;
|
||||
|
||||
// Rank-1 tensor (vector) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Flat;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Vec;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstVec;
|
||||
|
||||
// Rank-2 tensor (matrix) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Matrix;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstMatrix;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,177 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "CropOp.h"
|
||||
#include "paddle/function/TensorShape.h"
|
||||
#include "paddle/math/Vector.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
template <>
|
||||
void Crop<DEVICE_TYPE_CPU>(real* outputs,
|
||||
const real* inputs,
|
||||
const TensorShape inShape,
|
||||
const TensorShape outShape,
|
||||
const FuncConfig& conf) {
|
||||
std::vector<uint32_t> crop_corner =
|
||||
conf.get<std::vector<uint32_t>>("crop_corner");
|
||||
int cCrop = crop_corner[1];
|
||||
int hCrop = crop_corner[2];
|
||||
int wCrop = crop_corner[3];
|
||||
|
||||
int num = inShape[0];
|
||||
int inC = inShape[1];
|
||||
int inH = inShape[2];
|
||||
int inW = inShape[3];
|
||||
|
||||
int outC = outShape[1];
|
||||
int outH = outShape[2];
|
||||
int outW = outShape[3];
|
||||
|
||||
for (int n = 0; n < num; n++) {
|
||||
for (int c = 0; c < outC; c++) {
|
||||
for (int h = 0; h < outH; h++) {
|
||||
int outoff = ((n * outC + c) * outH + h) * outW;
|
||||
int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop;
|
||||
memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
|
||||
real* outGrad,
|
||||
const TensorShape inShape,
|
||||
const TensorShape outShape,
|
||||
const FuncConfig& conf) {
|
||||
std::vector<uint32_t> crop_corner =
|
||||
conf.get<std::vector<uint32_t>>("crop_corner");
|
||||
int cCrop = crop_corner[1];
|
||||
int hCrop = crop_corner[2];
|
||||
int wCrop = crop_corner[3];
|
||||
|
||||
int num = outShape[0];
|
||||
int outC = outShape[1];
|
||||
int outH = outShape[2];
|
||||
int outW = outShape[3];
|
||||
|
||||
int inC = inShape[1];
|
||||
int inH = inShape[2];
|
||||
int inW = inShape[3];
|
||||
|
||||
for (int n = 0; n < num; n++) {
|
||||
for (int c = 0; c < inC; c++) {
|
||||
for (int h = 0; h < inH; h++) {
|
||||
int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop;
|
||||
int inoff = ((n * inC + c) * inH + h) * inW;
|
||||
CpuVector inG = CpuVector(inW, const_cast<real*>(inGrad + inoff));
|
||||
CpuVector outG = CpuVector(inW, outGrad + outoff);
|
||||
outG += inG;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Crop input according to the specify corner and shape.
|
||||
* The input and output is a 4D tensor. In CropFunc, we only
|
||||
* crop the 2nd to 4th dimension.
|
||||
*
|
||||
* Argument in this Function:
|
||||
* \param pad_ A struct object contains the cropping corner and shape.
|
||||
* \param inputs A 4D tensor, only one input.
|
||||
* \param outputs A 4D tensor, the output value after cropping.
|
||||
*
|
||||
* For example,
|
||||
* Input(2,2,2,3) = [
|
||||
* [ [[1,2,3], [3,4,5]],
|
||||
* [[2,3,5], [1,6,7]] ],
|
||||
* [ [[4,3,1], [1,8,7]],
|
||||
* [[3,8,9], [2,3,5]] ]
|
||||
* ] # the input shape is (2,2,2,3)
|
||||
*
|
||||
* pad_: if corner = (0,1,1) and crop_shape = (2,1,2)
|
||||
* Output(2,2,1,2) = [
|
||||
* [ [[4,5]],
|
||||
* [[6,7]] ],
|
||||
* [ [[8,7]],
|
||||
* [[3,5]] ]
|
||||
* ] # the input shape is (2,2,2,3)
|
||||
*/
|
||||
template <DeviceType Device>
|
||||
class CropFunc : public FunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override { conf_ = config; }
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
CHECK_EQ(1UL, inputs.size());
|
||||
CHECK_EQ(1UL, outputs.size());
|
||||
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
|
||||
|
||||
TensorShape inShape = inputs[0].shape();
|
||||
TensorShape outShape = outputs[0].shape();
|
||||
|
||||
Crop<Device>(outputs[0].data<real>(),
|
||||
inputs[0].data<real>(),
|
||||
inShape,
|
||||
outShape,
|
||||
conf_);
|
||||
}
|
||||
|
||||
private:
|
||||
FuncConfig conf_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief The backward propagation of cropping Function.
|
||||
*
|
||||
* Argument in this Function:
|
||||
* \param crop_ The same meaning as it in CropFunc.
|
||||
* \param inputs The gradient with respect to the output value of CropFunc.
|
||||
* \param outputs The gradient with respect to the input value of CropFunc.
|
||||
*/
|
||||
|
||||
template <DeviceType Device>
|
||||
class CropGradFunc : public FunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override { conf_ = config; }
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
CHECK_EQ(1UL, inputs.size());
|
||||
CHECK_EQ(1UL, outputs.size());
|
||||
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
||||
|
||||
TensorShape outShape = outputs[0].shape();
|
||||
TensorShape inShape = inputs[0].shape();
|
||||
|
||||
CropGrad<Device>(inputs[0].data<real>(),
|
||||
outputs[0].data<real>(),
|
||||
inShape,
|
||||
outShape,
|
||||
conf_);
|
||||
}
|
||||
|
||||
private:
|
||||
FuncConfig conf_;
|
||||
};
|
||||
|
||||
REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
|
||||
REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc);
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
REGISTER_TYPED_FUNC(Crop, GPU, CropFunc);
|
||||
REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc);
|
||||
#endif
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Function.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* \brief This funtion crops inputs according to the specify start point and
|
||||
*shape.
|
||||
*
|
||||
* \param[out] outputs save results.
|
||||
* \param[in] inputs input data.
|
||||
* \param[in] inShape the shape of input tensor.
|
||||
* \param[in] conf the cropping config
|
||||
*/
|
||||
template <DeviceType Device>
|
||||
void Crop(real* outputs,
|
||||
const real* inputs,
|
||||
const TensorShape inShape,
|
||||
const TensorShape outShape,
|
||||
const FuncConfig& conf);
|
||||
|
||||
/**
|
||||
* \brief Cropping operation backward.
|
||||
*
|
||||
* \param[out] inGrad gradients of previous layer
|
||||
* \param[in] outGrad output gradient
|
||||
* \param[in] inShape the shape of input tensor.
|
||||
* \param[in] conf the cropping config
|
||||
*/
|
||||
template <DeviceType Device>
|
||||
void CropGrad(const real* inGrad,
|
||||
real* outGrad,
|
||||
const TensorShape inShape,
|
||||
const TensorShape outShape,
|
||||
const FuncConfig& conf);
|
||||
} // namespace paddle
|
@ -0,0 +1,113 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "hl_base.h"
|
||||
#include "CropOp.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
__global__ void KeCrop(real* outputs, const real* inputs,
|
||||
int inC, int inH, int inW,
|
||||
int cropC, int cropH, int cropW,
|
||||
int outC, int outH, int outW, int nthreads) {
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx < nthreads) {
|
||||
const int w = idx % outW;
|
||||
const int h = (idx / outW) % outH;
|
||||
const int c = (idx / outW / outH) % outC;
|
||||
const int n = idx / outW / outH / outC;
|
||||
|
||||
const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w;
|
||||
outputs[idx] = inputs[off];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void Crop<DEVICE_TYPE_GPU>(real* outputs,
|
||||
const real* inputs,
|
||||
const TensorShape inShape,
|
||||
const TensorShape outShape,
|
||||
const FuncConfig& conf) {
|
||||
std::vector<uint32_t> crop_corner = conf.get<std::vector<uint32_t>>("crop_corner");
|
||||
int cropC = crop_corner[1];
|
||||
int cropH = crop_corner[2];
|
||||
int cropW = crop_corner[3];
|
||||
|
||||
int num = inShape[0];
|
||||
int inC = inShape[1];
|
||||
int inH = inShape[2];
|
||||
int inW = inShape[3];
|
||||
|
||||
int outC = outShape[1];
|
||||
int outH = outShape[2];
|
||||
int outW = outShape[3];
|
||||
|
||||
size_t nth = num * outC * outH * outW;
|
||||
int blockSize = 1024;
|
||||
int gridSize = (nth + blockSize - 1) / blockSize;
|
||||
|
||||
KeCrop<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
|
||||
(outputs, inputs, inC, inH, inW, cropC, cropH, cropW,
|
||||
outC, outH, outW, nth);
|
||||
CHECK_SYNC("Crop");
|
||||
}
|
||||
|
||||
__global__ void KeCropDiff(const real* inGrad, real* outGrad,
|
||||
int inC, int inH, int inW,
|
||||
int cropC, int cropH, int cropW,
|
||||
int outC, int outH, int outW, int nthreads) {
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx < nthreads) {
|
||||
const int w = idx % inW;
|
||||
const int h = (idx / inW) % inH;
|
||||
const int c = (idx / inW / inH) % inC;
|
||||
const int n = idx / inW / inH / inC;
|
||||
|
||||
const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w;
|
||||
|
||||
outGrad[off] += inGrad[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void CropGrad<DEVICE_TYPE_GPU>(const real* inGrad,
|
||||
real* outGrad,
|
||||
const TensorShape inShape,
|
||||
const TensorShape outShape,
|
||||
const FuncConfig& conf) {
|
||||
std::vector<uint32_t> crop_corner = conf.get<std::vector<uint32_t>>("crop_corner");
|
||||
int cropC = crop_corner[1];
|
||||
int cropH = crop_corner[2];
|
||||
int cropW = crop_corner[3];
|
||||
|
||||
int num = outShape[0];
|
||||
int outC = outShape[1];
|
||||
int outH = outShape[2];
|
||||
int outW = outShape[3];
|
||||
|
||||
int inC = inShape[1];
|
||||
int inH = inShape[2];
|
||||
int inW = inShape[3];
|
||||
|
||||
size_t nth = num * inC * inH * inW;
|
||||
int blockSize = 1024;
|
||||
int gridSize = (nth + blockSize - 1) / blockSize;
|
||||
|
||||
KeCropDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
|
||||
(inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW,
|
||||
outC, outH, outW, nth);
|
||||
CHECK_SYNC("CropGrad");
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,49 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "FunctionTest.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
TEST(Crop, real) {
|
||||
for (size_t numSamples : {5, 32}) {
|
||||
for (size_t channels : {5, 5, 32}) {
|
||||
for (size_t imgSizeH : {5, 33, 100}) {
|
||||
for (size_t imgSizeW : {5, 32, 96}) {
|
||||
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
|
||||
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
|
||||
for (bool test_grad : {false, true}) {
|
||||
CpuGpuFuncCompare compare(
|
||||
test_grad ? "CropGrad" : "Crop",
|
||||
FuncConfig()
|
||||
.set<std::vector<uint32_t>>("crop_corner", {0, 1, 1, 1})
|
||||
.set<std::vector<uint32_t>>("crop_shape", {0, 2, 3, 3}));
|
||||
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
|
||||
TensorShape outDims{numSamples, 2, 3, 3};
|
||||
compare.addInputs(
|
||||
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
|
||||
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT,
|
||||
test_grad ? inDims : outDims,
|
||||
test_grad ? ADD_TO : ASSIGN_TO),
|
||||
test_grad ? ADD_TO : ASSIGN_TO);
|
||||
compare.run();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,146 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "CropLayer.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
namespace paddle {
|
||||
|
||||
REGISTER_LAYER(crop, CropLayer);
|
||||
|
||||
bool CropLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
/* Initialize the basic parent class */
|
||||
Layer::init(layerMap, parameterMap);
|
||||
CHECK_LE(static_cast<int>(inputLayers_.size()), 2);
|
||||
CHECK_GE(static_cast<int>(inputLayers_.size()), 1);
|
||||
crop_axis_ = config_.axis();
|
||||
for (int i = 0; i < config_.offset_size(); i++) {
|
||||
crop_offsets_.push_back(config_.offset(i));
|
||||
}
|
||||
|
||||
// 1. get input_0 shape
|
||||
auto& input0_img_conf = config_.inputs(0).image_conf();
|
||||
inDims_ = TensorShape({0,
|
||||
input0_img_conf.channels(),
|
||||
input0_img_conf.has_img_size_y()
|
||||
? input0_img_conf.img_size_y()
|
||||
: input0_img_conf.img_size(),
|
||||
input0_img_conf.img_size()});
|
||||
// 2. get target dims from config
|
||||
if (config_.inputs_size() == 1) {
|
||||
targetDims_ = TensorShape({config_.shape(0),
|
||||
config_.shape(1),
|
||||
config_.shape(2),
|
||||
config_.shape(3)});
|
||||
} else {
|
||||
// 2. get input_1 shape
|
||||
auto& input1_img_conf = config_.inputs(1).image_conf();
|
||||
targetDims_ = TensorShape({0,
|
||||
input1_img_conf.channels(),
|
||||
input1_img_conf.has_img_size_y()
|
||||
? input1_img_conf.img_size_y()
|
||||
: input1_img_conf.img_size(),
|
||||
input1_img_conf.img_size()});
|
||||
}
|
||||
|
||||
// 3. get final crop corner
|
||||
int dimSize = 4;
|
||||
crop_corner_ = {0, 0, 0, 0};
|
||||
for (int i = 0; i < dimSize; i++) {
|
||||
if (i >= crop_axis_) {
|
||||
if (crop_offsets_.size() > 1) {
|
||||
crop_corner_[i] = crop_offsets_[i - crop_axis_];
|
||||
} else {
|
||||
crop_corner_[i] = crop_offsets_[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outDims_ = TensorShape(4);
|
||||
|
||||
createFunction(
|
||||
forward_, "Crop", FuncConfig().set("crop_corner", crop_corner_));
|
||||
createFunction(
|
||||
backward_, "CropGrad", FuncConfig().set("crop_corner", crop_corner_));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CropLayer::setOutDims() {
|
||||
MatrixPtr input = inputLayers_[1]->getOutputValue();
|
||||
size_t batchSize = input->getHeight();
|
||||
// get target dims from input_1
|
||||
if (config_.inputs_size() == 2) {
|
||||
targetDims_.setDim(0, batchSize);
|
||||
int ch = config_.inputs(0).image_conf().channels();
|
||||
if (ch != 0) targetDims_.setDim(1, ch);
|
||||
int h = inputLayers_[1]->getOutput().getFrameHeight();
|
||||
if (h != 0) targetDims_.setDim(2, h);
|
||||
int w = inputLayers_[1]->getOutput().getFrameWidth();
|
||||
if (w != 0) targetDims_.setDim(3, w);
|
||||
}
|
||||
// get final crop shape from target dims and crop axis
|
||||
std::vector<uint32_t> crop_shape;
|
||||
int dimSize = 4;
|
||||
for (int i = 0; i < dimSize; i++) {
|
||||
if (i >= crop_axis_) {
|
||||
crop_shape.push_back(targetDims_[i]);
|
||||
} else {
|
||||
crop_shape.push_back(inDims_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
outDims_.reshape(
|
||||
{crop_shape[0], crop_shape[1], crop_shape[2], crop_shape[3]});
|
||||
output_.setFrameHeight(crop_shape[2]);
|
||||
output_.setFrameWidth(crop_shape[3]);
|
||||
}
|
||||
|
||||
void CropLayer::setInDims() {
|
||||
MatrixPtr input = inputLayers_[0]->getOutputValue();
|
||||
size_t batchSize = input->getHeight();
|
||||
inDims_.setDim(0, batchSize);
|
||||
int h = inputLayers_[0]->getOutput().getFrameHeight();
|
||||
if (h != 0) inDims_.setDim(2, h);
|
||||
int w = inputLayers_[0]->getOutput().getFrameWidth();
|
||||
if (w != 0) inDims_.setDim(3, w);
|
||||
}
|
||||
|
||||
void CropLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
setInDims();
|
||||
setOutDims();
|
||||
int size = outDims_[1] * outDims_[2] * outDims_[3];
|
||||
resetOutput(outDims_[0], size);
|
||||
MatrixPtr outV = getOutputValue();
|
||||
REGISTER_TIMER_INFO("CropForward", getName().c_str());
|
||||
|
||||
BufferArgs inputs;
|
||||
BufferArgs outputs;
|
||||
inputs.addArg(*getInputValue(0), inDims_);
|
||||
outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO);
|
||||
forward_[0]->calc(inputs, outputs);
|
||||
}
|
||||
|
||||
void CropLayer::backward(const UpdateCallback& callback) {
|
||||
(void)callback;
|
||||
REGISTER_TIMER_INFO("CropBackward", getName().c_str());
|
||||
|
||||
BufferArgs inputs;
|
||||
BufferArgs outputs;
|
||||
inputs.addArg(*getOutputGrad(), outDims_);
|
||||
outputs.addArg(*getInputGrad(0), inDims_, ADD_TO);
|
||||
backward_[0]->calc(inputs, outputs);
|
||||
}
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Layer.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* \brief This layer crop input according to the specify conf.
|
||||
* input_0: input to be cropped
|
||||
* input_1: optional reference input
|
||||
* axis: start dimension to be croped
|
||||
* offset: offset of cropping in each dimension
|
||||
* shape: if reference input layer was not setted,
|
||||
* crop input as this shape conf
|
||||
*/
|
||||
class CropLayer : public Layer {
|
||||
public:
|
||||
explicit CropLayer(const LayerConfig& config) : Layer(config) {}
|
||||
|
||||
~CropLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback = nullptr) override;
|
||||
|
||||
protected:
|
||||
void setOutDims();
|
||||
void setInDims();
|
||||
|
||||
int32_t crop_axis_;
|
||||
std::vector<uint32_t> crop_offsets_;
|
||||
std::vector<uint32_t> crop_corner_;
|
||||
TensorShape inDims_;
|
||||
TensorShape targetDims_;
|
||||
TensorShape outDims_;
|
||||
};
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue