From 8fd845e0fab40acc9539c4109feb3ed411f4dc8b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 30 Sep 2017 16:55:40 -0700 Subject: [PATCH 1/2] Unify Map in OpDescBind --- paddle/framework/op_desc.cc | 27 ++++++++++++++++++++++++++- paddle/framework/op_desc.h | 37 ++++++------------------------------- paddle/platform/enforce.h | 4 ++-- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 0c12c55dc0..33a064890c 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -112,6 +112,30 @@ const std::unordered_map &OpDescBind::GetAttrMap() return attrs_; } +struct SetAttrDescVisitor : public boost::static_visitor { + explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} + mutable OpDesc::Attr *attr_; + void operator()(int v) const { attr_->set_i(v); } + void operator()(float v) const { attr_->set_f(v); } + void operator()(const std::string &v) const { attr_->set_s(v); } + void operator()(bool b) const { attr_->set_b(b); } + + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_ints()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_floats()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_strings()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_bools()); + } + void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); } + void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } +}; + void OpDescBind::Sync() { if (need_update_) { this->op_desc_.mutable_inputs()->Clear(); @@ -134,7 +158,8 @@ void OpDescBind::Sync() { attr_desc->set_name(attr.first); attr_desc->set_type( static_cast(attr.second.which() - 1)); - boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second); + SetAttrDescVisitor visitor(attr_desc); + boost::apply_visitor(visitor, attr.second); } need_update_ = false; diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 0cf7d13971..e03b4d067f 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/framework/attribute.h" +#include "paddle/framework/op_info.h" #include "paddle/framework/var_desc.h" namespace paddle { @@ -61,48 +62,22 @@ class OpDescBind { void SetBlockAttr(const std::string &name, BlockDescBind &block); // Only be used in C++ - void SetAttrMap(const std::unordered_map &attr_map); + void SetAttrMap(const AttributeMap &attr_map); Attribute GetAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const; // Only be used in C++ - const std::unordered_map &GetAttrMap() const; + const AttributeMap &GetAttrMap() const; private: - struct SetAttrDescVisitor : public boost::static_visitor { - explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} - mutable OpDesc::Attr *attr_; - void operator()(int v) const { attr_->set_i(v); } - void operator()(float v) const { attr_->set_f(v); } - void operator()(const std::string &v) const { attr_->set_s(v); } - void operator()(bool b) const { attr_->set_b(b); } - - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_ints()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_floats()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_strings()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_bools()); - } - void operator()(BlockDesc *desc) const { - attr_->set_block_idx(desc->idx()); - } - void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } - }; - void Sync(); OpDesc op_desc_; - std::unordered_map> inputs_; - std::unordered_map> outputs_; - std::unordered_map attrs_; + VariableNameMap inputs_; + VariableNameMap outputs_; + AttributeMap attrs_; // need_update_ indicate there some local changes not be synchronized. If // local changes should be synchronized, need_update_ should be set to true. diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index b523ef03c0..52bd23039b 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -185,7 +185,7 @@ inline void throw_on_error(T e) { std::make_exception_ptr( \ std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \ __FILE__, __LINE__); \ - } while (0) + } while (false) #define PADDLE_ENFORCE(...) \ do { \ @@ -195,7 +195,7 @@ inline void throw_on_error(T e) { throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ __FILE__, __LINE__); \ } \ - } while (0) + } while (false) /* * Some enforce helpers here, usage: From 2296d81cf9560437b368354229b7ceb22b67d234 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 2 Oct 2017 11:39:10 -0700 Subject: [PATCH 2/2] Use `type_defs.h` to resolve cyclic dependencies --- paddle/framework/attribute.h | 10 +--------- paddle/framework/op_desc.h | 2 +- paddle/framework/op_info.h | 7 +------ paddle/framework/type_defs.h | 38 ++++++++++++++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 16 deletions(-) create mode 100644 paddle/framework/type_defs.h diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index c7559cefb6..d13530e340 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -21,20 +21,12 @@ limitations under the License. */ #include #include "paddle/framework/framework.pb.h" +#include "paddle/framework/type_defs.h" #include "paddle/platform/enforce.h" -#include "paddle/platform/variant.h" namespace paddle { namespace framework { -// The order should be as same as framework.proto -typedef boost::variant, - std::vector, std::vector, bool, - std::vector, BlockDesc*> - Attribute; - -typedef std::unordered_map AttributeMap; - ProgramDesc& GetProgramDesc(); template diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index e03b4d067f..0af4169715 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include #include "paddle/framework/attribute.h" -#include "paddle/framework/op_info.h" +#include "paddle/framework/type_defs.h" #include "paddle/framework/var_desc.h" namespace paddle { diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 6d1ee4dece..470336d367 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -19,15 +19,10 @@ #include #include "paddle/framework/attribute.h" #include "paddle/framework/op_desc.h" +#include "paddle/framework/type_defs.h" namespace paddle { namespace framework { -class OperatorBase; -using VariableNameMap = std::map>; - -using OpCreator = std::function; class GradOpDescMakerBase { public: diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h new file mode 100644 index 0000000000..dec5066f1e --- /dev/null +++ b/paddle/framework/type_defs.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/platform/variant.h" + +namespace paddle { +namespace framework { +class OperatorBase; +using VariableNameMap = std::map>; + +// The order should be as same as framework.proto +using Attribute = + boost::variant, + std::vector, std::vector, bool, + std::vector, BlockDesc*>; + +using AttributeMap = std::unordered_map; + +using OpCreator = std::function; + +} // namespace framework +} // namespace paddle