|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/framework/attribute.h"
|
|
|
|
|
#include "paddle/framework/op_info.h"
|
|
|
|
|
#include "paddle/framework/var_desc.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -61,48 +62,22 @@ class OpDescBind {
|
|
|
|
|
void SetBlockAttr(const std::string &name, BlockDescBind &block);
|
|
|
|
|
|
|
|
|
|
// Only be used in C++
|
|
|
|
|
void SetAttrMap(const std::unordered_map<std::string, Attribute> &attr_map);
|
|
|
|
|
void SetAttrMap(const AttributeMap &attr_map);
|
|
|
|
|
|
|
|
|
|
Attribute GetAttr(const std::string &name) const;
|
|
|
|
|
|
|
|
|
|
int GetBlockAttr(const std::string &name) const;
|
|
|
|
|
|
|
|
|
|
// Only be used in C++
|
|
|
|
|
const std::unordered_map<std::string, Attribute> &GetAttrMap() const;
|
|
|
|
|
const AttributeMap &GetAttrMap() const;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
struct SetAttrDescVisitor : public boost::static_visitor<void> {
|
|
|
|
|
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
|
|
|
|
|
mutable OpDesc::Attr *attr_;
|
|
|
|
|
void operator()(int v) const { attr_->set_i(v); }
|
|
|
|
|
void operator()(float v) const { attr_->set_f(v); }
|
|
|
|
|
void operator()(const std::string &v) const { attr_->set_s(v); }
|
|
|
|
|
void operator()(bool b) const { attr_->set_b(b); }
|
|
|
|
|
|
|
|
|
|
void operator()(const std::vector<int> &v) const {
|
|
|
|
|
VectorToRepeated(v, attr_->mutable_ints());
|
|
|
|
|
}
|
|
|
|
|
void operator()(const std::vector<float> &v) const {
|
|
|
|
|
VectorToRepeated(v, attr_->mutable_floats());
|
|
|
|
|
}
|
|
|
|
|
void operator()(const std::vector<std::string> &v) const {
|
|
|
|
|
VectorToRepeated(v, attr_->mutable_strings());
|
|
|
|
|
}
|
|
|
|
|
void operator()(const std::vector<bool> &v) const {
|
|
|
|
|
VectorToRepeated(v, attr_->mutable_bools());
|
|
|
|
|
}
|
|
|
|
|
void operator()(BlockDesc *desc) const {
|
|
|
|
|
attr_->set_block_idx(desc->idx());
|
|
|
|
|
}
|
|
|
|
|
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void Sync();
|
|
|
|
|
|
|
|
|
|
OpDesc op_desc_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::string>> inputs_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::string>> outputs_;
|
|
|
|
|
std::unordered_map<std::string, Attribute> attrs_;
|
|
|
|
|
VariableNameMap inputs_;
|
|
|
|
|
VariableNameMap outputs_;
|
|
|
|
|
AttributeMap attrs_;
|
|
|
|
|
|
|
|
|
|
// need_update_ indicate there some local changes not be synchronized. If
|
|
|
|
|
// local changes should be synchronized, need_update_ should be set to true.
|
|
|
|
|