Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into crop_layer
	
		
	
				
					
				
			
						commit
						de5ded6bbd
					
				@ -1,6 +1,8 @@
 | 
				
			|||||||
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
 | 
					cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
 | 
				
			||||||
 | 
					target_link_libraries(paddle_go_optimizer stdc++ m)
 | 
				
			||||||
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
 | 
					go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
 | 
				
			||||||
if(WITH_TESTING)
 | 
					if(WITH_TESTING)
 | 
				
			||||||
    # TODO: add unit test
 | 
					  # FIXME: this test requires pserver which is not managed by the test
 | 
				
			||||||
    #add_subdirectory(test)
 | 
					  # we need some kind of e2e testing machanism.
 | 
				
			||||||
 | 
					  # add_subdirectory(test)
 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,119 @@
 | 
				
			|||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <boost/variant.hpp>
 | 
				
			||||||
 | 
					#include <functional>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <unordered_map>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					#include "paddle/framework/enforce.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace framework {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
 | 
				
			||||||
 | 
					                       std::vector<float>, std::vector<std::string>>
 | 
				
			||||||
 | 
					    Attribute;
 | 
				
			||||||
 | 
					typedef std::unordered_map<std::string, Attribute> AttributeMap;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// check whether a value(attribute) fit a certain limit
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					class LargerThanChecker {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
 | 
				
			||||||
 | 
					  void operator()(T& value) const {
 | 
				
			||||||
 | 
					    PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  T lower_bound_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// we can provide users more common Checker, like 'LessThanChecker',
 | 
				
			||||||
 | 
					// 'BetweenChecker'...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					class DefaultValueSetter {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  DefaultValueSetter(T default_value) : default_value_(default_value) {}
 | 
				
			||||||
 | 
					  void operator()(T& value) const { value = default_value_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  T default_value_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// check whether a certain attribute fit its limits
 | 
				
			||||||
 | 
					// an attribute can have more than one limits
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					class TypedAttrChecker {
 | 
				
			||||||
 | 
					  typedef std::function<void(T&)> ValueChecker;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  TypedAttrChecker& LargerThan(const T& lower_bound) {
 | 
				
			||||||
 | 
					    value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
 | 
				
			||||||
 | 
					    return *this;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // we can add more common limits, like LessThan(), Between()...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  TypedAttrChecker& SetDefault(const T& default_value) {
 | 
				
			||||||
 | 
					    PADDLE_ENFORCE(default_value_setter_.empty(),
 | 
				
			||||||
 | 
					                   "%s can't have more than one default value!", attr_name_);
 | 
				
			||||||
 | 
					    default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
 | 
				
			||||||
 | 
					    return *this;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // allow users provide their own checker
 | 
				
			||||||
 | 
					  TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
 | 
				
			||||||
 | 
					    value_checkers_.push_back(checker);
 | 
				
			||||||
 | 
					    return *this;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void operator()(AttributeMap& attr_map) const {
 | 
				
			||||||
 | 
					    if (!attr_map.count(attr_name_)) {
 | 
				
			||||||
 | 
					      // user do not set this attr
 | 
				
			||||||
 | 
					      PADDLE_ENFORCE(!default_value_setter_.empty(),
 | 
				
			||||||
 | 
					                     "Attribute '%s' is required!", attr_name_);
 | 
				
			||||||
 | 
					      // default_value_setter_ has no more than one element
 | 
				
			||||||
 | 
					      T val;
 | 
				
			||||||
 | 
					      (default_value_setter_[0])(val);
 | 
				
			||||||
 | 
					      attr_map[attr_name_] = val;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    Attribute& attr = attr_map.at(attr_name_);
 | 
				
			||||||
 | 
					    T& attr_value = boost::get<T>(attr);
 | 
				
			||||||
 | 
					    for (const auto& checker : value_checkers_) {
 | 
				
			||||||
 | 
					      checker(attr_value);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  std::string attr_name_;
 | 
				
			||||||
 | 
					  std::vector<ValueChecker> value_checkers_;
 | 
				
			||||||
 | 
					  std::vector<ValueChecker> default_value_setter_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// check whether op's all attributes fit their own limits
 | 
				
			||||||
 | 
					class OpAttrChecker {
 | 
				
			||||||
 | 
					  typedef std::function<void(AttributeMap&)> AttrChecker;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  template <typename T>
 | 
				
			||||||
 | 
					  TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name) {
 | 
				
			||||||
 | 
					    attr_checkers_.push_back(TypedAttrChecker<T>(attr_name));
 | 
				
			||||||
 | 
					    AttrChecker& checker = attr_checkers_.back();
 | 
				
			||||||
 | 
					    return *(checker.target<TypedAttrChecker<T>>());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void Check(AttributeMap& attr_map) const {
 | 
				
			||||||
 | 
					    for (const auto& checker : attr_checkers_) {
 | 
				
			||||||
 | 
					      checker(attr_map);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  std::vector<AttrChecker> attr_checkers_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace framework
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,122 @@
 | 
				
			|||||||
 | 
					#include "paddle/framework/op_registry.h"
 | 
				
			||||||
 | 
					#include <gtest/gtest.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(OpRegistry, CreateOp) {
 | 
				
			||||||
 | 
					  paddle::framework::OpDesc op_desc;
 | 
				
			||||||
 | 
					  op_desc.set_type("cos_sim");
 | 
				
			||||||
 | 
					  op_desc.add_inputs("aa");
 | 
				
			||||||
 | 
					  op_desc.add_outputs("bb");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto attr = op_desc.mutable_attrs()->Add();
 | 
				
			||||||
 | 
					  attr->set_name("scale");
 | 
				
			||||||
 | 
					  attr->set_type(paddle::framework::AttrType::FLOAT);
 | 
				
			||||||
 | 
					  attr->set_f(3.3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  paddle::framework::OpBase* op =
 | 
				
			||||||
 | 
					      paddle::framework::OpRegistry::CreateOp(op_desc);
 | 
				
			||||||
 | 
					  std::string debug_str = op->Run();
 | 
				
			||||||
 | 
					  std::string str = "CosineOp runs! scale = " + std::to_string(3.3);
 | 
				
			||||||
 | 
					  ASSERT_EQ(str.size(), debug_str.size());
 | 
				
			||||||
 | 
					  for (size_t i = 0; i < debug_str.length(); ++i) {
 | 
				
			||||||
 | 
					    ASSERT_EQ(debug_str[i], str[i]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(OpRegistry, IllegalAttr) {
 | 
				
			||||||
 | 
					  paddle::framework::OpDesc op_desc;
 | 
				
			||||||
 | 
					  op_desc.set_type("cos_sim");
 | 
				
			||||||
 | 
					  op_desc.add_inputs("aa");
 | 
				
			||||||
 | 
					  op_desc.add_outputs("bb");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto attr = op_desc.mutable_attrs()->Add();
 | 
				
			||||||
 | 
					  attr->set_name("scale");
 | 
				
			||||||
 | 
					  attr->set_type(paddle::framework::AttrType::FLOAT);
 | 
				
			||||||
 | 
					  attr->set_f(-2.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool caught = false;
 | 
				
			||||||
 | 
					  try {
 | 
				
			||||||
 | 
					    paddle::framework::OpBase* op __attribute__((unused)) =
 | 
				
			||||||
 | 
					        paddle::framework::OpRegistry::CreateOp(op_desc);
 | 
				
			||||||
 | 
					  } catch (paddle::framework::EnforceNotMet err) {
 | 
				
			||||||
 | 
					    caught = true;
 | 
				
			||||||
 | 
					    std::string msg = "larger_than check fail";
 | 
				
			||||||
 | 
					    const char* err_msg = err.what();
 | 
				
			||||||
 | 
					    for (size_t i = 0; i < msg.length(); ++i) {
 | 
				
			||||||
 | 
					      ASSERT_EQ(err_msg[i], msg[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  ASSERT_TRUE(caught);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(OpRegistry, DefaultValue) {
 | 
				
			||||||
 | 
					  paddle::framework::OpDesc op_desc;
 | 
				
			||||||
 | 
					  op_desc.set_type("cos_sim");
 | 
				
			||||||
 | 
					  op_desc.add_inputs("aa");
 | 
				
			||||||
 | 
					  op_desc.add_outputs("bb");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  paddle::framework::OpBase* op =
 | 
				
			||||||
 | 
					      paddle::framework::OpRegistry::CreateOp(op_desc);
 | 
				
			||||||
 | 
					  std::string debug_str = op->Run();
 | 
				
			||||||
 | 
					  float default_value = 1.0;
 | 
				
			||||||
 | 
					  std::string str = "CosineOp runs! scale = " + std::to_string(default_value);
 | 
				
			||||||
 | 
					  ASSERT_EQ(str.size(), debug_str.size());
 | 
				
			||||||
 | 
					  for (size_t i = 0; i < debug_str.length(); ++i) {
 | 
				
			||||||
 | 
					    ASSERT_EQ(debug_str[i], str[i]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(OpRegistry, CustomChecker) {
 | 
				
			||||||
 | 
					  paddle::framework::OpDesc op_desc;
 | 
				
			||||||
 | 
					  op_desc.set_type("my_test_op");
 | 
				
			||||||
 | 
					  op_desc.add_inputs("ii");
 | 
				
			||||||
 | 
					  op_desc.add_outputs("oo");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // attr 'test_attr' is not set
 | 
				
			||||||
 | 
					  bool caught = false;
 | 
				
			||||||
 | 
					  try {
 | 
				
			||||||
 | 
					    paddle::framework::OpBase* op __attribute__((unused)) =
 | 
				
			||||||
 | 
					        paddle::framework::OpRegistry::CreateOp(op_desc);
 | 
				
			||||||
 | 
					  } catch (paddle::framework::EnforceNotMet err) {
 | 
				
			||||||
 | 
					    caught = true;
 | 
				
			||||||
 | 
					    std::string msg = "Attribute 'test_attr' is required!";
 | 
				
			||||||
 | 
					    const char* err_msg = err.what();
 | 
				
			||||||
 | 
					    for (size_t i = 0; i < msg.length(); ++i) {
 | 
				
			||||||
 | 
					      ASSERT_EQ(err_msg[i], msg[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  ASSERT_TRUE(caught);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // set 'test_attr' set to an illegal value
 | 
				
			||||||
 | 
					  auto attr = op_desc.mutable_attrs()->Add();
 | 
				
			||||||
 | 
					  attr->set_name("test_attr");
 | 
				
			||||||
 | 
					  attr->set_type(paddle::framework::AttrType::INT);
 | 
				
			||||||
 | 
					  attr->set_i(3);
 | 
				
			||||||
 | 
					  caught = false;
 | 
				
			||||||
 | 
					  try {
 | 
				
			||||||
 | 
					    paddle::framework::OpBase* op __attribute__((unused)) =
 | 
				
			||||||
 | 
					        paddle::framework::OpRegistry::CreateOp(op_desc);
 | 
				
			||||||
 | 
					  } catch (paddle::framework::EnforceNotMet err) {
 | 
				
			||||||
 | 
					    caught = true;
 | 
				
			||||||
 | 
					    std::string msg = "'test_attr' must be even!";
 | 
				
			||||||
 | 
					    const char* err_msg = err.what();
 | 
				
			||||||
 | 
					    for (size_t i = 0; i < msg.length(); ++i) {
 | 
				
			||||||
 | 
					      ASSERT_EQ(err_msg[i], msg[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  ASSERT_TRUE(caught);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // set 'test_attr' set to a legal value
 | 
				
			||||||
 | 
					  op_desc.mutable_attrs()->Clear();
 | 
				
			||||||
 | 
					  attr = op_desc.mutable_attrs()->Add();
 | 
				
			||||||
 | 
					  attr->set_name("test_attr");
 | 
				
			||||||
 | 
					  attr->set_type(paddle::framework::AttrType::INT);
 | 
				
			||||||
 | 
					  attr->set_i(4);
 | 
				
			||||||
 | 
					  paddle::framework::OpBase* op =
 | 
				
			||||||
 | 
					      paddle::framework::OpRegistry::CreateOp(op_desc);
 | 
				
			||||||
 | 
					  std::string debug_str = op->Run();
 | 
				
			||||||
 | 
					  std::string str = "MyTestOp runs! test_attr = " + std::to_string(4);
 | 
				
			||||||
 | 
					  ASSERT_EQ(str.size(), debug_str.size());
 | 
				
			||||||
 | 
					  for (size_t i = 0; i < debug_str.length(); ++i) {
 | 
				
			||||||
 | 
					    ASSERT_EQ(debug_str[i], str[i]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue