Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into crop_layer
commit
de5ded6bbd
@ -1,6 +1,8 @@
|
|||||||
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
|
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
|
||||||
|
target_link_libraries(paddle_go_optimizer stdc++ m)
|
||||||
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
|
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
|
||||||
if(WITH_TESTING)
|
if(WITH_TESTING)
|
||||||
# TODO: add unit test
|
# FIXME: this test requires pserver which is not managed by the test
|
||||||
|
# we need some kind of e2e testing machanism.
|
||||||
# add_subdirectory(test)
|
# add_subdirectory(test)
|
||||||
endif()
|
endif()
|
||||||
|
@ -0,0 +1,119 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <boost/variant.hpp>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/framework/enforce.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
|
||||||
|
std::vector<float>, std::vector<std::string>>
|
||||||
|
Attribute;
|
||||||
|
typedef std::unordered_map<std::string, Attribute> AttributeMap;
|
||||||
|
|
||||||
|
// check whether a value(attribute) fit a certain limit
|
||||||
|
template <typename T>
|
||||||
|
class LargerThanChecker {
|
||||||
|
public:
|
||||||
|
LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
|
||||||
|
void operator()(T& value) const {
|
||||||
|
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
T lower_bound_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// we can provide users more common Checker, like 'LessThanChecker',
|
||||||
|
// 'BetweenChecker'...
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class DefaultValueSetter {
|
||||||
|
public:
|
||||||
|
DefaultValueSetter(T default_value) : default_value_(default_value) {}
|
||||||
|
void operator()(T& value) const { value = default_value_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
T default_value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// check whether a certain attribute fit its limits
|
||||||
|
// an attribute can have more than one limits
|
||||||
|
template <typename T>
|
||||||
|
class TypedAttrChecker {
|
||||||
|
typedef std::function<void(T&)> ValueChecker;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
|
||||||
|
|
||||||
|
TypedAttrChecker& LargerThan(const T& lower_bound) {
|
||||||
|
value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// we can add more common limits, like LessThan(), Between()...
|
||||||
|
|
||||||
|
TypedAttrChecker& SetDefault(const T& default_value) {
|
||||||
|
PADDLE_ENFORCE(default_value_setter_.empty(),
|
||||||
|
"%s can't have more than one default value!", attr_name_);
|
||||||
|
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow users provide their own checker
|
||||||
|
TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
|
||||||
|
value_checkers_.push_back(checker);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(AttributeMap& attr_map) const {
|
||||||
|
if (!attr_map.count(attr_name_)) {
|
||||||
|
// user do not set this attr
|
||||||
|
PADDLE_ENFORCE(!default_value_setter_.empty(),
|
||||||
|
"Attribute '%s' is required!", attr_name_);
|
||||||
|
// default_value_setter_ has no more than one element
|
||||||
|
T val;
|
||||||
|
(default_value_setter_[0])(val);
|
||||||
|
attr_map[attr_name_] = val;
|
||||||
|
}
|
||||||
|
Attribute& attr = attr_map.at(attr_name_);
|
||||||
|
T& attr_value = boost::get<T>(attr);
|
||||||
|
for (const auto& checker : value_checkers_) {
|
||||||
|
checker(attr_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string attr_name_;
|
||||||
|
std::vector<ValueChecker> value_checkers_;
|
||||||
|
std::vector<ValueChecker> default_value_setter_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// check whether op's all attributes fit their own limits
|
||||||
|
class OpAttrChecker {
|
||||||
|
typedef std::function<void(AttributeMap&)> AttrChecker;
|
||||||
|
|
||||||
|
public:
|
||||||
|
template <typename T>
|
||||||
|
TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name) {
|
||||||
|
attr_checkers_.push_back(TypedAttrChecker<T>(attr_name));
|
||||||
|
AttrChecker& checker = attr_checkers_.back();
|
||||||
|
return *(checker.target<TypedAttrChecker<T>>());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Check(AttributeMap& attr_map) const {
|
||||||
|
for (const auto& checker : attr_checkers_) {
|
||||||
|
checker(attr_map);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<AttrChecker> attr_checkers_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,122 @@
|
|||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
TEST(OpRegistry, CreateOp) {
|
||||||
|
paddle::framework::OpDesc op_desc;
|
||||||
|
op_desc.set_type("cos_sim");
|
||||||
|
op_desc.add_inputs("aa");
|
||||||
|
op_desc.add_outputs("bb");
|
||||||
|
|
||||||
|
auto attr = op_desc.mutable_attrs()->Add();
|
||||||
|
attr->set_name("scale");
|
||||||
|
attr->set_type(paddle::framework::AttrType::FLOAT);
|
||||||
|
attr->set_f(3.3);
|
||||||
|
|
||||||
|
paddle::framework::OpBase* op =
|
||||||
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||||
|
std::string debug_str = op->Run();
|
||||||
|
std::string str = "CosineOp runs! scale = " + std::to_string(3.3);
|
||||||
|
ASSERT_EQ(str.size(), debug_str.size());
|
||||||
|
for (size_t i = 0; i < debug_str.length(); ++i) {
|
||||||
|
ASSERT_EQ(debug_str[i], str[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OpRegistry, IllegalAttr) {
|
||||||
|
paddle::framework::OpDesc op_desc;
|
||||||
|
op_desc.set_type("cos_sim");
|
||||||
|
op_desc.add_inputs("aa");
|
||||||
|
op_desc.add_outputs("bb");
|
||||||
|
|
||||||
|
auto attr = op_desc.mutable_attrs()->Add();
|
||||||
|
attr->set_name("scale");
|
||||||
|
attr->set_type(paddle::framework::AttrType::FLOAT);
|
||||||
|
attr->set_f(-2.0);
|
||||||
|
|
||||||
|
bool caught = false;
|
||||||
|
try {
|
||||||
|
paddle::framework::OpBase* op __attribute__((unused)) =
|
||||||
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||||
|
} catch (paddle::framework::EnforceNotMet err) {
|
||||||
|
caught = true;
|
||||||
|
std::string msg = "larger_than check fail";
|
||||||
|
const char* err_msg = err.what();
|
||||||
|
for (size_t i = 0; i < msg.length(); ++i) {
|
||||||
|
ASSERT_EQ(err_msg[i], msg[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(caught);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OpRegistry, DefaultValue) {
|
||||||
|
paddle::framework::OpDesc op_desc;
|
||||||
|
op_desc.set_type("cos_sim");
|
||||||
|
op_desc.add_inputs("aa");
|
||||||
|
op_desc.add_outputs("bb");
|
||||||
|
|
||||||
|
paddle::framework::OpBase* op =
|
||||||
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||||
|
std::string debug_str = op->Run();
|
||||||
|
float default_value = 1.0;
|
||||||
|
std::string str = "CosineOp runs! scale = " + std::to_string(default_value);
|
||||||
|
ASSERT_EQ(str.size(), debug_str.size());
|
||||||
|
for (size_t i = 0; i < debug_str.length(); ++i) {
|
||||||
|
ASSERT_EQ(debug_str[i], str[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OpRegistry, CustomChecker) {
|
||||||
|
paddle::framework::OpDesc op_desc;
|
||||||
|
op_desc.set_type("my_test_op");
|
||||||
|
op_desc.add_inputs("ii");
|
||||||
|
op_desc.add_outputs("oo");
|
||||||
|
|
||||||
|
// attr 'test_attr' is not set
|
||||||
|
bool caught = false;
|
||||||
|
try {
|
||||||
|
paddle::framework::OpBase* op __attribute__((unused)) =
|
||||||
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||||
|
} catch (paddle::framework::EnforceNotMet err) {
|
||||||
|
caught = true;
|
||||||
|
std::string msg = "Attribute 'test_attr' is required!";
|
||||||
|
const char* err_msg = err.what();
|
||||||
|
for (size_t i = 0; i < msg.length(); ++i) {
|
||||||
|
ASSERT_EQ(err_msg[i], msg[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(caught);
|
||||||
|
|
||||||
|
// set 'test_attr' set to an illegal value
|
||||||
|
auto attr = op_desc.mutable_attrs()->Add();
|
||||||
|
attr->set_name("test_attr");
|
||||||
|
attr->set_type(paddle::framework::AttrType::INT);
|
||||||
|
attr->set_i(3);
|
||||||
|
caught = false;
|
||||||
|
try {
|
||||||
|
paddle::framework::OpBase* op __attribute__((unused)) =
|
||||||
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||||
|
} catch (paddle::framework::EnforceNotMet err) {
|
||||||
|
caught = true;
|
||||||
|
std::string msg = "'test_attr' must be even!";
|
||||||
|
const char* err_msg = err.what();
|
||||||
|
for (size_t i = 0; i < msg.length(); ++i) {
|
||||||
|
ASSERT_EQ(err_msg[i], msg[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(caught);
|
||||||
|
|
||||||
|
// set 'test_attr' set to a legal value
|
||||||
|
op_desc.mutable_attrs()->Clear();
|
||||||
|
attr = op_desc.mutable_attrs()->Add();
|
||||||
|
attr->set_name("test_attr");
|
||||||
|
attr->set_type(paddle::framework::AttrType::INT);
|
||||||
|
attr->set_i(4);
|
||||||
|
paddle::framework::OpBase* op =
|
||||||
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||||
|
std::string debug_str = op->Run();
|
||||||
|
std::string str = "MyTestOp runs! test_attr = " + std::to_string(4);
|
||||||
|
ASSERT_EQ(str.size(), debug_str.size());
|
||||||
|
for (size_t i = 0; i < debug_str.length(); ++i) {
|
||||||
|
ASSERT_EQ(debug_str[i], str[i]);
|
||||||
|
}
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue