refine test_recognize_digits_mlp and format codes (#5937)

release/0.11.0
QI JUN 8 years ago committed by Yu Yang
parent d89ff5b614
commit b28b2f172b

@ -55,7 +55,7 @@ paddle_error paddle_matrix_set_row(paddle_matrix mat,
} }
PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat, PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat,
paddle_real* value) { paddle_real* value) {
if (mat == nullptr || value == nullptr) return kPD_NULLPTR; if (mat == nullptr || value == nullptr) return kPD_NULLPTR;
auto ptr = cast(mat); auto ptr = cast(mat);
if (ptr->mat == nullptr) return kPD_NULLPTR; if (ptr->mat == nullptr) return kPD_NULLPTR;
@ -75,7 +75,7 @@ PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat,
} }
PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat, PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat,
paddle_real* result) { paddle_real* result) {
if (mat == nullptr || result == nullptr) return kPD_NULLPTR; if (mat == nullptr || result == nullptr) return kPD_NULLPTR;
auto ptr = cast(mat); auto ptr = cast(mat);
if (ptr->mat == nullptr) return kPD_NULLPTR; if (ptr->mat == nullptr) return kPD_NULLPTR;

@ -79,7 +79,7 @@ PD_API paddle_error paddle_matrix_set_row(paddle_matrix mat,
* @note value should contain enough element of data to init the mat * @note value should contain enough element of data to init the mat
*/ */
PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat, PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat,
paddle_real* value); paddle_real* value);
/** /**
* @brief PDMatGetRow Get raw row buffer from matrix * @brief PDMatGetRow Get raw row buffer from matrix
@ -93,14 +93,14 @@ PD_API paddle_error paddle_matrix_get_row(paddle_matrix mat,
paddle_real** rawRowBuffer); paddle_real** rawRowBuffer);
/** /**
* @brief copy data from the matrix * @brief copy data from the matrix
* @param [in] mat Target matrix * @param [in] mat Target matrix
* @param [out] result pointer to store the matrix data * @param [out] result pointer to store the matrix data
* @return paddle_error * @return paddle_error
* @note the space of the result should allocated before invoke this API * @note the space of the result should allocated before invoke this API
*/ */
PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat, PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat,
paddle_real* result); paddle_real* result);
/** /**
* @brief PDMatCreateNone Create None Matrix * @brief PDMatCreateNone Create None Matrix
* @return * @return

@ -135,18 +135,17 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx,
auto dst_ptr = static_cast<void*>(dst->data()); auto dst_ptr = static_cast<void*>(dst->data());
if (platform::is_cpu_place(src.place())) { if (platform::is_cpu_place(src.place())) {
memory::Copy(dst_place, dst_ptr, boost::get<platform::CPUPlace>(src.place()), memory::Copy(dst_place, dst_ptr,
src_ptr, size); boost::get<platform::CPUPlace>(src.place()), src_ptr, size);
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src.place())) { // NOLINT else if (platform::is_gpu_place(src.place())) { // NOLINT
memory::Copy( memory::Copy(
dst_place, dst_ptr, boost::get<platform::GPUPlace>(src.place()), src_ptr, dst_place, dst_ptr, boost::get<platform::GPUPlace>(src.place()),
size, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} }
#endif #endif
} }
} // namespace framework } // namespace framework

@ -23,8 +23,7 @@ template <typename T>
class MaxOutFunctor<platform::CPUPlace, T> { class MaxOutFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input, framework::Tensor* output,
framework::Tensor * output,
int groups) { int groups) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
@ -37,34 +36,30 @@ class MaxOutFunctor<platform::CPUPlace, T> {
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i; int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c; int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) { for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX); T ele = static_cast<T>(-FLT_MAX);
for (int ph = 0; ph < groups; ++ph) { for (int ph = 0; ph < groups; ++ph) {
T x = input_data[(new_bindex + new_cindex) * groups T x = input_data[(new_bindex + new_cindex) * groups +
+ ph * fea_size + f]; ph * fea_size + f];
ele = ele > x ? ele : x; ele = ele > x ? ele : x;
} }
output_data[(new_bindex+new_cindex+f)] = ele; output_data[(new_bindex + new_cindex + f)] = ele;
} }
} }
} }
} }
}; };
template <class T> template <class T>
class MaxOutGradFunctor<platform::CPUPlace, T> { class MaxOutGradFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input, framework::Tensor* input_grad,
framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad, int groups) {
int groups) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
@ -84,11 +79,11 @@ public:
bool continue_match = true; bool continue_match = true;
int output_idx = blen + clen + f; int output_idx = blen + clen + f;
for (int g = 0; g < groups && continue_match; ++g) { for (int g = 0; g < groups && continue_match; ++g) {
int input_idx = input_idx0 + fea_size * g; int input_idx = input_idx0 + fea_size * g;
if (input_data[input_idx] == output_data[output_idx]) { if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx]; input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false; continue_match = false;
} }
} }
} }
} }

@ -21,9 +21,9 @@ namespace math {
template <typename T> template <typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data, __global__ void KernelMaxOut(const int nthreads, const T* input_data,
const int channels, const int channels, const int input_height,
const int input_height, const int input_width, const int input_width, int groups,
int groups, T* output_data ) { T* output_data) {
const int size = input_height * input_width * channels / groups; const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width; const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
@ -34,7 +34,7 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
int channel_idx = batch_offset / feat_len; int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len; int feat_idx = batch_offset % feat_len;
int data_idx = int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
T ele = static_cast<T>(-FLT_MAX); T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
T x = input_data[data_idx + g * feat_len]; T x = input_data[data_idx + g * feat_len];
@ -44,34 +44,35 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
} }
} }
template <typename T> template <typename T>
__global__ void KernelMaxoutGrad( __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
const int nthreads, const T* input_data, const T* output_data, const T* output_data, const T* output_grad,
const T* output_grad, T* input_grad, const int channels, T* input_grad, const int channels,
const int input_height, const int input_width, int groups) { const int input_height, const int input_width,
const int size = input_height * input_width * channels / groups; int groups) {
const int feat_len = input_height * input_width; const int size = input_height * input_width * channels / groups;
int index = blockIdx.x * blockDim.x + threadIdx.x; const int feat_len = input_height * input_width;
int offset = blockDim.x * gridDim.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = index; i < nthreads; i += offset) { int offset = blockDim.x * gridDim.x;
int batch_idx = i / size; for (int i = index; i < nthreads; i += offset) {
int batch_offset = i % size; int batch_idx = i / size;
int channel_idx = batch_offset / feat_len; int batch_offset = i % size;
int feat_idx = batch_offset % feat_len; int channel_idx = batch_offset / feat_len;
int data_idx = int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int max_index = -1; int max_index = -1;
bool continue_match = true; bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) { for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) { if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len; max_index = data_idx + g * feat_len;
continue_match = false; continue_match = false;
break; break;
}
}
if (max_index != -1) {
input_grad[max_index] += output_grad[index];
} }
} }
if (max_index != -1) {
input_grad[max_index] += output_grad[index];
}
}
} }
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW format.
@ -80,7 +81,7 @@ template <typename T>
class MaxOutFunctor<platform::GPUPlace, T> { class MaxOutFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output, const framework::Tensor& input, framework::Tensor* output,
int groups) { int groups) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[1];
@ -92,7 +93,7 @@ class MaxOutFunctor<platform::GPUPlace, T> {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel(); int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024; int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
@ -101,8 +102,7 @@ class MaxOutFunctor<platform::GPUPlace, T> {
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, input_channels, .stream()>>>(nthreads, input_data, input_channels,
input_height, input_width, groups, input_height, input_width, groups, output_data);
output_data);
} }
}; };
/* /*
@ -112,11 +112,9 @@ template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> { class MaxOutGradFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input, framework::Tensor* input_grad,
framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad, int groups) {
int groups) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[1];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
@ -129,7 +127,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel(); int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024; int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
@ -137,9 +135,9 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
KernelMaxoutGrad< KernelMaxoutGrad<
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(nthreads, input_data, output_data,
nthreads, input_data, output_data, output_grad_data, input_grad_data, output_grad_data, input_grad_data, input_channels,
input_channels, input_height, input_width, groups); input_height, input_width, groups);
} }
}; };

@ -21,15 +21,14 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#define FLT_MAX \ #define FLT_MAX __FLT_MAX__
__FLT_MAX__
template <typename Place, typename T> template <typename Place, typename T>
class MaxOutFunctor { class MaxOutFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output, const framework::Tensor& input, framework::Tensor* output,
int groups); int groups);
}; };
@ -37,8 +36,7 @@ template <typename Place, class T>
class MaxOutGradFunctor { class MaxOutGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input, framework::Tensor* input_grad,
framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, int groups); const framework::Tensor& output_grad, int groups);
}; };

@ -22,16 +22,17 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput(
"X",
"(Tensor) The input tensor of maxout operator. " "(Tensor) The input tensor of maxout operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature."); "number of channels, H and W is the height and width of feature.");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of maxout operator." "(Tensor) The output tensor of maxout operator."
"The format of output tensor is also NCHW." "The format of output tensor is also NCHW."
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, H and W is the height and " "the number of channels, H and W is the height and "
"width of feature."); "width of feature.");
AddAttr<int>( AddAttr<int>(
"groups", "groups",
R"DOC("Specifies how many groups the input tensor will be split" R"DOC("Specifies how many groups the input tensor will be split"
@ -59,21 +60,19 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class MaxOutOp : public framework::OperatorWithKernel { class MaxOutOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MaxoutOp" PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxoutOp"
"should not be null."); "should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MaxoutOp should not be null."); "Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
// check groups > 1 // check groups > 1
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop");
groups, 1,
"groups should be larger than 1 in maxoutop");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups}); std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
output_shape.push_back(in_x_dims[2]); output_shape.push_back(in_x_dims[2]);
output_shape.push_back(in_x_dims[3]); output_shape.push_back(in_x_dims[3]);
@ -87,18 +86,17 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null."); "Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad, REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad,
ops::MaxOutOpGrad); ops::MaxOutOpGrad);
REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::CPUPlace, REGISTER_OP_CPU_KERNEL(maxout,
float>); ops::MaxOutKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(maxout_grad, REGISTER_OP_CPU_KERNEL(
ops::MaxOutGradKernel<paddle::platform::CPUPlace, maxout_grad, ops::MaxOutGradKernel<paddle::platform::CPUPlace, float>);
float>);

@ -18,8 +18,6 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(maxout, REGISTER_OP_GPU_KERNEL(maxout,
ops::MaxOutKernel<paddle::platform::GPUPlace, float>, ops::MaxOutKernel<paddle::platform::GPUPlace, float>,
ops::MaxOutKernel<paddle::platform::GPUPlace, double>); ops::MaxOutKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(maxout_grad, REGISTER_OP_GPU_KERNEL(
ops::MaxOutGradKernel<paddle::platform::GPUPlace, maxout_grad, ops::MaxOutGradKernel<paddle::platform::GPUPlace, float>,
float>, ops::MaxOutGradKernel<paddle::platform::GPUPlace, double>);
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
double>);

@ -53,7 +53,7 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
zero(device_ctx, in_x_grad, static_cast<T>(0.0)); zero(device_ctx, in_x_grad, static_cast<T>(0.0));
math::MaxOutGradFunctor<Place, T> maxout_backward; math::MaxOutGradFunctor<Place, T> maxout_backward;
maxout_backward(context.device_context(), *in_x, in_x_grad, *out, maxout_backward(context.device_context(), *in_x, in_x_grad, *out,
*out_grad, groups); *out_grad, groups);
} }
} }
}; };

@ -43,8 +43,8 @@ class ROIPoolOp : public framework::OperatorWithKernel {
"ROIs should be a 2-D tensor of shape (num_rois, 5)" "ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …]."); "given as [[batch_id, x1, y1, x2, y2], …].");
PADDLE_ENFORCE(rois_dims[1] == kROISize, PADDLE_ENFORCE(rois_dims[1] == kROISize,
"ROIs should be a 2-D tensor of shape (num_rois, 5)" "ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …]."); "given as [[batch_id, x1, y1, x2, y2], …].");
int pooled_height = ctx->Attrs().Get<int>("pooled_height"); int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width"); int pooled_width = ctx->Attrs().Get<int>("pooled_width");
@ -65,7 +65,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("Argmax", out_dims); ctx->SetOutputDim("Argmax", out_dims);
} }
protected: protected:
framework::OpKernelType GetKernelType( framework::OpKernelType GetKernelType(
@ -100,7 +100,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ROIPoolOpMaker(framework::OpProto* proto, ROIPoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor), " "(Tensor), "
@ -125,21 +125,22 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), " "(Tensor), "
"Argmaxes corresponding to indices in X used " "Argmaxes corresponding to indices in X used "
"for gradient computation. Only output " "for gradient computation. Only output "
"if arg “is_test” is false.").AsIntermediate(); "if arg “is_test” is false.")
.AsIntermediate();
AddAttr<float>("spatial_scale", AddAttr<float>("spatial_scale",
"(float, default 1.0), " "(float, default 1.0), "
"Multiplicative spatial scale factor " "Multiplicative spatial scale factor "
"to translate ROI coords from their input scale " "to translate ROI coords from their input scale "
"to the scale used when pooling.") "to the scale used when pooling.")
.SetDefault(1.0); .SetDefault(1.0);
AddAttr<int>("pooled_height", AddAttr<int>("pooled_height",
"(int, default 1), " "(int, default 1), "
"The pooled output height.") "The pooled output height.")
.SetDefault(1); .SetDefault(1);
AddAttr<int>("pooled_width", AddAttr<int>("pooled_width",
"(int, default 1), " "(int, default 1), "
"The pooled output width.") "The pooled output width.")
.SetDefault(1); .SetDefault(1);
AddComment(R"DOC( AddComment(R"DOC(
ROIPool operator ROIPool operator
@ -153,11 +154,10 @@ https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, roi_pool_grad,
roi_pool_grad, ops::ROIPoolGradOp); ops::ROIPoolGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
roi_pool, roi_pool, ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>); ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
roi_pool_grad, roi_pool_grad,

File diff suppressed because it is too large Load Diff

@ -136,8 +136,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad = auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height"); auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width"); auto pooled_width = ctx.Attr<int>("pooled_width");

@ -45,7 +45,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
// Initialize the output's dims to maximum, // Initialize the output's dims to maximum,
// and re-set to real dims by the value of Offset and Length at kernel // and re-set to real dims by the value of Offset and Length at kernel
ctx->SetOutputDim("Out", input_dims); ctx->SetOutputDim("Out", input_dims);
} }
protected: protected:
framework::OpKernelType GetKernelType( framework::OpKernelType GetKernelType(
@ -93,8 +93,7 @@ class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), " "(Tensor), "
"a vector<int> to describe the length of every input sequence for " "a vector<int> to describe the length of every input sequence for "
"sub sequence item."); "sub sequence item.");
AddOutput("Out", AddOutput("Out", "(LoDTensor), the output of SequenceSliceOp.");
"(LoDTensor), the output of SequenceSliceOp.");
AddComment(R"DOC( AddComment(R"DOC(
Sequence slice operator Sequence slice operator

@ -38,6 +38,7 @@ UCI_TEST_DATA = None
URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fit_a_line.tar' URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fit_a_line.tar'
MD5_MODEL = '52fc3da8ef3937822fcdd87ee05c0c9b' MD5_MODEL = '52fc3da8ef3937822fcdd87ee05c0c9b'
def feature_range(maximums, minimums): def feature_range(maximums, minimums):
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
@ -114,7 +115,8 @@ def test():
def model(): def model():
tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar', MD5_MODEL) tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar',
MD5_MODEL)
with open(tar_file, 'r') as f: with open(tar_file, 'r') as f:
parameters = Parameters.from_tar(f) parameters = Parameters.from_tar(f)
return parameters return parameters

@ -35,6 +35,13 @@ opts = optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=predict, label=label) accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
inference_program = fluid.default_main_program().clone()
test_accuracy = fluid.evaluator.Accuracy(
input=predict, label=label, main_program=inference_program)
test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
inference_program = fluid.io.get_inference_program(
test_target, main_program=inference_program)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192), paddle.dataset.mnist.train(), buf_size=8192),
@ -69,11 +76,6 @@ for pass_id in range(PASS_NUM):
acc = np.array(outs[1]) acc = np.array(outs[1])
pass_acc = accuracy.eval(exe) pass_acc = accuracy.eval(exe)
test_accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
inference_program = fluid.io.get_inference_program(test_target)
test_accuracy.reset(exe) test_accuracy.reset(exe)
for data in test_reader(): for data in test_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32") x_data = np.array(map(lambda x: x[0], data)).astype("float32")

@ -30,9 +30,7 @@ class TestMaxOutOp(OpTest):
def init_test_case(self): def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2] self.shape = [100, 6, 2, 2]
self.groups=2 self.groups = 2
if __name__ == '__main__': if __name__ == '__main__':

@ -4,24 +4,22 @@ import math
import sys import sys
from op_test import OpTest from op_test import OpTest
class TestROIPoolOp(OpTest): class TestROIPoolOp(OpTest):
def set_data(self): def set_data(self):
self.init_test_case() self.init_test_case()
self.make_rois() self.make_rois()
self.calc_roi_pool() self.calc_roi_pool()
self.inputs = { self.inputs = {'X': self.x, 'ROIs': self.rois}
'X': self.x,
'ROIs': self.rois}
self.attrs = { self.attrs = {
'spatial_scale': self.spatial_scale, 'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height, 'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width} 'pooled_width': self.pooled_width
}
self.outputs = { self.outputs = {'Out': self.outs, 'Argmax': self.argmaxes}
'Out': self.outs,
'Argmax': self.argmaxes}
def init_test_case(self): def init_test_case(self):
self.batch_size = 5 self.batch_size = 5
@ -30,10 +28,9 @@ class TestROIPoolOp(OpTest):
self.width = 4 self.width = 4
# n, c, h, w # n, c, h, w
self.x_dim = (self.batch_size, self.channels, self.x_dim = (self.batch_size, self.channels, self.height, self.width)
self.height, self.width)
self.spatial_scale = 1.0/4.0 self.spatial_scale = 1.0 / 4.0
self.pooled_height = 2 self.pooled_height = 2
self.pooled_width = 2 self.pooled_width = 2
self.rois_num = 2 self.rois_num = 2
@ -41,13 +38,11 @@ class TestROIPoolOp(OpTest):
self.x = np.random.random(self.x_dim).astype('float32') self.x = np.random.random(self.x_dim).astype('float32')
def calc_roi_pool(self): def calc_roi_pool(self):
out_data = np.zeros( out_data = np.zeros((self.rois_num, self.channels, self.pooled_height,
(self.rois_num, self.channels, self.pooled_width))
self.pooled_height, self.pooled_width)) argmax_data = np.zeros((self.rois_num, self.channels,
argmax_data = np.zeros( self.pooled_height, self.pooled_width))
(self.rois_num, self.channels,
self.pooled_height, self.pooled_width))
for i in range(self.rois_num): for i in range(self.rois_num):
roi = self.rois[i] roi = self.rois[i]
roi_batch_id = roi[0] roi_batch_id = roi[0]
@ -56,8 +51,8 @@ class TestROIPoolOp(OpTest):
roi_end_w = int(round(roi[3] * self.spatial_scale)) roi_end_w = int(round(roi[3] * self.spatial_scale))
roi_end_h = int(round(roi[4] * self.spatial_scale)) roi_end_h = int(round(roi[4] * self.spatial_scale))
roi_height = int(max(roi_end_h - roi_start_h + 1, 1)); roi_height = int(max(roi_end_h - roi_start_h + 1, 1))
roi_width = int(max(roi_end_w - roi_start_w + 1, 1)); roi_width = int(max(roi_end_w - roi_start_w + 1, 1))
x_i = self.x[roi_batch_id] x_i = self.x[roi_batch_id]
@ -84,7 +79,7 @@ class TestROIPoolOp(OpTest):
out_data[i, c, ph, pw] = -sys.float_info.max out_data[i, c, ph, pw] = -sys.float_info.max
argmax_data[i, c, ph, pw] = -1 argmax_data[i, c, ph, pw] = -1
for h in range(hstart, hend): for h in range(hstart, hend):
for w in range(wstart, wend): for w in range(wstart, wend):
if x_i[c, h, w] > out_data[i, c, ph, pw]: if x_i[c, h, w] > out_data[i, c, ph, pw]:
@ -104,11 +99,11 @@ class TestROIPoolOp(OpTest):
y1 = np.random.random_integers( y1 = np.random.random_integers(
0, self.height / self.spatial_scale - self.pooled_height) 0, self.height / self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers( x2 = np.random.random_integers(x1 + self.pooled_width,
x1 + self.pooled_width, self.width / self.spatial_scale) self.width / self.spatial_scale)
y2 = np.random.random_integers( y2 = np.random.random_integers(y1 + self.pooled_height,
y1 + self.pooled_height, self.height / self.spatial_scale) self.height / self.spatial_scale)
roi = [batch_ids[i], x1, y1, x2, y2] roi = [batch_ids[i], x1, y1, x2, y2]
rois.append(roi) rois.append(roi)
self.rois = np.array(rois).astype("int64") self.rois = np.array(rois).astype("int64")
@ -123,5 +118,6 @@ class TestROIPoolOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save