Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_CudnnHolder_bug
commit
db5e3dd767
@ -1,20 +1,35 @@
|
||||
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
|
||||
file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n")
|
||||
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
|
||||
function(pass_library TARGET)
|
||||
set(options "")
|
||||
set(oneValueArgs "")
|
||||
set(multiValueArgs SRCS DEPS)
|
||||
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass)
|
||||
file(APPEND ${pass_file} "USE_PASS(${TARGET});\n")
|
||||
set(PASS_LIBRARY ${TARGET} ${PASS_LIBRARY} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
cc_library(node SRCS node.cc DEPS proto_desc)
|
||||
cc_library(graph SRCS graph.cc DEPS node)
|
||||
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
|
||||
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
|
||||
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
|
||||
cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
|
||||
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
|
||||
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
|
||||
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector)
|
||||
cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
|
||||
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
|
||||
cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
|
||||
cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector)
|
||||
|
||||
pass_library(graph_to_program_pass)
|
||||
pass_library(graph_viz_pass)
|
||||
pass_library(fc_fuse_pass)
|
||||
pass_library(attention_lstm_fuse_pass)
|
||||
pass_library(infer_clean_graph_pass)
|
||||
pass_library(fc_lstm_fuse_pass)
|
||||
pass_library(seq_concat_fc_fuse_pass)
|
||||
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
|
||||
|
||||
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
|
||||
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
|
||||
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
|
||||
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
|
||||
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
|
||||
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto)
|
||||
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
|
||||
|
@ -1,6 +1,7 @@
|
||||
{
|
||||
global:
|
||||
*paddle*;
|
||||
*Pass*;
|
||||
local:
|
||||
*;
|
||||
};
|
||||
|
@ -0,0 +1,66 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
/*
|
||||
* transform that computes target bounding-box regression deltas
|
||||
* given proposal boxes and ground-truth boxes.
|
||||
*/
|
||||
template <typename T>
|
||||
inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes,
|
||||
const framework::Tensor& gt_boxes, const T* weights,
|
||||
const bool normalized, framework::Tensor* box_delta) {
|
||||
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
|
||||
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
|
||||
auto trg = framework::EigenTensor<T, 2>::From(*box_delta);
|
||||
T ex_w, ex_h, ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y;
|
||||
for (int64_t i = 0; i < box_num; ++i) {
|
||||
ex_w = ex_boxes_et(i, 2) - ex_boxes_et(i, 0) + (normalized == false);
|
||||
ex_h = ex_boxes_et(i, 3) - ex_boxes_et(i, 1) + (normalized == false);
|
||||
ex_ctr_x = ex_boxes_et(i, 0) + 0.5 * ex_w;
|
||||
ex_ctr_y = ex_boxes_et(i, 1) + 0.5 * ex_h;
|
||||
|
||||
gt_w = gt_boxes_et(i, 2) - gt_boxes_et(i, 0) + (normalized == false);
|
||||
gt_h = gt_boxes_et(i, 3) - gt_boxes_et(i, 1) + (normalized == false);
|
||||
gt_ctr_x = gt_boxes_et(i, 0) + 0.5 * gt_w;
|
||||
gt_ctr_y = gt_boxes_et(i, 1) + 0.5 * gt_h;
|
||||
|
||||
trg(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w;
|
||||
trg(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h;
|
||||
trg(i, 2) = std::log(gt_w / ex_w);
|
||||
trg(i, 3) = std::log(gt_h / ex_h);
|
||||
|
||||
if (weights) {
|
||||
trg(i, 0) = trg(i, 0) / weights[0];
|
||||
trg(i, 1) = trg(i, 1) / weights[1];
|
||||
trg(i, 2) = trg(i, 2) / weights[2];
|
||||
trg(i, 3) = trg(i, 3) / weights[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Gather(const T* in, const int in_stride, const int* index, const int num,
|
||||
T* out) {
|
||||
const int stride_bytes = in_stride * sizeof(T);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
int id = index[i];
|
||||
memcpy(out + i * in_stride, in + id * in_stride, stride_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue