Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into pixel_softmax_layer
commit
3aa679814f
@ -1,5 +1,8 @@
|
||||
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
|
||||
target_link_libraries(paddle_go_optimizer stdc++ m)
|
||||
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
|
||||
if(WITH_TESTING)
|
||||
add_subdirectory(test)
|
||||
# FIXME: this test requires pserver which is not managed by the test
|
||||
# we need some kind of e2e testing machanism.
|
||||
# add_subdirectory(test)
|
||||
endif()
|
||||
|
@ -1,2 +1,2 @@
|
||||
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient)
|
||||
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer)
|
||||
add_style_check_target(test_cclient test_cclient.c)
|
||||
|
@ -0,0 +1,119 @@
|
||||
#pragma once
|
||||
|
||||
#include <boost/variant.hpp>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/framework/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
|
||||
std::vector<float>, std::vector<std::string>>
|
||||
Attribute;
|
||||
typedef std::unordered_map<std::string, Attribute> AttributeMap;
|
||||
|
||||
// check whether a value(attribute) fit a certain limit
|
||||
template <typename T>
|
||||
class LargerThanChecker {
|
||||
public:
|
||||
LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
|
||||
void operator()(T& value) const {
|
||||
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail");
|
||||
}
|
||||
|
||||
private:
|
||||
T lower_bound_;
|
||||
};
|
||||
|
||||
// we can provide users more common Checker, like 'LessThanChecker',
|
||||
// 'BetweenChecker'...
|
||||
|
||||
template <typename T>
|
||||
class DefaultValueSetter {
|
||||
public:
|
||||
DefaultValueSetter(T default_value) : default_value_(default_value) {}
|
||||
void operator()(T& value) const { value = default_value_; }
|
||||
|
||||
private:
|
||||
T default_value_;
|
||||
};
|
||||
|
||||
// check whether a certain attribute fit its limits
|
||||
// an attribute can have more than one limits
|
||||
template <typename T>
|
||||
class TypedAttrChecker {
|
||||
typedef std::function<void(T&)> ValueChecker;
|
||||
|
||||
public:
|
||||
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
|
||||
|
||||
TypedAttrChecker& LargerThan(const T& lower_bound) {
|
||||
value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
|
||||
return *this;
|
||||
}
|
||||
|
||||
// we can add more common limits, like LessThan(), Between()...
|
||||
|
||||
TypedAttrChecker& SetDefault(const T& default_value) {
|
||||
PADDLE_ENFORCE(default_value_setter_.empty(),
|
||||
"%s can't have more than one default value!", attr_name_);
|
||||
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
|
||||
return *this;
|
||||
}
|
||||
|
||||
// allow users provide their own checker
|
||||
TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
|
||||
value_checkers_.push_back(checker);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void operator()(AttributeMap& attr_map) const {
|
||||
if (!attr_map.count(attr_name_)) {
|
||||
// user do not set this attr
|
||||
PADDLE_ENFORCE(!default_value_setter_.empty(),
|
||||
"Attribute '%s' is required!", attr_name_);
|
||||
// default_value_setter_ has no more than one element
|
||||
T val;
|
||||
(default_value_setter_[0])(val);
|
||||
attr_map[attr_name_] = val;
|
||||
}
|
||||
Attribute& attr = attr_map.at(attr_name_);
|
||||
T& attr_value = boost::get<T>(attr);
|
||||
for (const auto& checker : value_checkers_) {
|
||||
checker(attr_value);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string attr_name_;
|
||||
std::vector<ValueChecker> value_checkers_;
|
||||
std::vector<ValueChecker> default_value_setter_;
|
||||
};
|
||||
|
||||
// check whether op's all attributes fit their own limits
|
||||
class OpAttrChecker {
|
||||
typedef std::function<void(AttributeMap&)> AttrChecker;
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name) {
|
||||
attr_checkers_.push_back(TypedAttrChecker<T>(attr_name));
|
||||
AttrChecker& checker = attr_checkers_.back();
|
||||
return *(checker.target<TypedAttrChecker<T>>());
|
||||
}
|
||||
|
||||
void Check(AttributeMap& attr_map) const {
|
||||
for (const auto& checker : attr_checkers_) {
|
||||
checker(attr_map);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<AttrChecker> attr_checkers_;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,122 @@
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
TEST(OpRegistry, CreateOp) {
|
||||
paddle::framework::OpDesc op_desc;
|
||||
op_desc.set_type("cos_sim");
|
||||
op_desc.add_inputs("aa");
|
||||
op_desc.add_outputs("bb");
|
||||
|
||||
auto attr = op_desc.mutable_attrs()->Add();
|
||||
attr->set_name("scale");
|
||||
attr->set_type(paddle::framework::AttrType::FLOAT);
|
||||
attr->set_f(3.3);
|
||||
|
||||
paddle::framework::OpBase* op =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
std::string debug_str = op->Run();
|
||||
std::string str = "CosineOp runs! scale = " + std::to_string(3.3);
|
||||
ASSERT_EQ(str.size(), debug_str.size());
|
||||
for (size_t i = 0; i < debug_str.length(); ++i) {
|
||||
ASSERT_EQ(debug_str[i], str[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OpRegistry, IllegalAttr) {
|
||||
paddle::framework::OpDesc op_desc;
|
||||
op_desc.set_type("cos_sim");
|
||||
op_desc.add_inputs("aa");
|
||||
op_desc.add_outputs("bb");
|
||||
|
||||
auto attr = op_desc.mutable_attrs()->Add();
|
||||
attr->set_name("scale");
|
||||
attr->set_type(paddle::framework::AttrType::FLOAT);
|
||||
attr->set_f(-2.0);
|
||||
|
||||
bool caught = false;
|
||||
try {
|
||||
paddle::framework::OpBase* op __attribute__((unused)) =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
} catch (paddle::framework::EnforceNotMet err) {
|
||||
caught = true;
|
||||
std::string msg = "larger_than check fail";
|
||||
const char* err_msg = err.what();
|
||||
for (size_t i = 0; i < msg.length(); ++i) {
|
||||
ASSERT_EQ(err_msg[i], msg[i]);
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(caught);
|
||||
}
|
||||
|
||||
TEST(OpRegistry, DefaultValue) {
|
||||
paddle::framework::OpDesc op_desc;
|
||||
op_desc.set_type("cos_sim");
|
||||
op_desc.add_inputs("aa");
|
||||
op_desc.add_outputs("bb");
|
||||
|
||||
paddle::framework::OpBase* op =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
std::string debug_str = op->Run();
|
||||
float default_value = 1.0;
|
||||
std::string str = "CosineOp runs! scale = " + std::to_string(default_value);
|
||||
ASSERT_EQ(str.size(), debug_str.size());
|
||||
for (size_t i = 0; i < debug_str.length(); ++i) {
|
||||
ASSERT_EQ(debug_str[i], str[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OpRegistry, CustomChecker) {
|
||||
paddle::framework::OpDesc op_desc;
|
||||
op_desc.set_type("my_test_op");
|
||||
op_desc.add_inputs("ii");
|
||||
op_desc.add_outputs("oo");
|
||||
|
||||
// attr 'test_attr' is not set
|
||||
bool caught = false;
|
||||
try {
|
||||
paddle::framework::OpBase* op __attribute__((unused)) =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
} catch (paddle::framework::EnforceNotMet err) {
|
||||
caught = true;
|
||||
std::string msg = "Attribute 'test_attr' is required!";
|
||||
const char* err_msg = err.what();
|
||||
for (size_t i = 0; i < msg.length(); ++i) {
|
||||
ASSERT_EQ(err_msg[i], msg[i]);
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(caught);
|
||||
|
||||
// set 'test_attr' set to an illegal value
|
||||
auto attr = op_desc.mutable_attrs()->Add();
|
||||
attr->set_name("test_attr");
|
||||
attr->set_type(paddle::framework::AttrType::INT);
|
||||
attr->set_i(3);
|
||||
caught = false;
|
||||
try {
|
||||
paddle::framework::OpBase* op __attribute__((unused)) =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
} catch (paddle::framework::EnforceNotMet err) {
|
||||
caught = true;
|
||||
std::string msg = "'test_attr' must be even!";
|
||||
const char* err_msg = err.what();
|
||||
for (size_t i = 0; i < msg.length(); ++i) {
|
||||
ASSERT_EQ(err_msg[i], msg[i]);
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(caught);
|
||||
|
||||
// set 'test_attr' set to a legal value
|
||||
op_desc.mutable_attrs()->Clear();
|
||||
attr = op_desc.mutable_attrs()->Add();
|
||||
attr->set_name("test_attr");
|
||||
attr->set_type(paddle::framework::AttrType::INT);
|
||||
attr->set_i(4);
|
||||
paddle::framework::OpBase* op =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
std::string debug_str = op->Run();
|
||||
std::string str = "MyTestOp runs! test_attr = " + std::to_string(4);
|
||||
ASSERT_EQ(str.size(), debug_str.size());
|
||||
for (size_t i = 0; i < debug_str.length(); ++i) {
|
||||
ASSERT_EQ(debug_str[i], str[i]);
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue