Merge remote-tracking branch 'ups/develop' into refine/op/lstm

update-install-command
tensor-tang 7 years ago
commit 3db1e41e12

@ -305,9 +305,9 @@ paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'neg
paddle.fluid.layers.detection_output ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0))
paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None))
paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral'))
paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'anchor_var', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3))
paddle.fluid.layers.rpn_target_assign ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True))
paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None))
paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'gt_boxes', 'im_scales', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None))
paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True))
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)

@ -120,13 +120,20 @@ void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map,
outputs.insert(node);
}
// update the dst and src node's inlinks and outlinks.
// update the dst and src node's inlinks and outlinks.
#ifdef __clang__
src_node->inlinks = std::vector<BriefNode *>(inputs.begin(), inputs.end());
src_node->outlinks = std::vector<BriefNode *>(outputs.begin(), outputs.end());
dst_node->inlinks.clear();
dst_node->outlinks.clear();
#else
src_node->inlinks =
std::move(std::vector<BriefNode *>(inputs.begin(), inputs.end()));
src_node->outlinks =
std::move(std::vector<BriefNode *>(outputs.begin(), outputs.end()));
dst_node->inlinks.clear();
dst_node->outlinks.clear();
#endif
auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) {
for (auto *&n : nodes) {

@ -77,6 +77,9 @@ bool AnalysisPredictor::Init(
OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_._use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get());

@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
@ -64,13 +64,15 @@ PaddleBuf& PaddleBuf::operator=(PaddleBuf&& other) {
void PaddleBuf::Resize(size_t length) {
// Only the owned memory can be reset, the external memory can't be changed.
if (length_ == length) return;
if (length_ >= length) return;
if (memory_owned_) {
Free();
data_ = malloc(length);
length_ = length;
memory_owned_ = true;
} else {
PADDLE_THROW("The memory is allocated externally, can not Resized");
}
data_ = new char[length];
length_ = length;
memory_owned_ = true;
}
void PaddleBuf::Reset(void* data, size_t length) {
@ -82,8 +84,8 @@ void PaddleBuf::Reset(void* data, size_t length) {
void PaddleBuf::Free() {
if (memory_owned_ && data_) {
assert(length_ > 0);
delete[] static_cast<char*>(data_);
PADDLE_ENFORCE_GT(length_, 0);
free(static_cast<char*>(data_));
data_ = nullptr;
length_ = 0;
}

@ -106,6 +106,9 @@ bool NativePaddlePredictor::Init(
}
ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_._use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0);

@ -45,7 +45,7 @@ class PaddleBuf {
PaddleBuf(void* data, size_t length)
: data_(data), length_(length), memory_owned_{false} {}
// Own memory.
PaddleBuf(size_t length)
explicit PaddleBuf(size_t length)
: data_(new char[length]), length_(length), memory_owned_(true) {}
// Resize to `length` bytes.
void Resize(size_t length);
@ -121,6 +121,8 @@ struct NativeConfig : public PaddlePredictor::Config {
bool use_gpu{false};
int device{0};
float fraction_of_gpu_memory{-1.f}; // Negative to notify initialization.
// NOTE: NOT use it, just for the internal test, will discard later
bool _use_mkldnn{false};
// Specify the variable's name of each input.
bool specify_input_name{false};

@ -53,5 +53,21 @@ set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classifi
download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz")
inference_analysis_test(test_analyzer_text_classification SRCS analyzer_text_classification_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/model
--infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt)
# ocr
set(OCR_MODEL_URL "http://paddlemodels.cdn.bcebos.com/inference-vis-demos%2Focr.tar.gz")
set(OCR_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR} AND WITH_INFERENCE)
get_filename_component(filename ${OCR_MODEL_URL} NAME)
message(STATUS "Download inference test stuff ${filename} from ${OCR_MODEL_URL}")
execute_process(COMMAND bash -c "mkdir -p ${OCR_INSTALL_DIR}")
execute_process(COMMAND bash -c "cd ${OCR_INSTALL_DIR} && wget -q ${OCR_MODEL_URL}")
execute_process(COMMAND bash -c "cd ${OCR_INSTALL_DIR} && tar xzf ${filename}")
message(STATUS "finish downloading ${filename}")
endif()
inference_analysis_test(test_analyzer_ocr SRCS analyzer_vis_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${OCR_INSTALL_DIR}/model
--infer_data=${OCR_INSTALL_DIR}/data.txt)

@ -110,8 +110,7 @@ const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25,
void TestLACPrediction(const std::string &model_path,
const std::string &data_file, const int batch_size,
const int repeat, bool test_all_data,
bool use_analysis = false) {
const int repeat, bool use_analysis = false) {
AnalysisConfig cfg;
cfg.model_dir = model_path;
cfg.use_gpu = false;
@ -199,13 +198,13 @@ void TestLACPrediction(const std::string &model_path,
TEST(Analyzer_LAC, native) {
LOG(INFO) << "LAC with native";
TestLACPrediction(FLAGS_infer_model, FLAGS_infer_data, FLAGS_batch_size,
FLAGS_repeat, FLAGS_test_all_data);
FLAGS_repeat);
}
TEST(Analyzer_LAC, analysis) {
LOG(INFO) << "LAC with analysis";
TestLACPrediction(FLAGS_infer_model, FLAGS_infer_data, FLAGS_batch_size,
FLAGS_repeat, FLAGS_test_all_data, true);
FLAGS_repeat, true);
}
} // namespace analysis

@ -0,0 +1,133 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include <iostream>
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
struct Record {
std::vector<float> data;
std::vector<int32_t> shape;
};
Record ProcessALine(const std::string &line) {
VLOG(3) << "process a line";
std::vector<std::string> columns;
split(line, '\t', &columns);
CHECK_EQ(columns.size(), 2UL)
<< "data format error, should be <data>\t<shape>";
Record record;
std::vector<std::string> data_strs;
split(columns[0], ' ', &data_strs);
for (auto &d : data_strs) {
record.data.push_back(std::stof(d));
}
std::vector<std::string> shape_strs;
split(columns[1], ' ', &shape_strs);
for (auto &s : shape_strs) {
record.shape.push_back(std::stoi(s));
}
VLOG(3) << "data size " << record.data.size();
VLOG(3) << "data shape size " << record.shape.size();
return record;
}
/*
* Use the native and analysis fluid engine to inference the demo.
* ocr, mobilenet and se_resnext50
*/
void TestVisualPrediction(bool use_mkldnn) {
std::unique_ptr<PaddlePredictor> predictor;
AnalysisConfig cfg;
cfg.param_file = FLAGS_infer_model + "/__params__";
cfg.prog_file = FLAGS_infer_model + "/__model__";
cfg.use_gpu = false;
cfg._use_mkldnn = use_mkldnn;
cfg.device = 0;
cfg.enable_ir_optim = true;
// TODO(TJ): fix fusion gru
cfg.ir_passes.push_back("fc_gru_fuse_pass");
#ifdef PADDLE_WITH_MKLDNN
// disable mkldnn fuse since it should have some bugs
cfg.ir_passes.push_back("conv_relu_mkldnn_fuse_pass");
#endif
predictor =
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(cfg);
// Only have single batch of data.
std::string line;
std::ifstream file(FLAGS_infer_data);
std::getline(file, line);
auto record = ProcessALine(line);
file.close();
// Inference.
PaddleTensor input;
input.shape = record.shape;
input.data =
PaddleBuf(record.data.data(), record.data.size() * sizeof(float));
input.dtype = PaddleDType::FLOAT32;
std::vector<PaddleTensor> outputs_slots;
Timer timer;
timer.tic();
for (int i = 0; i < FLAGS_repeat; i++) {
predictor->Run({input}, &outputs_slots);
}
PrintTime(/*batch size*/ 1, FLAGS_repeat, /*num threads*/ 1, /*thread id*/ 0,
timer.toc() / FLAGS_repeat);
VLOG(3) << "output.size " << outputs_slots.size();
// run native as reference
auto ref_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(cfg);
std::vector<PaddleTensor> ref_outputs_slots;
ref_predictor->Run({input}, &ref_outputs_slots);
CompareResult(outputs_slots, ref_outputs_slots);
// print what are fused
AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get());
auto &fuse_statis = analysis_predictor->analysis_argument()
.Get<std::unordered_map<std::string, int>>(
framework::ir::kFuseStatisAttr);
for (auto &item : fuse_statis) {
LOG(INFO) << "fused " << item.first << " " << item.second;
}
int num_ops = 0;
for (auto &node :
analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
if (node->IsFunction()) {
++num_ops;
}
}
LOG(INFO) << "has num ops: " << num_ops;
}
TEST(Analyzer_vis, analysis) { TestVisualPrediction(/*use_mkldnn*/ false); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_vis, analysis_mkldnn) {
TestVisualPrediction(/*use_mkldnn*/ true);
}
#endif
} // namespace analysis
} // namespace inference
} // namespace paddle

@ -37,22 +37,37 @@ namespace paddle {
namespace inference {
void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<PaddleTensor> &base_outputs) {
PADDLE_ENFORCE_GT(outputs.size(), 0);
PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
const std::vector<PaddleTensor> &ref_outputs) {
EXPECT_GT(outputs.size(), 0);
EXPECT_EQ(outputs.size(), ref_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i];
auto &base_out = base_outputs[i];
auto &ref_out = ref_outputs[i];
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
[](int a, int b) { return a * b; });
size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
1, [](int a, int b) { return a * b; });
PADDLE_ENFORCE_EQ(size, size1);
PADDLE_ENFORCE_GT(size, 0);
float *data = static_cast<float *>(out.data.data());
float *base_data = static_cast<float *>(base_out.data.data());
for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(data[i], base_data[i], 1e-3);
size_t ref_size =
std::accumulate(ref_out.shape.begin(), ref_out.shape.end(), 1,
[](int a, int b) { return a * b; });
EXPECT_GT(size, 0);
EXPECT_EQ(size, ref_size);
EXPECT_EQ(out.dtype, ref_out.dtype);
switch (out.dtype) {
case PaddleDType::INT64: {
int64_t *pdata = static_cast<int64_t *>(out.data.data());
int64_t *pdata_ref = static_cast<int64_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::FLOAT32: {
float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = static_cast<float *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_NEAR(pdata_ref[j], pdata[j], 1e-3);
}
break;
}
}
}
}

@ -300,6 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise");
int groups = ctx.Attr<int>("groups");
// TODO: add support for dilation
@ -366,12 +367,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides,
paddings, mkldnn_engine, fuse_relu);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_eltwise);
} else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine, fuse_relu);
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_eltwise);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
@ -421,16 +423,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
private:
mkldnn::primitive_attr AddRelu() const {
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
bool fuse_eltwise) const {
mkldnn::primitive_attr conv_attr;
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
mkldnn::post_ops post_operations;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_eltwise is true, the
// Output tensor contains the data coming from residual connection. The
// result of this post_op is: Output = scale * Output + Conv_Out.
if (fuse_eltwise) {
post_operations.append_sum(1.0f);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
@ -439,8 +451,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine,
const bool fuse_relu) const {
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
@ -449,10 +461,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr;
if (fuse_relu) {
conv_attr = AddRelu();
}
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
@ -466,8 +475,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine,
const bool fuse_relu) const {
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
@ -476,10 +485,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr;
if (fuse_relu) {
conv_attr = AddRelu();
}
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);

@ -164,6 +164,11 @@ void Conv2DOpMaker::Make() {
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_eltwise",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is connected via skip connection "
"to a previous layer.")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "

@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
@ -21,7 +22,7 @@ namespace operators {
*/
template <typename T>
inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes,
const framework::Tensor& gt_boxes, const T* weights,
const framework::Tensor& gt_boxes, const float* weights,
const bool normalized, framework::Tensor* box_delta) {
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
@ -62,5 +63,35 @@ void Gather(const T* in, const int in_stride, const int* index, const int num,
}
}
template <typename T>
void BboxOverlaps(const framework::Tensor& r_boxes,
const framework::Tensor& c_boxes,
framework::Tensor* overlaps) {
auto r_boxes_et = framework::EigenTensor<T, 2>::From(r_boxes);
auto c_boxes_et = framework::EigenTensor<T, 2>::From(c_boxes);
auto overlaps_et = framework::EigenTensor<T, 2>::From(*overlaps);
int r_num = r_boxes.dims()[0];
int c_num = c_boxes.dims()[0];
auto zero = static_cast<T>(0.0);
T r_box_area, c_box_area, x_min, y_min, x_max, y_max, inter_w, inter_h,
inter_area;
for (int i = 0; i < r_num; ++i) {
r_box_area = (r_boxes_et(i, 2) - r_boxes_et(i, 0) + 1) *
(r_boxes_et(i, 3) - r_boxes_et(i, 1) + 1);
for (int j = 0; j < c_num; ++j) {
c_box_area = (c_boxes_et(j, 2) - c_boxes_et(j, 0) + 1) *
(c_boxes_et(j, 3) - c_boxes_et(j, 1) + 1);
x_min = std::max(r_boxes_et(i, 0), c_boxes_et(j, 0));
y_min = std::max(r_boxes_et(i, 1), c_boxes_et(j, 1));
x_max = std::min(r_boxes_et(i, 2), c_boxes_et(j, 2));
y_max = std::min(r_boxes_et(i, 3), c_boxes_et(j, 3));
inter_w = std::max(x_max - x_min + 1, zero);
inter_h = std::max(y_max - y_min + 1, zero);
inter_area = inter_w * inter_h;
overlaps_et(i, j) = inter_area / (r_box_area + c_box_area - inter_area);
}
}
}
} // namespace operators
} // namespace paddle

@ -89,12 +89,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
}
for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len];
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1];
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0;
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0;
T anchor_center_x = (anchor_data[i * len + 2] + anchor_data[i * len]) / 2;
T anchor_center_y =
(anchor_data[i * len + 3] + anchor_data[i * len + 1]) / 2;
T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width;
T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height;
T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0;
@ -106,25 +105,31 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y;
bbox_width = std::exp(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2]) *
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
anchor_width;
bbox_height = std::exp(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3]) *
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
anchor_height;
} else {
bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(bbox_deltas_data[i * len + 2]) * anchor_width;
bbox_height = std::exp(bbox_deltas_data[i * len + 3]) * anchor_height;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
anchor_height;
}
proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
}
// return proposals;
}
@ -156,18 +161,23 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
float min_size, const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
min_size *= im_info_data[2];
T im_scale = im_info_data[2];
keep->Resize({boxes->dims()[0], 1});
min_size = std::max(min_size, 1.0f);
int *keep_data = keep->mutable_data<int>(ctx.GetPlace());
int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T ws_origin_scale =
(boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_scale + 1;
T hs_origin_scale =
(boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_scale + 1;
T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] &&
y_ctr <= im_info_data[0]) {
if (ws_origin_scale >= min_size && hs_origin_scale >= min_size &&
x_ctr <= im_info_data[1] && y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i;
}
}
@ -218,8 +228,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = inter_xmax - inter_xmin;
const T inter_h = inter_ymax - inter_ymin;
const T inter_w = std::max(0.0f, inter_xmax - inter_xmin + 1);
const T inter_h = std::max(0.0f, inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);

File diff suppressed because it is too large Load Diff

