Merge branch 'develop' into grid_sampler

fix_recordio_link
Kaipeng Deng 7 years ago committed by GitHub
commit a3b26e8528
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -68,7 +68,6 @@ option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better d
option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(WITH_INFERENCE "Compile fluid inference library" ON)
option(ON_INFER "Turn on inference optimization." OFF) option(ON_INFER "Turn on inference optimization." OFF)
option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
@ -305,6 +304,9 @@ if(WITH_DOC)
endif() endif()
if (ON_INFER) if (ON_INFER)
message(WARNING "On inference mode, will take place some specific optimization.") message(STATUS "On inference mode, will take place some specific optimization.")
add_definitions(-DPADDLE_ON_INFERENCE) add_definitions(-DPADDLE_ON_INFERENCE)
else()
#TODO(luotao), combine this warning with `make inference_lib_dist` command.
message(WARNING "On inference mode, will take place some specific optimization. Turn on the ON_INFER flag when building inference_lib only.")
endif() endif()

@ -7,7 +7,11 @@ set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include")
IF(WITH_STATIC_LIB) IF(WITH_STATIC_LIB)
SET(BUILD_CMD make lib) SET(BUILD_CMD make lib)
ELSE() ELSE()
SET(BUILD_CMD sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib) IF(APPLE)
SET(BUILD_CMD sed -i \"\" "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib)
ELSE(APPLE)
SET(BUILD_CMD sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib)
ENDIF(APPLE)
ENDIF() ENDIF()
ExternalProject_Add( ExternalProject_Add(

@ -14,9 +14,6 @@
# make package for paddle fluid shared and static library # make package for paddle fluid shared and static library
function(copy TARGET) function(copy TARGET)
if (NOT ON_INFER)
message(WARNING "Turn on the ON_INFER flag when building inference_lib only.")
endif()
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DSTS DEPS) set(multiValueArgs SRCS DSTS DEPS)

@ -24,6 +24,7 @@ if(NOT WITH_FLUID_ONLY)
endif() endif()
add_subdirectory(testing) add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
if(NOT MOBILE_INFERENCE AND NOT RPI AND NOT WITH_C_API) if(NOT MOBILE_INFERENCE AND NOT RPI AND NOT WITH_C_API)
add_subdirectory(fluid) add_subdirectory(fluid)
endif() endif()

@ -178,6 +178,8 @@ paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, k
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)

@ -9,8 +9,6 @@ add_subdirectory(pybind)
add_subdirectory(recordio) add_subdirectory(recordio)
endif(NOT WIN32) endif(NOT WIN32)
if(WITH_INFERENCE) # NOTE: please add subdirectory inference at last.
# NOTE: please add subdirectory inference at last. add_subdirectory(inference)
add_subdirectory(inference) add_subdirectory(train)
add_subdirectory(train)
endif()

@ -64,6 +64,13 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
case proto::AttrType::LONG: { case proto::AttrType::LONG: {
return attr_desc.l(); return attr_desc.l();
} }
case proto::AttrType::LONGS: {
std::vector<int64_t> val(attr_desc.longs_size());
for (int i = 0; i < attr_desc.longs_size(); ++i) {
val[i] = attr_desc.longs(i);
}
return val;
}
default: default:
PADDLE_THROW("Unsupport attr type %d", attr_desc.type()); PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
} }

@ -26,6 +26,113 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<std::vector<int64_t>> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
std::vector<int64_t>* operator()(Attribute& attr) const {
if (attr.type() == typeid(std::vector<int>)) { // NOLINT
std::vector<int> val = boost::get<std::vector<int>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
} else if (attr.type() == typeid(std::vector<float>)) { // NOLINT
std::vector<float> val = boost::get<std::vector<float>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
}
std::vector<int64_t>* attr_value = nullptr;
try {
attr_value = &boost::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <typename T> template <typename T>
inline proto::AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
@ -42,7 +149,11 @@ class AttrReader {
inline const T& Get(const std::string& name) const { inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name); name);
return boost::get<T>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
} }
private: private:
@ -82,7 +193,7 @@ class DefaultValueSetter {
public: public:
explicit DefaultValueSetter(T default_value) explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {} : default_value_(default_value) {}
void operator()(T& value) const { value = default_value_; } void operator()(T& value) const { value = default_value_; } // NOLINT
private: private:
T default_value_; T default_value_;
@ -117,84 +228,6 @@ class EnumInContainer {
std::unordered_set<T> container_; std::unordered_set<T> container_;
}; };
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// check whether a certain attribute fit its limits // check whether a certain attribute fit its limits
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
@ -235,7 +268,7 @@ class TypedAttrChecker {
return *this; return *this;
} }
void operator()(AttributeMap& attr_map) const { void operator()(AttributeMap& attr_map) const { // NOLINT
if (!attr_map.count(attr_name_)) { if (!attr_map.count(attr_name_)) {
// user do not set this attr // user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(), PADDLE_ENFORCE(!default_value_setter_.empty(),
@ -271,7 +304,7 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>()); return *(checker.target<TypedAttrChecker<T>>());
} }
void Check(AttributeMap& attr_map) const { void Check(AttributeMap& attr_map) const { // NOLINT
for (const auto& checker : attr_checkers_) { for (const auto& checker : attr_checkers_) {
checker(attr_map); checker(attr_map);
} }

@ -59,6 +59,10 @@ void BroadcastOpHandle::BroadcastOneVar(
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
if (UNLIKELY(!in_tensor.IsInitialized())) {
VLOG(3) << "in var " << in_var_handle.name_ << "not inited, return!";
return;
}
InitOutputValue(in_var_handle, out_var_handles); InitOutputValue(in_var_handle, out_var_handles);

@ -722,7 +722,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
} }
if (node->Op()->Type() == "split_byref" || if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") { node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {

@ -35,6 +35,7 @@ enum AttrType {
BLOCK = 8; BLOCK = 8;
LONG = 9; LONG = 9;
BLOCKS = 10; BLOCKS = 10;
LONGS = 11;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
@ -55,6 +56,7 @@ message OpDesc {
optional int32 block_idx = 12; optional int32 block_idx = 12;
optional int64 l = 13; optional int64 l = 13;
repeated int32 blocks_idx = 14; repeated int32 blocks_idx = 14;
repeated int64 longs = 15;
}; };
message Var { message Var {

@ -419,8 +419,15 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
} }
VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx()); VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
} }
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(int64_t v) const { attr_->set_l(v); } void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(const std::vector<int64_t> &v) const {
VectorToRepeated(v, attr_->mutable_longs());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };

@ -33,7 +33,7 @@ enum class OpRole {
// used for distributed training. // used for distributed training.
kDist = 0x0008, kDist = 0x0008,
// Tag all learning rate scheduler operators. // Tag all learning rate scheduler operators.
kLRSched = 0x0016, kLRSched = 0x0010,
kLoss = 0x0100, kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and // The default value of op's role. This should be only used for unittests and

@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>(); return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
} }
static const Tensor* GetTensorFromVar(Variable* var) { const Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>(); return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {

@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(Variable* var);
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;

@ -187,6 +187,10 @@ void ParallelExecutor::BCastParamsToDevices(
} }
auto &main_tensor = main_var->Get<LoDTensor>(); auto &main_tensor = main_var->Get<LoDTensor>();
if (!main_tensor.IsInitialized()) {
VLOG(3) << "one in var not inited, return!";
continue;
}
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) { if (paddle::platform::is_gpu_place(main_tensor.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
@ -299,10 +303,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
} }
ParallelExecutor::~ParallelExecutor() { ParallelExecutor::~ParallelExecutor() {
const auto dev_ctxs = for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().GetAllDeviceContexts(); platform::DeviceContextPool::Instance().Get(p)->Wait();
for (auto &dev_ctx : dev_ctxs) {
dev_ctx->Wait();
} }
if (member_->own_local_scope_) { if (member_->own_local_scope_) {

@ -36,7 +36,7 @@ using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*, int64_t, std::vector<bool>, BlockDesc*, int64_t,
std::vector<BlockDesc*>>; std::vector<BlockDesc*>, std::vector<int64_t>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;

@ -61,8 +61,6 @@ cc_test(test_paddle_inference_api
inference_api_test(test_api_impl SRC api_impl_tester.cc inference_api_test(test_api_impl SRC api_impl_tester.cc
ARGS test_word2vec test_image_classification) ARGS test_word2vec test_image_classification)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor ${inference_deps} paddle_inference_api cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor ${inference_deps} paddle_inference_api
ARGS --dirname=${PYTHON_TESTS_DIR}/book) ARGS --dirname=${PYTHON_TESTS_DIR}/book)

@ -22,9 +22,9 @@ limitations under the License. */
#include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/inference/tests/test_helper.h"
#ifdef __clang__ #ifdef __clang__
#define ACC_DIFF 4e-3 #define ACC_DIFF 4e-2
#else #else
#define ACC_DIFF 1e-3 #define ACC_DIFF 1e-2
#endif #endif
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
@ -187,7 +187,7 @@ void MainThreadsWord2Vec(bool use_gpu) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (int tid = 0; tid < num_jobs; ++tid) { for (int tid = 0; tid < num_jobs; ++tid) {
threads.emplace_back([&, tid]() { threads.emplace_back([&, tid]() {
auto predictor = main_predictor->Clone(); auto predictor = CreatePaddlePredictor(config);
auto& local_inputs = paddle_tensor_feeds[tid]; auto& local_inputs = paddle_tensor_feeds[tid];
std::vector<PaddleTensor> local_outputs; std::vector<PaddleTensor> local_outputs;
ASSERT_TRUE(predictor->Run(local_inputs, &local_outputs)); ASSERT_TRUE(predictor->Run(local_inputs, &local_outputs));
@ -245,7 +245,7 @@ void MainThreadsImageClassification(bool use_gpu) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (int tid = 0; tid < num_jobs; ++tid) { for (int tid = 0; tid < num_jobs; ++tid) {
threads.emplace_back([&, tid]() { threads.emplace_back([&, tid]() {
auto predictor = main_predictor->Clone(); auto predictor = CreatePaddlePredictor(config);
auto& local_inputs = paddle_tensor_feeds[tid]; auto& local_inputs = paddle_tensor_feeds[tid];
std::vector<PaddleTensor> local_outputs; std::vector<PaddleTensor> local_outputs;
ASSERT_TRUE(predictor->Run(local_inputs, &local_outputs)); ASSERT_TRUE(predictor->Run(local_inputs, &local_outputs));
@ -271,7 +271,7 @@ TEST(inference_api_native, word2vec_cpu_threads) {
MainThreadsWord2Vec(false /*use_gpu*/); MainThreadsWord2Vec(false /*use_gpu*/);
} }
TEST(inference_api_native, image_classification_cpu) { TEST(inference_api_native, image_classification_cpu) {
MainThreadsImageClassification(false /*use_gpu*/); MainImageClassification(false /*use_gpu*/);
} }
TEST(inference_api_native, image_classification_cpu_threads) { TEST(inference_api_native, image_classification_cpu_threads) {
MainThreadsImageClassification(false /*use_gpu*/); MainThreadsImageClassification(false /*use_gpu*/);
@ -279,15 +279,17 @@ TEST(inference_api_native, image_classification_cpu_threads) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TEST(inference_api_native, word2vec_gpu) { MainWord2Vec(true /*use_gpu*/); } TEST(inference_api_native, word2vec_gpu) { MainWord2Vec(true /*use_gpu*/); }
TEST(inference_api_native, word2vec_gpu_threads) { // Turn off temporarily for the unstable result.
MainThreadsWord2Vec(true /*use_gpu*/); // TEST(inference_api_native, word2vec_gpu_threads) {
} // MainThreadsWord2Vec(true /*use_gpu*/);
// }
TEST(inference_api_native, image_classification_gpu) { TEST(inference_api_native, image_classification_gpu) {
MainThreadsImageClassification(true /*use_gpu*/); MainImageClassification(true /*use_gpu*/);
}
TEST(inference_api_native, image_classification_gpu_threads) {
MainThreadsImageClassification(true /*use_gpu*/);
} }
// Turn off temporarily for the unstable result.
// TEST(inference_api_native, image_classification_gpu_threads) {
// MainThreadsImageClassification(true /*use_gpu*/);
// }
#endif #endif

@ -60,8 +60,7 @@ for WITH_STATIC_LIB in ON OFF; do
-DWITH_MKL=$TURN_ON_MKL \ -DWITH_MKL=$TURN_ON_MKL \
-DDEMO_NAME=simple_on_word2vec \ -DDEMO_NAME=simple_on_word2vec \
-DWITH_GPU=$TEST_GPU_CPU \ -DWITH_GPU=$TEST_GPU_CPU \
-DWITH_STATIC_LIB=$WITH_STATIC_LIB \ -DWITH_STATIC_LIB=$WITH_STATIC_LIB
-DON_INFER=ON
make -j make -j
word2vec_model=${PADDLE_ROOT}'/build/python/paddle/fluid/tests/book/word2vec.inference.model' word2vec_model=${PADDLE_ROOT}'/build/python/paddle/fluid/tests/book/word2vec.inference.model'
if [ -d $word2vec_model ]; then if [ -d $word2vec_model ]; then
@ -81,8 +80,7 @@ for WITH_STATIC_LIB in ON OFF; do
-DWITH_MKL=$TURN_ON_MKL \ -DWITH_MKL=$TURN_ON_MKL \
-DDEMO_NAME=vis_demo \ -DDEMO_NAME=vis_demo \
-DWITH_GPU=$TEST_GPU_CPU \ -DWITH_GPU=$TEST_GPU_CPU \
-DWITH_STATIC_LIB=$WITH_STATIC_LIB \ -DWITH_STATIC_LIB=$WITH_STATIC_LIB
-DON_INFER=ON
make -j make -j
for use_gpu in $use_gpu_list; do for use_gpu in $use_gpu_list; do
for vis_demo_name in $vis_demo_list; do for vis_demo_name in $vis_demo_list; do
@ -108,8 +106,7 @@ for WITH_STATIC_LIB in ON OFF; do
-DWITH_STATIC_LIB=$WITH_STATIC_LIB \ -DWITH_STATIC_LIB=$WITH_STATIC_LIB \
-DUSE_TENSORRT=$USE_TENSORRT \ -DUSE_TENSORRT=$USE_TENSORRT \
-DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \ -DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \
-DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR \ -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR
-DON_INFER=ON
make -j make -j
./trt_mobilenet_demo \ ./trt_mobilenet_demo \
--modeldir=$DATA_DIR/mobilenet/model \ --modeldir=$DATA_DIR/mobilenet/model \

@ -301,6 +301,7 @@ op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding) op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op) op_library(unstack_op DEPS stack_op)
op_library(fake_quantize_op DEPS memory) op_library(fake_quantize_op DEPS memory)
op_library(crf_decoding_op DEPS jit_kernel)
op_library(fusion_lstm_op DEPS jit_kernel) op_library(fusion_lstm_op DEPS jit_kernel)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(conv_op DEPS vol2col depthwise_conv im2col)

@ -0,0 +1,97 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/add_position_encoding_op.h"
namespace paddle {
namespace operators {
class AddPositionEncodingOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of add_position_encoding_op should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Out(Output) of add_position_encoding_op should not be null.");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class AddPositionEncodingOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Out must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Out@GRAD must not be null.");
auto out_dims = ctx->GetInputDim("Out");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), out_dims);
}
}
};
class AddPositionEncodingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of AddPositionEncoding operator");
AddOutput("Out", "Output of AddPositionEncoding operator");
AddAttr<float>("alpha", "The scale of Original Embedding.")
.SetDefault(1.0f)
.AddCustomChecker([](const float& alpha) {
PADDLE_ENFORCE(alpha >= 0.0f, "'alpha' must be above 0.0.");
});
AddAttr<float>("beta", "The scale of Position Embedding.")
.SetDefault(1.0f)
.AddCustomChecker([](const float& beta) {
PADDLE_ENFORCE(beta >= 0.0f, "'beta' must be between 0.0.");
});
AddComment(R"DOC(
Add Position Encoding Operator.
The add position encoding calculates the output based on the input, alpha, beta.
The size of each dimension of the parameters checked in the infer-shape.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plt = paddle::platform;
REGISTER_OPERATOR(add_position_encoding, ops::AddPositionEncodingOp,
ops::AddPositionEncodingOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad);
REGISTER_OP_CPU_KERNEL(
add_position_encoding,
ops::AddPositionEncodingKernel<plt::CPUDeviceContext, float>,
ops::AddPositionEncodingKernel<plt::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
add_position_encoding_grad,
ops::AddPositionEncodingGradKernel<plt::CPUDeviceContext, float>,
ops::AddPositionEncodingGradKernel<plt::CPUDeviceContext, double>);

@ -0,0 +1,105 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class AddPositionEncodingKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::LoDTensor>("X");
auto& x_lod = X->lod();
auto* src_ptr = X->data<T>();
auto* Out = context.Output<framework::LoDTensor>("Out");
auto* dst_ptr = Out->mutable_data<T>(context.GetPlace());
float alpha = context.Attr<float>("alpha");
float beta = context.Attr<float>("beta");
auto x_dim = X->dims();
int batch_size = 0;
int max_seq_len = 0;
int enc_size = 0;
if (x_lod.empty()) {
PADDLE_ENFORCE(
x_dim.size() == 3UL,
"The input X of Add Position Encoding should be 3-D Tensor!");
batch_size = x_dim[0];
max_seq_len = x_dim[1];
enc_size = x_dim[2];
} else {
PADDLE_ENFORCE(
x_dim.size() == 2UL,
"The input X of Add Position Encoding should be 2-D LoDTensor!");
PADDLE_ENFORCE(
x_lod.size() == 1UL,
"The Add Position Encoding Op only supports lod_level == 1!");
batch_size = x_lod[0].size() - 1;
max_seq_len = -1;
enc_size = x_dim[1];
}
PADDLE_ENFORCE(enc_size % 2 == 0, "Only support even encode size!");
const int half_size = enc_size / 2;
for (int i = 0; i < batch_size; ++i) {
const int max_length =
x_lod.empty() ? max_seq_len : x_lod[0][i + 1] - x_lod[0][i];
for (int j = 0; j < max_length; ++j) {
for (int k = 0; k < half_size; ++k) {
const double val = (half_size > 1)
? j / pow(10000.0, double(k) / (half_size - 1))
: j / 10000.0;
dst_ptr[k] = src_ptr[k] * alpha + sin(val) * beta;
dst_ptr[half_size + k] =
src_ptr[half_size + k] * alpha + cos(val) * beta;
}
src_ptr += enc_size;
dst_ptr += enc_size;
}
}
}
};
template <typename DeviceContext, typename T>
class AddPositionEncodingGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dOut =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto dout = framework::EigenVector<T>::Flatten(*dOut);
auto* dX =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto dx = framework::EigenVector<T>::Flatten(*dX);
float alpha = context.Attr<float>("alpha");
auto* place =
context.template device_context<DeviceContext>().eigen_device();
dx.device(*place) = dout * static_cast<T>(alpha);
}
};
} // namespace operators
} // namespace paddle

@ -16,6 +16,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
@ -69,9 +70,6 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
auto emission_dims = emission_weights.dims(); auto emission_dims = emission_weights.dims();
const size_t seq_len = emission_dims[0]; const size_t seq_len = emission_dims[0];
const size_t tag_num = emission_dims[1]; const size_t tag_num = emission_dims[1];
const size_t state_trans_base_idx = 2;
const T* x = emission_weights.data<T>(); const T* x = emission_weights.data<T>();
const T* w = transition_weights.data<T>(); const T* w = transition_weights.data<T>();
int64_t* path = decoded_path->data<int64_t>(); int64_t* path = decoded_path->data<int64_t>();
@ -84,221 +82,10 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
const auto& ker = math::jitkernel::KernelPool::Instance()
#ifdef __AVX__ .template Get<math::jitkernel::CRFDecodeKernel<T>>(
// It use the AVX or AVX512 instruction to deal the data as the vector of 8 or static_cast<int>(tag_num));
// 16 elements per iteration. Then it can implement the parallel processing. ker->Compute(static_cast<int>(seq_len), x, w, alpha_value, track_value);
// Only optimize for float type.
#ifdef __AVX512F__
size_t step_size = 16;
#else
size_t step_size = 8;
#endif
if (std::is_same<T, float>::value && (tag_num >= step_size)) {
size_t steps = tag_num / step_size;
size_t remain = tag_num % step_size;
int last_offset = static_cast<int>(remain) - static_cast<int>(step_size);
// Setup the alpha initial value.
size_t i_offset = 0;
for (size_t i = 0; i <= steps; ++i) {
#ifdef __AVX512F__
// Declare the variable for the content of weights, input and alpha
// values.
__m512 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm512_loadu_ps((const float*)(w + i_offset));
x_content = _mm512_loadu_ps((const float*)(x + i_offset));
alpha_content = _mm512_add_ps(w_content, x_content);
// Save the alpha value.
_mm512_storeu_ps(reinterpret_cast<float*>(alpha_value + i_offset),
alpha_content);
#else
// Declare the variable for the content of weights, input and alpha
// values.
__m256 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm256_loadu_ps((const float*)(w + i_offset));
x_content = _mm256_loadu_ps((const float*)(x + i_offset));
alpha_content = _mm256_add_ps(w_content, x_content);
// Save the alpha value.
_mm256_storeu_ps(reinterpret_cast<float*>(alpha_value + i_offset),
alpha_content);
#endif
i_offset += step_size;
if (i == steps - 1) {
if (remain > 0) {
i_offset += last_offset;
} else {
break;
}
}
}
// Use the column-major strategy to get the location of maximum score.
size_t seq_offset = 0;
for (size_t k = 1; k < seq_len; ++k) {
size_t j_offset = 0;
for (size_t j = 0; j <= steps; ++j) {
#ifdef __AVX512F__
// Initialize the variables of maximum score and location.
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<T>::max());
__m512i max_j = _mm512_setzero_si512();
#else
// Initialize the variables of maximum score and location.
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<T>::max());
__m256i max_j = _mm256_set1_epi32(0);
#endif
// Calculate the offset of transition_weights.
size_t trans_offset = state_trans_base_idx * tag_num + j_offset;
for (size_t i = 0; i < tag_num; ++i) {
#ifdef __AVX512F__
// Initalize the content of alpha variable with related offset.
__m512 alpha_content =
_mm512_set1_ps(*(const float*)(alpha_value + seq_offset + i));
// Obtain the content of weights from un-aligned address.
__m512 w_content =
_mm512_loadu_ps((const float*)(w + trans_offset));
__m512 score_v = _mm512_add_ps(alpha_content, w_content);
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS);
// According to the mask value, it update the index of the max_score
// location.
max_j = _mm512_mask_set1_epi32(max_j, mask, i);
// Update the max_score value.
max_score = _mm512_max_ps(max_score, score_v);
#else
// Initalize the content of alpha variable with related offset.
__m256 alpha_content = _mm256_broadcast_ss(
(const float*)(alpha_value + seq_offset + i));
// Obtain the content of weights from un-aligned address.
__m256 w_content =
_mm256_loadu_ps((const float*)(w + trans_offset));
__m256 score_v = _mm256_add_ps(alpha_content, w_content);
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS);
#ifdef __AVX2__
// According to the mask value, it update the index of the max_score
// location.
max_j = _mm256_or_si256(
_mm256_andnot_si256((__m256i)mask, max_j),
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i)));
#else
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0);
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1);
__m128i lo_mask = _mm256_extractf128_si256((__m256i)mask, 0);
__m128i hi_mask = _mm256_extractf128_si256((__m256i)mask, 1);
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j);
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j);
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i));
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i));
lo_max_j = _mm_or_si128(lo_mask, lo_max_j);
hi_max_j = _mm_or_si128(hi_mask, hi_max_j);
// According to the mask value, it update the index of the max_score
// location.
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0);
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1);
#endif
// Update the max_score value.
max_score = _mm256_max_ps(max_score, score_v);
#endif
trans_offset += tag_num;
}
#ifdef __AVX512F__
// Update the alpha and track values.
__m512 x_content = _mm512_loadu_ps(
(const float*)(x + seq_offset + tag_num + j_offset));
max_score = _mm512_add_ps(max_score, x_content);
_mm512_storeu_ps(reinterpret_cast<float*>(alpha_value + seq_offset +
tag_num + j_offset),
max_score);
_mm512_storeu_si512(
reinterpret_cast<__m512i*>(track_value + seq_offset + tag_num +
j_offset),
max_j);
#else
// Update the alpha and track values.
__m256 x_content = _mm256_loadu_ps(
(const float*)(x + seq_offset + tag_num + j_offset));
max_score = _mm256_add_ps(max_score, x_content);
_mm256_storeu_ps(reinterpret_cast<float*>(alpha_value + seq_offset +
tag_num + j_offset),
max_score);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(track_value + seq_offset + tag_num +
j_offset),
max_j);
#endif
// Calculate the offset of next step
j_offset += step_size;
if (j == steps - 1) {
if (remain > 0) {
j_offset += last_offset;
} else {
break;
}
}
}
seq_offset += tag_num;
}
} else {
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
for (size_t k = 1; k < seq_len; ++k) {
for (size_t i = 0; i < tag_num; ++i) {
T max_score = -std::numeric_limits<T>::max();
int max_j = 0;
for (size_t j = 0; j < tag_num; ++j) {
T score = alpha_value[(k - 1) * tag_num + j] +
w[(j + state_trans_base_idx) * tag_num + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
track_value[k * tag_num + i] = max_j;
}
}
}
#else
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
for (size_t k = 1; k < seq_len; ++k) {
for (size_t i = 0; i < tag_num; ++i) {
T max_score = -std::numeric_limits<T>::max();
int max_j = 0;
for (size_t j = 0; j < tag_num; ++j) {
T score = alpha_value[(k - 1) * tag_num + j] +
w[(j + state_trans_base_idx) * tag_num + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
track_value[k * tag_num + i] = max_j;
}
}
#endif
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {

@ -439,31 +439,88 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker { class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
// TODO(buxingyuan): Add Document AddInput(
AddInput("RpnRois", "RpnRois."); "RpnRois",
AddInput("GtClasses", "GtClasses."); "(LoDTensor), This input is a 2D LoDTensor with shape [N, 4]. "
AddInput("IsCrowd", "IsCrowd."); "N is the number of the GenerateProposalOp's output, "
AddInput("GtBoxes", "GtBoxes."); "each element is a bounding box with [xmin, ymin, xmax, ymax] format.");
AddInput("ImInfo", "ImInfo."); AddInput("GtClasses",
"(LoDTensor), This input is a 2D LoDTensor with shape [M, 1]. "
AddOutput("Rois", "Rois."); "M is the number of groundtruth, "
AddOutput("LabelsInt32", "LabelsInt32."); "each element is a class label of groundtruth.");
AddOutput("BboxTargets", "BboxTargets."); AddInput(
AddOutput("BboxInsideWeights", "BboxInsideWeights."); "IsCrowd",
AddOutput("BboxOutsideWeights", "BboxOutsideWeights."); "(LoDTensor), This input is a 2D LoDTensor with shape [M, 1]. "
"M is the number of groundtruth, "
AddAttr<int>("batch_size_per_im", "batch_size_per_im"); "each element is a flag indicates whether a groundtruth is crowd.");
AddAttr<float>("fg_fraction", "fg_fraction"); AddInput(
AddAttr<float>("fg_thresh", "fg_thresh"); "GtBoxes",
AddAttr<float>("bg_thresh_hi", "bg_thresh_hi"); "(LoDTensor), This input is a 2D LoDTensor with shape [M, 4]. "
AddAttr<float>("bg_thresh_lo", "bg_thresh_lo"); "M is the number of groundtruth, "
AddAttr<std::vector<float>>("bbox_reg_weights", "bbox_reg_weights"); "each element is a bounding box with [xmin, ymin, xmax, ymax] format.");
AddAttr<int>("class_nums", "class_nums"); AddInput("ImInfo",
AddAttr<bool>("use_random", "use_random").SetDefault(true); "(Tensor), This input is a 2D Tensor with shape [B, 3]. "
"B is the number of input images, "
"each element consists of im_height, im_width, im_scale.");
AddOutput(
"Rois",
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 4]. "
"P usuall equal to batch_size_per_im * batch_size, "
"each element is a bounding box with [xmin, ymin, xmax, ymax] format.");
AddOutput("LabelsInt32",
"(LoDTensor), This output is a 2D LoDTensor with shape [P], "
"each element repersents a class label of a roi");
AddOutput("BboxTargets",
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 4 * "
"class_nums], "
"each element repersents a box label of a roi");
AddOutput(
"BboxInsideWeights",
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 4 * "
"class_nums], "
"each element indicates whether a box should contribute to loss.");
AddOutput(
"BboxOutsideWeights",
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 4 * "
"class_nums], "
"each element indicates whether a box should contribute to loss.");
AddAttr<int>("batch_size_per_im", "Batch size of rois per images.");
AddAttr<float>("fg_fraction",
"Foreground fraction in total batch_size_per_im.");
AddAttr<float>(
"fg_thresh",
"Overlap threshold which is used to chose foreground sample.");
AddAttr<float>("bg_thresh_hi",
"Overlap threshold upper bound which is used to chose "
"background sample.");
AddAttr<float>("bg_thresh_lo",
"Overlap threshold lower bound which is used to chose "
"background sample.");
AddAttr<std::vector<float>>("bbox_reg_weights", "Box regression weights.");
AddAttr<int>("class_nums", "Class number.");
AddAttr<bool>(
"use_random",
"Use random sampling to choose foreground and background boxes.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Generate Proposals Labels Operator. This operator can be, for given the GenerateProposalOp output bounding boxes and groundtruth,
)DOC"); to sample foreground boxes and background boxes, and compute loss target.
RpnRois is the output boxes of RPN and was processed by generate_proposal_op, these boxes
were combined with groundtruth boxes and sampled according to batch_size_per_im and fg_fraction,
If an instance with a groundtruth overlap greater than fg_thresh, then it was considered as a foreground sample.
If an instance with a groundtruth overlap greater than bg_thresh_lo and lower than bg_thresh_hi,
then it was considered as a background sample.
After all foreground and background boxes are chosen (so called Rois),
then we apply random sampling to make sure
the number of foreground boxes is no more than batch_size_per_im * fg_fraction.
For each box in Rois, we assign the classification (class label) and regression targets (box label) to it.
Finally BboxInsideWeights and BboxOutsideWeights are used to specify whether it would contribute to training loss.
)DOC");
} }
}; };

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save