|
|
|
@ -17,28 +17,27 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/function/TensorShape.h"
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
static inline CropConf castToCropConf(const FuncConfig& conf) {
|
|
|
|
|
return {conf.get<std::vector<uint32_t>>("crop_corner"),
|
|
|
|
|
conf.get<std::vector<uint32_t>>("crop_shape")};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void Crop<DEVICE_TYPE_CPU>(real* outputs,
|
|
|
|
|
const real* inputs,
|
|
|
|
|
const TensorShape inShape,
|
|
|
|
|
const CropConf& crop) {
|
|
|
|
|
int cCrop = crop.corner[0];
|
|
|
|
|
int hCrop = crop.corner[1];
|
|
|
|
|
int wCrop = crop.corner[2];
|
|
|
|
|
const FuncConfig& conf) {
|
|
|
|
|
std::vector<uint32_t> crop_corner =
|
|
|
|
|
conf.get<std::vector<uint32_t>>("crop_corner");
|
|
|
|
|
std::vector<uint32_t> crop_shape =
|
|
|
|
|
conf.get<std::vector<uint32_t>>("crop_shape");
|
|
|
|
|
int cCrop = crop_corner[1];
|
|
|
|
|
int hCrop = crop_corner[2];
|
|
|
|
|
int wCrop = crop_corner[3];
|
|
|
|
|
|
|
|
|
|
int num = inShape[0];
|
|
|
|
|
int inC = inShape[1];
|
|
|
|
|
int inH = inShape[2];
|
|
|
|
|
int inW = inShape[3];
|
|
|
|
|
|
|
|
|
|
int outC = crop.shape[0];
|
|
|
|
|
int outH = crop.shape[1];
|
|
|
|
|
int outW = crop.shape[2];
|
|
|
|
|
int outC = crop_shape[1];
|
|
|
|
|
int outH = crop_shape[2];
|
|
|
|
|
int outW = crop_shape[3];
|
|
|
|
|
|
|
|
|
|
for (int n = 0; n < num; n++) {
|
|
|
|
|
for (int c = 0; c < outC; c++) {
|
|
|
|
@ -55,19 +54,23 @@ template <>
|
|
|
|
|
void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
|
|
|
|
|
real* outGrad,
|
|
|
|
|
const TensorShape outShape,
|
|
|
|
|
const CropConf& crop) {
|
|
|
|
|
int cCrop = crop.corner[0];
|
|
|
|
|
int hCrop = crop.corner[1];
|
|
|
|
|
int wCrop = crop.corner[2];
|
|
|
|
|
const FuncConfig& conf) {
|
|
|
|
|
std::vector<uint32_t> crop_corner =
|
|
|
|
|
conf.get<std::vector<uint32_t>>("crop_corner");
|
|
|
|
|
std::vector<uint32_t> crop_shape =
|
|
|
|
|
conf.get<std::vector<uint32_t>>("crop_shape");
|
|
|
|
|
int cCrop = crop_corner[1];
|
|
|
|
|
int hCrop = crop_corner[2];
|
|
|
|
|
int wCrop = crop_corner[3];
|
|
|
|
|
|
|
|
|
|
int num = outShape[0];
|
|
|
|
|
int outC = outShape[1];
|
|
|
|
|
int outH = outShape[2];
|
|
|
|
|
int outW = outShape[3];
|
|
|
|
|
|
|
|
|
|
int inC = crop.shape[0];
|
|
|
|
|
int inH = crop.shape[1];
|
|
|
|
|
int inW = crop.shape[2];
|
|
|
|
|
int inC = crop_shape[1];
|
|
|
|
|
int inH = crop_shape[2];
|
|
|
|
|
int inW = crop_shape[3];
|
|
|
|
|
|
|
|
|
|
for (int n = 0; n < num; n++) {
|
|
|
|
|
for (int c = 0; c < inC; c++) {
|
|
|
|
@ -111,26 +114,21 @@ void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
|
|
|
|
|
template <DeviceType Device>
|
|
|
|
|
class CropFunc : public FunctionBase {
|
|
|
|
|
public:
|
|
|
|
|
void init(const FuncConfig& config) override {
|
|
|
|
|
crop_ = castToCropConf(config);
|
|
|
|
|
}
|
|
|
|
|
void init(const FuncConfig& config) override { conf_ = config; }
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ(1UL, inputs.size());
|
|
|
|
|
CHECK_EQ(1UL, outputs.size());
|
|
|
|
|
CHECK_EQ(outputs[0].shape()[1], crop_.shape[0]);
|
|
|
|
|
CHECK_EQ(outputs[0].shape()[2], crop_.shape[1]);
|
|
|
|
|
CHECK_EQ(outputs[0].shape()[3], crop_.shape[2]);
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
|
|
|
|
|
|
|
|
|
|
TensorShape inShape = inputs[0].shape();
|
|
|
|
|
|
|
|
|
|
Crop<Device>(
|
|
|
|
|
outputs[0].data<real>(), inputs[0].data<real>(), inShape, crop_);
|
|
|
|
|
outputs[0].data<real>(), inputs[0].data<real>(), inShape, conf_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
CropConf crop_;
|
|
|
|
|
FuncConfig conf_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -145,26 +143,21 @@ private:
|
|
|
|
|
template <DeviceType Device>
|
|
|
|
|
class CropGradFunc : public FunctionBase {
|
|
|
|
|
public:
|
|
|
|
|
void init(const FuncConfig& config) override {
|
|
|
|
|
crop_ = castToCropConf(config);
|
|
|
|
|
}
|
|
|
|
|
void init(const FuncConfig& config) override { conf_ = config; }
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ(1UL, inputs.size());
|
|
|
|
|
CHECK_EQ(1UL, outputs.size());
|
|
|
|
|
CHECK_EQ(inputs[0].shape()[1], crop_.shape[0]);
|
|
|
|
|
CHECK_EQ(inputs[0].shape()[2], crop_.shape[1]);
|
|
|
|
|
CHECK_EQ(inputs[0].shape()[3], crop_.shape[2]);
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
|
|
|
|
|
|
|
|
|
|
TensorShape outShape = outputs[0].shape();
|
|
|
|
|
|
|
|
|
|
CropGrad<Device>(
|
|
|
|
|
inputs[0].data<real>(), outputs[0].data<real>(), outShape, crop_);
|
|
|
|
|
inputs[0].data<real>(), outputs[0].data<real>(), outShape, conf_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
CropConf crop_;
|
|
|
|
|
FuncConfig conf_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
|
|
|
|
|