@ -82,8 +82,10 @@ class ProtoEncodeHelper {
: base_(buf), p_(buf), limit_(base_ + max_size) {}
~ProtoEncodeHelper() {
#define REPLACE_ENFORCE_GLOG 1
// Make sure callers didn't do operations that went over max_size promised
PADDLE_ENFORCE_LE(p_, limit_);
paddle::platform::throw_on_error(p_ <= limit_);
#undef REPLACE_ENFORCE_GLOG
}
const char* data() const { return base_; }

@ -59,17 +59,16 @@ static void ParallelExecuteBlocks(
framework::ProgramDesc *program, framework::Scope *scope) {
std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) {
fs.push_back(
framework::Async([&executor, &prepared, &program, &scope, idx]() {
int run_block = idx; // thread local
try {
VLOG(3) << "running server block: " << run_block
<< "pointer: " << prepared[run_block].get();
executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
fs.push_back(framework::Async([&executor, &prepared, &scope, idx]() {
int run_block = idx; // thread local
try {
VLOG(3) << "running server block: " << run_block
<< "pointer: " << prepared[run_block].get();
executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}

@ -26,10 +26,13 @@ class PReluOp : public framework::OperatorWithKernel {
std::string mode = ctx->Attrs().Get<std::string>("mode");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("Alpha"),
"Input(Alpha) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PreluOp should not be null");
if (mode == "all") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
"For mode 'all', size of weight Alpha must be one.");

@ -55,15 +55,19 @@ for _OP in set(__auto__):
globals()[_OP] = generate_layer_fn(_OP)
def rpn_target_assign(loc,
scores,
def rpn_target_assign(bbox_pred,
cls_logits,
anchor_box,
anchor_var,
gt_box,
gt_boxes,
is_crowd,
im_info,
rpn_batch_size_per_im=256,
fg_fraction=0.25,
rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3):
rpn_negative_overlap=0.3,
use_random=True):
"""
** Target Assign Layer for region proposal network (RPN) in Faster-RCNN detection. **
@ -83,14 +87,13 @@ def rpn_target_assign(loc,
the positive anchors.
Args:
loc(Variable): A 3-D Tensor with shape [N, M, 4] represents the
bbox_pred(Variable): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes. N is the batch size,
and each bounding box has four coordinate values and the layout
is [xmin, ymin, xmax, ymax].
scores(Variable): A 3-D Tensor with shape [N, M, C] represents the
predicted confidence predictions. N is the batch size, C is the
class number, M is number of bounding boxes. For each category
there are total M scores which corresponding M bounding boxes.
cls_logits(Variable): A 3-D Tensor with shape [N, M, 1] represents the
predicted confidence predictions. N is the batch size, 1 is the
frontground and background sigmoid, M is number of bounding boxes.
anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes,
each box is represented as [xmin, ymin, xmax, ymax],
[xmin, ymin] is the left top coordinate of the anchor box,
@ -99,11 +102,16 @@ def rpn_target_assign(loc,
coordinate of the anchor box.
anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded
variances of anchors.
gt_box (Variable): The ground-truth boudding boxes (bboxes) are a 2D
gt_boxes (Variable): The ground-truth boudding boxes (bboxes) are a 2D
LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth
bboxes of mini-batch input.
is_crowd (Variable): A 1-D LoDTensor which indicates groud-truth is crowd.
im_info (Variable): A 2-D LoDTensor with shape [N, 3]. N is the batch size,
3 is the height, width and scale.
rpn_batch_size_per_im(int): Total number of RPN examples per image.
fg_fraction(float): Target fraction of RoI minibatch that is labeled
rpn_straddle_thresh(float): Remove RPN anchors that go outside the image
by straddle_thresh pixels.
rpn_fg_fraction(float): Target fraction of RoI minibatch that is labeled
foreground (i.e. class > 0), 0-th class is background.
rpn_positive_overlap(float): Minimum overlap required between an anchor
and ground-truth box for the (anchor, gt box) pair to be a positive
@ -129,45 +137,48 @@ def rpn_target_assign(loc,
Examples:
.. code-block:: python
loc = layers.data(name='location', shape=[2, 80],
bbox_pred = layers.data(name='bbox_pred', shape=[100, 4],
append_batch_size=False, dtype='float32')
scores = layers.data(name='scores', shape=[2, 40],
cls_logits = layers.data(name='cls_logits', shape=[100, 1],
append_batch_size=False, dtype='float32')
anchor_box = layers.data(name='anchor_box', shape=[20, 4],
append_batch_size=False, dtype='float32')
gt_box = layers.data(name='gt_box', shape=[10, 4],
gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
append_batch_size=False, dtype='float32')
loc_pred, score_pred, loc_target, score_target =
fluid.layers.detection_output(loc=location,
scores=scores,
fluid.layers.rpn_target_assign(bbox_pred=bbox_pred,
cls_logits=cls_logits,
anchor_box=anchor_box,
gt_box=gt_box)
gt_boxes=gt_boxes)
"""
helper = LayerHelper('rpn_target_assign', **locals())
# Compute overlaps between the prior boxes and the gt boxes overlaps
iou = iou_similarity(x=gt_box, y=anchor_box)
# Assign target label to anchors
loc_index = helper.create_tmp_variable(dtype='int32')
score_index = helper.create_tmp_variable(dtype='int32')
target_label = helper.create_tmp_variable(dtype='int64')
target_label = helper.create_tmp_variable(dtype='int32')
target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype)
helper.append_op(
type="rpn_target_assign",
inputs={'Anchor': anchor_box,
'GtBox': gt_box,
'DistMat': iou},
inputs={
'Anchor': anchor_box,
'GtBoxes': gt_boxes,
'IsCrowd': is_crowd,
'ImInfo': im_info
},
outputs={
'LocationIndex': loc_index,
'ScoreIndex': score_index,
'TargetLabel': target_label,
'TargetBBox': target_bbox,
'TargetBBox': target_bbox
},
attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im,
'rpn_straddle_thresh': rpn_straddle_thresh,
'rpn_positive_overlap': rpn_positive_overlap,
'rpn_negative_overlap': rpn_negative_overlap,
'fg_fraction': fg_fraction
'rpn_fg_fraction': rpn_fg_fraction,
'use_random': use_random
})
loc_index.stop_gradient = True
@ -175,12 +186,12 @@ def rpn_target_assign(loc,
target_label.stop_gradient = True
target_bbox.stop_gradient = True
scores = nn.reshape(x=scores, shape=(-1, 1))
loc = nn.reshape(x=loc, shape=(-1, 4))
predicted_scores = nn.gather(scores, score_index)
predicted_location = nn.gather(loc, loc_index)
cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1))
bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
predicted_cls_logits = nn.gather(cls_logits, score_index)
predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
return predicted_scores, predicted_location, target_label, target_bbox
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox
def detection_output(loc,
@ -1258,15 +1269,17 @@ def anchor_generator(input,
def generate_proposal_labels(rpn_rois,
gt_classes,
is_crowd,
gt_boxes,
im_scales,
im_info,
batch_size_per_im=256,
fg_fraction=0.25,
fg_thresh=0.25,
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=None):
class_nums=None,
use_random=True):
"""
** Generate proposal labels Faster-RCNN **
TODO(buxingyuan): Add Document
@ -1285,8 +1298,9 @@ def generate_proposal_labels(rpn_rois,
inputs={
'RpnRois': rpn_rois,
'GtClasses': gt_classes,
'IsCrowd': is_crowd,
'GtBoxes': gt_boxes,
'ImScales': im_scales
'ImInfo': im_info
},
outputs={
'Rois': rois,
@ -1302,7 +1316,8 @@ def generate_proposal_labels(rpn_rois,
'bg_thresh_hi': bg_thresh_hi,
'bg_thresh_lo': bg_thresh_lo,
'bbox_reg_weights': bbox_reg_weights,
'class_nums': class_nums
'class_nums': class_nums,
'use_random': use_random
})
rois.stop_gradient = True

@ -148,51 +148,60 @@ class TestAnchorGenerator(unittest.TestCase):
class TestGenerateProposalLabels(unittest.TestCase):
def test_generate_proposal_labels(self):
rpn_rois = layers.data(
name='rpn_rois',
shape=[4, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
gt_classes = layers.data(
name='gt_classes',
shape=[6],
dtype='int32',
lod_level=1,
append_batch_size=False)
gt_boxes = layers.data(
name='gt_boxes',
shape=[6, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
im_scales = layers.data(
name='im_scales',
shape=[1],
dtype='float32',
lod_level=1,
append_batch_size=False)
class_nums = 5
rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
gt_classes=gt_classes,
gt_boxes=gt_boxes,
im_scales=im_scales,
batch_size_per_im=2,
fg_fraction=0.5,
fg_thresh=0.5,
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=class_nums)
assert rois.shape[1] == 4
assert rois.shape[0] == labels_int32.shape[0]
assert rois.shape[0] == bbox_targets.shape[0]
assert rois.shape[0] == bbox_inside_weights.shape[0]
assert rois.shape[0] == bbox_outside_weights.shape[0]
assert bbox_targets.shape[1] == 4 * class_nums
assert bbox_inside_weights.shape[1] == 4 * class_nums
assert bbox_outside_weights.shape[1] == 4 * class_nums
program = Program()
with program_guard(program):
rpn_rois = layers.data(
name='rpn_rois',
shape=[4, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
gt_classes = layers.data(
name='gt_classes',
shape=[6],
dtype='int32',
lod_level=1,
append_batch_size=False)
is_crowd = layers.data(
name='is_crowd',
shape=[6],
dtype='int32',
lod_level=1,
append_batch_size=False)
gt_boxes = layers.data(
name='gt_boxes',
shape=[6, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
im_info = layers.data(
name='im_info',
shape=[1, 3],
dtype='float32',
lod_level=1,
append_batch_size=False)
class_nums = 5
rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
gt_classes=gt_classes,
is_crowd=is_crowd,
gt_boxes=gt_boxes,
im_info=im_info,
batch_size_per_im=2,
fg_fraction=0.5,
fg_thresh=0.5,
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=class_nums)
assert rois.shape[1] == 4
assert rois.shape[0] == labels_int32.shape[0]
assert rois.shape[0] == bbox_targets.shape[0]
assert rois.shape[0] == bbox_inside_weights.shape[0]
assert rois.shape[0] == bbox_outside_weights.shape[0]
assert bbox_targets.shape[1] == 4 * class_nums
assert bbox_inside_weights.shape[1] == 4 * class_nums
assert bbox_outside_weights.shape[1] == 4 * class_nums
class TestMultiBoxHead(unittest.TestCase):
@ -254,18 +263,18 @@ class TestRpnTargetAssign(unittest.TestCase):
def test_rpn_target_assign(self):
program = Program()
with program_guard(program):
loc_shape = [10, 50, 4]
score_shape = [10, 50, 2]
bbox_pred_shape = [10, 50, 4]
cls_logits_shape = [10, 50, 2]
anchor_shape = [50, 4]
loc = layers.data(
name='loc',
shape=loc_shape,
bbox_pred = layers.data(
name='bbox_pred',
shape=bbox_pred_shape,
append_batch_size=False,
dtype='float32')
scores = layers.data(
name='scores',
shape=score_shape,
cls_logits = layers.data(
name='cls_logits',
shape=cls_logits_shape,
append_batch_size=False,
dtype='float32')
anchor_box = layers.data(
@ -278,17 +287,31 @@ class TestRpnTargetAssign(unittest.TestCase):
shape=anchor_shape,
append_batch_size=False,
dtype='float32')
gt_box = layers.data(
name='gt_box', shape=[4], lod_level=1, dtype='float32')
gt_boxes = layers.data(
name='gt_boxes', shape=[4], lod_level=1, dtype='float32')
is_crowd = layers.data(
name='is_crowd',
shape=[10],
dtype='int32',
lod_level=1,
append_batch_size=False)
im_info = layers.data(
name='im_info',
shape=[1, 3],
dtype='float32',
lod_level=1,
append_batch_size=False)
pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign(
loc=loc,
scores=scores,
bbox_pred=bbox_pred,
cls_logits=cls_logits,
anchor_box=anchor_box,
anchor_var=anchor_var,
gt_box=gt_box,
gt_boxes=gt_boxes,
is_crowd=is_crowd,
im_info=im_info,
rpn_batch_size_per_im=256,
fg_fraction=0.25,
rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3)

@ -20,10 +20,10 @@ import paddle.fluid as fluid
from op_test import OpTest
def generate_proposal_labels_in_python(
rpn_rois, gt_classes, gt_boxes, im_scales, batch_size_per_im,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights,
class_nums):
def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes,
im_info, batch_size_per_im, fg_fraction,
fg_thresh, bg_thresh_hi, bg_thresh_lo,
bbox_reg_weights, class_nums):
rois = []
labels_int32 = []
bbox_targets = []
@ -31,13 +31,13 @@ def generate_proposal_labels_in_python(
bbox_outside_weights = []
lod = []
assert len(rpn_rois) == len(
im_scales), 'batch size of rpn_rois and ground_truth is not matched'
im_info), 'batch size of rpn_rois and ground_truth is not matched'
for im_i in range(len(im_scales)):
for im_i in range(len(im_info)):
frcn_blobs = _sample_rois(
rpn_rois[im_i], gt_classes[im_i], gt_boxes[im_i], im_scales[im_i],
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bg_thresh_lo, bbox_reg_weights, class_nums)
rpn_rois[im_i], gt_classes[im_i], is_crowd[im_i], gt_boxes[im_i],
im_info[im_i], batch_size_per_im, fg_fraction, fg_thresh,
bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums)
lod.append(frcn_blobs['rois'].shape[0])
@ -50,13 +50,14 @@ def generate_proposal_labels_in_python(
return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, lod
def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo,
bbox_reg_weights, class_nums):
def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bg_thresh_lo, bbox_reg_weights, class_nums):
rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
# Roidb
im_scale = im_info[2]
inv_im_scale = 1. / im_scale
rpn_rois = rpn_rois * inv_im_scale
@ -78,6 +79,9 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[
overlapped_boxes_ind]
crowd_ind = np.where(is_crowd)[0]
gt_overlaps[crowd_ind] = -1
max_overlaps = gt_overlaps.max(axis=1)
max_classes = gt_overlaps.argmax(axis=1)
@ -85,9 +89,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
fg_inds = np.where(max_overlaps >= fg_thresh)[0]
fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0])
# Sample foreground if there are too many
if fg_inds.shape[0] > fg_rois_per_this_image:
fg_inds = np.random.choice(
fg_inds, size=fg_rois_per_this_image, replace=False)
# if fg_inds.shape[0] > fg_rois_per_this_image:
# fg_inds = np.random.choice(
# fg_inds, size=fg_rois_per_this_image, replace=False)
fg_inds = fg_inds[:fg_rois_per_this_image]
# Background
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
@ -96,9 +101,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.shape[0])
# Sample background if there are too many
if bg_inds.shape[0] > bg_rois_per_this_image:
bg_inds = np.random.choice(
bg_inds, size=bg_rois_per_this_image, replace=False)
# if bg_inds.shape[0] > bg_rois_per_this_image:
# bg_inds = np.random.choice(
# bg_inds, size=bg_rois_per_this_image, replace=False)
bg_inds = bg_inds[:bg_rois_per_this_image]
keep_inds = np.append(fg_inds, bg_inds)
sampled_labels = max_classes[keep_inds]
@ -208,8 +214,9 @@ class TestGenerateProposalLabelsOp(OpTest):
self.inputs = {
'RpnRois': (self.rpn_rois[0], self.rpn_rois_lod),
'GtClasses': (self.gt_classes[0], self.gts_lod),
'IsCrowd': (self.is_crowd[0], self.gts_lod),
'GtBoxes': (self.gt_boxes[0], self.gts_lod),
'ImScales': self.im_scales[0]
'ImInfo': self.im_info
}
self.attrs = {
'batch_size_per_im': self.batch_size_per_im,
@ -218,14 +225,15 @@ class TestGenerateProposalLabelsOp(OpTest):
'bg_thresh_hi': self.bg_thresh_hi,
'bg_thresh_lo': self.bg_thresh_lo,
'bbox_reg_weights': self.bbox_reg_weights,
'class_nums': self.class_nums
'class_nums': self.class_nums,
'use_random': False
}
self.outputs = {
'Rois': (self.rois[0], [self.lod]),
'LabelsInt32': (self.labels_int32[0], [self.lod]),
'BboxTargets': (self.bbox_targets[0], [self.lod]),
'BboxInsideWeights': (self.bbox_inside_weights[0], [self.lod]),
'BboxOutsideWeights': (self.bbox_outside_weights[0], [self.lod]),
'Rois': (self.rois, [self.lod]),
'LabelsInt32': (self.labels_int32, [self.lod]),
'BboxTargets': (self.bbox_targets, [self.lod]),
'BboxInsideWeights': (self.bbox_inside_weights, [self.lod]),
'BboxOutsideWeights': (self.bbox_outside_weights, [self.lod]),
}
def test_check_output(self):
@ -236,8 +244,8 @@ class TestGenerateProposalLabelsOp(OpTest):
self.set_data()
def init_test_params(self):
self.batch_size_per_im = 10
self.fg_fraction = 1.0
self.batch_size_per_im = 512
self.fg_fraction = 0.25
self.fg_thresh = 0.5
self.bg_thresh_hi = 0.5
self.bg_thresh_lo = 0.0
@ -246,14 +254,14 @@ class TestGenerateProposalLabelsOp(OpTest):
def init_test_input(self):
np.random.seed(0)
image_nums = 1
gt_nums = 6 # Keep same with batch_size_per_im for unittest
proposal_nums = self.batch_size_per_im - gt_nums
images_shape = []
self.im_scales = []
for i in range(image_nums):
images_shape.append(np.random.randint(200, size=2))
self.im_scales.append(np.ones((1)).astype(np.float32))
proposal_nums = 2000 #self.batch_size_per_im - gt_nums
images_shape = [[64, 64]]
self.im_info = np.ones((len(images_shape), 3)).astype(np.float32)
for i in range(len(images_shape)):
self.im_info[i, 0] = images_shape[i][0]
self.im_info[i, 1] = images_shape[i][1]
self.im_info[i, 2] = 0.8 #scale
self.rpn_rois, self.rpn_rois_lod = _generate_proposals(images_shape,
proposal_nums)
@ -261,16 +269,23 @@ class TestGenerateProposalLabelsOp(OpTest):
images_shape, self.class_nums, gt_nums)
self.gt_classes = [gt['gt_classes'] for gt in ground_truth]
self.gt_boxes = [gt['boxes'] for gt in ground_truth]
self.is_crowd = [gt['is_crowd'] for gt in ground_truth]
def init_test_output(self):
self.rois, self.labels_int32, self.bbox_targets, \
self.bbox_inside_weights, self.bbox_outside_weights, \
self.lod = generate_proposal_labels_in_python(
self.rpn_rois, self.gt_classes, self.gt_boxes, self.im_scales,
self.rpn_rois, self.gt_classes, self.is_crowd, self.gt_boxes, self.im_info,
self.batch_size_per_im, self.fg_fraction,
self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo,
self.bbox_reg_weights, self.class_nums
)
self.rois = np.vstack(self.rois)
self.labels_int32 = np.hstack(self.labels_int32)
self.labels_int32 = self.labels_int32[:, np.newaxis]
self.bbox_targets = np.vstack(self.bbox_targets)
self.bbox_inside_weights = np.vstack(self.bbox_inside_weights)
self.bbox_outside_weights = np.vstack(self.bbox_outside_weights)
def _generate_proposals(images_shape, proposal_nums):
@ -280,7 +295,7 @@ def _generate_proposals(images_shape, proposal_nums):
for i, image_shape in enumerate(images_shape):
proposals = _generate_boxes(image_shape, proposal_nums)
rpn_rois.append(proposals)
num_proposals += len(proposals)
num_proposals = len(proposals)
rpn_rois_lod.append(num_proposals)
return rpn_rois, [rpn_rois_lod]
@ -294,7 +309,11 @@ def _generate_groundtruth(images_shape, class_nums, gt_nums):
gt_classes = np.random.randint(
low=1, high=class_nums, size=gt_nums).astype(np.int32)
gt_boxes = _generate_boxes(image_shape, gt_nums)
ground_truth.append(dict(gt_classes=gt_classes, boxes=gt_boxes))
is_crowd = np.zeros((gt_nums), dtype=np.int32)
is_crowd[0] = 1
ground_truth.append(
dict(
gt_classes=gt_classes, boxes=gt_boxes, is_crowd=is_crowd))
num_gts += len(gt_classes)
gts_lod.append(num_gts)
return ground_truth, [gts_lod]

@ -114,10 +114,10 @@ def box_coder(all_anchors, bbox_deltas, variances):
#anchor_loc: width, height, center_x, center_y
anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32)
anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0]
anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1]
anchor_loc[:, 2] = (all_anchors[:, 2] + all_anchors[:, 0]) / 2
anchor_loc[:, 3] = (all_anchors[:, 3] + all_anchors[:, 1]) / 2
anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0] + 1
anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1] + 1
anchor_loc[:, 2] = all_anchors[:, 0] + 0.5 * anchor_loc[:, 0]
anchor_loc[:, 3] = all_anchors[:, 1] + 0.5 * anchor_loc[:, 1]
#predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height
pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32)
@ -127,23 +127,29 @@ def box_coder(all_anchors, bbox_deltas, variances):
i, 0] + anchor_loc[i, 2]
pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[
i, 1] + anchor_loc[i, 3]
pred_bbox[i, 2] = math.exp(variances[i, 2] *
bbox_deltas[i, 2]) * anchor_loc[i, 0]
pred_bbox[i, 3] = math.exp(variances[i, 3] *
bbox_deltas[i, 3]) * anchor_loc[i, 1]
pred_bbox[i, 2] = math.exp(
min(variances[i, 2] * bbox_deltas[i, 2], math.log(
1000 / 16.0))) * anchor_loc[i, 0]
pred_bbox[i, 3] = math.exp(
min(variances[i, 3] * bbox_deltas[i, 3], math.log(
1000 / 16.0))) * anchor_loc[i, 1]
else:
for i in range(bbox_deltas.shape[0]):
pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[
i, 2]
pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[
i, 3]
pred_bbox[i, 2] = math.exp(bbox_deltas[i, 2]) * anchor_loc[i, 0]
pred_bbox[i, 3] = math.exp(bbox_deltas[i, 3]) * anchor_loc[i, 1]
pred_bbox[i, 2] = math.exp(
min(bbox_deltas[i, 2], math.log(1000 / 16.0))) * anchor_loc[i,
0]
pred_bbox[i, 3] = math.exp(
min(bbox_deltas[i, 3], math.log(1000 / 16.0))) * anchor_loc[i,
1]
proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2
proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2
proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2
proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2
proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2 - 1
proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2 - 1
return proposals
@ -170,13 +176,16 @@ def filter_boxes(boxes, min_size, im_info):
"""Only keep boxes with both sides >= min_size and center within the image.
"""
# Scale min_size to match image scale
min_size *= im_info[2]
im_scale = im_info[2]
min_size = max(min_size, 1.0)
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
ws_orig_scale = (boxes[:, 2] - boxes[:, 0]) / im_scale + 1
hs_orig_scale = (boxes[:, 3] - boxes[:, 1]) / im_scale + 1
x_ctr = boxes[:, 0] + ws / 2.
y_ctr = boxes[:, 1] + hs / 2.
keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_info[1]) &
(y_ctr < im_info[0]))[0]
keep = np.where((ws_orig_scale >= min_size) & (hs_orig_scale >= min_size) &
(x_ctr < im_info[1]) & (y_ctr < im_info[0]))[0]
return keep
@ -204,7 +213,7 @@ def iou(box_a, box_b):
xb = min(xmax_a, xmax_b)
yb = min(ymax_a, ymax_b)
inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0)
inter_area = max(xb - xa + 1, 0.0) * max(yb - ya + 1, 0.0)
iou_ratio = inter_area / (area_a + area_b - inter_area)

@ -65,8 +65,43 @@ class InferenceTranspiler(object):
if use_mkldnn:
self._fuse_conv_bias_mkldnn(program)
self._fuse_conv_relu_mkldnn(program)
self._fuse_conv_eltwise_mkldnn(program)
self._fuse_conv_relu_mkldnn(
program) # ResNet residual block merging
self._fuse_bn_relu_mkldnn(program)
def _fuse_conv_eltwise_mkldnn(self, program):
'''
Transpile the program fusing elementwise_add into conv for MKLDNN
program. Elementwise add following convolution OP can be fused by adding
'fuse_eltwise' attribute to convolution OP and replacing its output
Tensor with second parameter of elementwise_add.
The result of fuse is:
- before:
- conv->elementwise_add->any_other_op
- after:
- conv->any_other_op
:param program: program to transpile
:type program: Program
'''
self.block = program.block(0)
i = 0
while i < len(self.block.ops):
current_op = self.block.ops[i]
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'elementwise_add':
self._fuse_conv_eltwise(current_op, next_op)
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
self._adjust_input()
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def _fuse_conv_relu_mkldnn(self, program):
'''
Transpile the program by fused relu activation for MKLDNN program.
@ -88,9 +123,9 @@ class InferenceTranspiler(object):
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'relu':
# modify conv OP to include relu
# modify bnorm OP to include relu
current_op.set_attr("fuse_relu", True)
# remove conv OP
# remove relu OP
self.block._remove_op(i + 1)
i = i + 1
@ -409,6 +444,20 @@ class InferenceTranspiler(object):
outputs={"Output": out_var},
attrs=attrs)
def _fuse_conv_eltwise(self, conv_op, eltwise_op):
'''
fuse the conv op with elementwise_add
:param conv_op: convolution operator
:type conv_op: Operator
:param eltwise_op: operator adding data from skip connection
:type eltwise_op: Operator
'''
conv_op.set_attr("fuse_eltwise", True)
self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0]
self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0]
def _adjust_input(self):
for i in range(len(self.block.ops)):
current_op = self.block.ops[i]

Loading…
Cancel
Save