Move framework.proto to proto namespace (#6718)

* Move framework.proto to proto namespace

* Fix compile

* Fix compile

* Fix Compile
del_some_in_makelist
Yu Yang 8 years ago committed by GitHub
parent a87f4963ed
commit e445b3ff20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -53,7 +53,7 @@ Kernel实现 | CPU、CUDA共享Kernel实现在`.h`文件中否则CPU
```cpp ```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)"); AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)"); AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
@ -82,7 +82,7 @@ The equation is: Out = X * Y
template <typename AttrType> template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient(); AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); AddOutput("Out", "The output tensor of scale operator.").NotInGradient();

@ -50,7 +50,7 @@ First, define `ProtoMaker` to describe the Operator's input, output, and additio
```cpp ```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)"); AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)"); AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
@ -79,7 +79,7 @@ An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/de
template <typename AttrType> template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient(); AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); AddOutput("Out", "The output tensor of scale operator.").NotInGradient();

@ -19,42 +19,42 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: { case proto::AttrType::BOOLEAN: {
return attr_desc.b(); return attr_desc.b();
} }
case framework::AttrType::INT: { case proto::AttrType::INT: {
return attr_desc.i(); return attr_desc.i();
} }
case framework::AttrType::FLOAT: { case proto::AttrType::FLOAT: {
return attr_desc.f(); return attr_desc.f();
} }
case framework::AttrType::STRING: { case proto::AttrType::STRING: {
return attr_desc.s(); return attr_desc.s();
} }
case framework::AttrType::BOOLEANS: { case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size()); std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) { for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i); val[i] = attr_desc.bools(i);
} }
return val; return val;
} }
case framework::AttrType::INTS: { case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size()); std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) { for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i); val[i] = attr_desc.ints(i);
} }
return val; return val;
} }
case framework::AttrType::FLOATS: { case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size()); std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) { for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i); val[i] = attr_desc.floats(i);
} }
return val; return val;
} }
case framework::AttrType::STRINGS: { case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size()); std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) { for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i); val[i] = attr_desc.strings(i);

@ -27,12 +27,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T> template <typename T>
inline AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
return static_cast<AttrType>(tmp.which() - 1); return static_cast<proto::AttrType>(tmp.which() - 1);
} }
Attribute GetAttrValue(const OpDesc::Attr& attr_desc); Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader { class AttrReader {
public: public:

@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname); auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg); auto* grad = block_desc->FindVar(arg);
if (param == nullptr) { if (param == nullptr) {
grad->SetDataType(DataType::FP32); grad->SetDataType(proto::DataType::FP32);
} else { } else {
grad->SetDataType(param->GetDataType()); grad->SetDataType(param->GetDataType());
} }

@ -166,7 +166,7 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
class SumOpMaker : public framework::OpProtoAndCheckerMaker { class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.").AsDuplicable(); AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator."); AddOutput("Out", "the output tensor of sum operator.");

@ -128,22 +128,22 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx())); return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
} }
BlockDesc *BlockDescBind::Proto() { proto::BlockDesc *BlockDescBind::Proto() {
Flush(); Flush();
return desc_; return desc_;
} }
BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind::BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) { : prog_(prog), desc_(desc), need_update_(false) {
for (const VarDesc &var_desc : desc_->vars()) { for (const proto::VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc)); vars_[var_desc.name()].reset(new VarDescBind(var_desc));
} }
for (const OpDesc &op_desc : desc_->ops()) { for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog)); ops_.emplace_back(new OpDescBind(op_desc, prog));
} }
} }
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc, BlockDescBind::BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc,
ProgramDescBind *prog) ProgramDescBind *prog)
: prog_(prog), desc_(desc) { : prog_(prog), desc_(desc) {
need_update_ = true; need_update_ = true;

@ -36,9 +36,9 @@ class ProgramDescBind;
class BlockDescBind { class BlockDescBind {
public: public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc); BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc);
BlockDescBind(const BlockDescBind &other, BlockDesc *desc, BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc,
ProgramDescBind *prog); ProgramDescBind *prog);
~BlockDescBind() { ~BlockDescBind() {
@ -88,7 +88,7 @@ class BlockDescBind {
void Flush(); void Flush();
BlockDesc *Proto(); proto::BlockDesc *Proto();
ProgramDescBind *Program() { return this->prog_; } ProgramDescBind *Program() { return this->prog_; }
@ -97,8 +97,8 @@ class BlockDescBind {
void ClearPBVars(); void ClearPBVars();
private: private:
ProgramDescBind *prog_; // not_own ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own proto::BlockDesc *desc_; // not_own
bool need_update_; bool need_update_;
std::deque<std::unique_ptr<OpDescBind>> ops_; std::deque<std::unique_ptr<OpDescBind>> ops_;

@ -20,7 +20,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline DataType ToDataType(std::type_index type) { inline proto::DataType ToDataType(std::type_index type) {
using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) { if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32; return DataType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) { } else if (typeid(double).hash_code() == type.hash_code()) {
@ -36,7 +37,8 @@ inline DataType ToDataType(std::type_index type) {
} }
} }
inline std::type_index ToTypeIndex(DataType type) { inline std::type_index ToTypeIndex(proto::DataType type) {
using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case DataType::FP32:
return typeid(float); return typeid(float);
@ -54,7 +56,8 @@ inline std::type_index ToTypeIndex(DataType type) {
} }
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(DataType type, Visitor visitor) { inline void VisitDataType(proto::DataType type, Visitor visitor) {
using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case DataType::FP32:
visitor.template operator()<float>(); visitor.template operator()<float>();

@ -90,7 +90,7 @@ struct OpInfoFiller<T, kOperator> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new OpProto; info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_); auto maker = T(info->proto_, info->checker_);
maker.Validate(); maker.Validate();

@ -41,20 +41,20 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_.swap(borrowed_contexts); device_contexts_.swap(borrowed_contexts);
} }
static void CreateTensor(Variable* var, VarDesc::VarType var_type) { static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == VarDesc::LOD_TENSOR) { if (var_type == proto::VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == VarDesc::SELECTED_ROWS) { } else if (var_type == proto::VarDesc::SELECTED_ROWS) {
var->GetMutable<SelectedRows>(); var->GetMutable<SelectedRows>();
} else if (var_type == VarDesc::FEED_MINIBATCH) { } else if (var_type == proto::VarDesc::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::FETCH_LIST) { } else if (var_type == proto::VarDesc::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::STEP_SCOPES) { } else if (var_type == proto::VarDesc::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>(); var->GetMutable<std::vector<framework::Scope>>();
} else if (var_type == VarDesc::LOD_RANK_TABLE) { } else if (var_type == proto::VarDesc::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == VarDesc::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>(); var->GetMutable<LoDTensorArray>();
} else { } else {
PADDLE_THROW( PADDLE_THROW(

@ -14,7 +14,7 @@ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
package paddle.framework; package paddle.framework.proto;
enum AttrType { enum AttrType {
INT = 0; INT = 0;

@ -197,7 +197,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
{ // the 2nd field, tensor description { // the 2nd field, tensor description
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
framework::TensorDesc desc; proto::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto *pb_dims = desc.mutable_dims(); auto *pb_dims = desc.mutable_dims();
@ -262,7 +262,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
uint32_t version; uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version)); is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
framework::TensorDesc desc; proto::TensorDesc desc;
{ // int32_t size { // int32_t size
// proto buffer // proto buffer
int32_t size; int32_t size;
@ -281,16 +281,16 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
void *buf; void *buf;
platform::Place cpu = platform::CPUPlace(); platform::Place cpu = platform::CPUPlace();
switch (desc.data_type()) { switch (desc.data_type()) {
case framework::FP32: case proto::FP32:
buf = tensor->mutable_data<float>(cpu); buf = tensor->mutable_data<float>(cpu);
break; break;
case framework::FP64: case proto::FP64:
buf = tensor->mutable_data<double>(cpu); buf = tensor->mutable_data<double>(cpu);
break; break;
case framework::INT32: case proto::INT32:
buf = tensor->mutable_data<int>(cpu); buf = tensor->mutable_data<int>(cpu);
break; break;
case framework::INT64: case proto::INT64:
buf = tensor->mutable_data<int64_t>(cpu); buf = tensor->mutable_data<int64_t>(cpu);
break; break;
default: default:

@ -58,11 +58,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_LT(j, Outputs(out).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != VarDesc::LOD_TENSOR) { if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor"; VLOG(3) << "input " << in << " is not LodTensor";
return; return;
} }
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR, PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j, "The %d-th output of Output(%s) must be LoDTensor.", j,
out); out);
out_var->SetLoDLevel(in_var->GetLodLevel()); out_var->SetLoDLevel(in_var->GetLodLevel());
@ -70,7 +70,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override; bool IsRuntime() const override;
protected: protected:
VarDesc::VarType GetVarType(const std::string &name) const override; proto::VarDesc::VarType GetVarType(const std::string &name) const override;
DDim GetDim(const std::string &name) const override; DDim GetDim(const std::string &name) const override;
@ -90,12 +90,12 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
need_update_ = true; need_update_ = true;
} }
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog) OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) { : desc_(desc), need_update_(false) {
// restore inputs_ // restore inputs_
int input_size = desc_.inputs_size(); int input_size = desc_.inputs_size();
for (int i = 0; i < input_size; ++i) { for (int i = 0; i < input_size; ++i) {
const OpDesc::Var &var = desc_.inputs(i); const proto::OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()]; std::vector<std::string> &args = inputs_[var.parameter()];
int argu_size = var.arguments_size(); int argu_size = var.arguments_size();
args.reserve(argu_size); args.reserve(argu_size);
@ -106,7 +106,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
// restore outputs_ // restore outputs_
int output_size = desc_.outputs_size(); int output_size = desc_.outputs_size();
for (int i = 0; i < output_size; ++i) { for (int i = 0; i < output_size; ++i) {
const OpDesc::Var &var = desc_.outputs(i); const proto::OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()]; std::vector<std::string> &args = outputs_[var.parameter()];
int argu_size = var.arguments_size(); int argu_size = var.arguments_size();
args.reserve(argu_size); args.reserve(argu_size);
@ -115,9 +115,9 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
} }
} }
// restore attrs_ // restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) { for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name(); std::string attr_name = attr.name();
if (attr.type() != AttrType::BLOCK) { if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = GetAttrValue(attr); attrs_[attr_name] = GetAttrValue(attr);
} else { } else {
auto bid = attr.block_idx(); auto bid = attr.block_idx();
@ -126,7 +126,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
} }
} }
OpDesc *OpDescBind::Proto() { proto::OpDesc *OpDescBind::Proto() {
Flush(); Flush();
return &desc_; return &desc_;
} }
@ -175,10 +175,10 @@ void OpDescBind::SetOutput(const std::string &param_name,
this->outputs_[param_name] = args; this->outputs_[param_name] = args;
} }
AttrType OpDescBind::GetAttrType(const std::string &name) const { proto::AttrType OpDescBind::GetAttrType(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<AttrType>(it->second.which() - 1); return static_cast<proto::AttrType>(it->second.which() - 1);
} }
std::vector<std::string> OpDescBind::AttrNames() const { std::vector<std::string> OpDescBind::AttrNames() const {
@ -253,8 +253,8 @@ void OpDescBind::RenameInput(const std::string &old_name,
} }
struct SetAttrDescVisitor : public boost::static_visitor<void> { struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_; mutable proto::OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); } void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); } void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); } void operator()(const std::string &v) const { attr_->set_s(v); }
@ -272,7 +272,9 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<bool> &v) const { void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools()); VectorToRepeated(v, attr_->mutable_bools());
} }
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); } void operator()(proto::BlockDesc *desc) const {
attr_->set_block_idx(desc->idx());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };
@ -297,7 +299,7 @@ void OpDescBind::Flush() {
auto *attr_desc = desc_.add_attrs(); auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first); attr_desc->set_name(attr.first);
attr_desc->set_type( attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1)); static_cast<proto::AttrType>(attr.second.which() - 1));
SetAttrDescVisitor visitor(attr_desc); SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second); boost::apply_visitor(visitor, attr.second);
} }
@ -375,7 +377,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
for (auto &out_pair : this->outputs_) { for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) { for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name) block->FindRecursiveOrCreateVar(out_var_name)
->SetType(VarDesc::LOD_TENSOR); ->SetType(proto::VarDesc::LOD_TENSOR);
} }
} }
} }
@ -484,7 +486,7 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
} }
bool CompileTimeInferShapeContext::IsRuntime() const { return false; } bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
VarDesc::VarType CompileTimeInferShapeContext::GetVarType( proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
const std::string &name) const { const std::string &name) const {
return block_.FindVarRecursive(name)->GetType(); return block_.FindVarRecursive(name)->GetType();
} }

@ -33,9 +33,9 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs); const VariableNameMap &outputs, const AttributeMap &attrs);
OpDescBind(const OpDesc &desc, ProgramDescBind *prog); OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog);
OpDesc *Proto(); proto::OpDesc *Proto();
std::string Type() const { return desc_.type(); } std::string Type() const { return desc_.type(); }
@ -59,7 +59,7 @@ class OpDescBind {
return attrs_.find(name) != attrs_.end(); return attrs_.find(name) != attrs_.end();
} }
AttrType GetAttrType(const std::string &name) const; proto::AttrType GetAttrType(const std::string &name) const;
std::vector<std::string> AttrNames() const; std::vector<std::string> AttrNames() const;
@ -126,7 +126,7 @@ class OpDescBind {
return ret_val; return ret_val;
} }
OpDesc desc_; proto::OpDesc desc_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;

@ -34,7 +34,7 @@ class InferShapeBase {
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
GradOpMakerFN grad_op_maker_; GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr}; proto::OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_; InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_; InferShapeFN infer_shape_;
@ -43,7 +43,7 @@ struct OpInfo {
return proto_ != nullptr && checker_ != nullptr; return proto_ != nullptr && checker_ != nullptr;
} }
const OpProto& Proto() const { const proto::OpProto& Proto() const {
PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered"); PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered");
PADDLE_ENFORCE(proto_->IsInitialized(), PADDLE_ENFORCE(proto_->IsInitialized(),
"Operator Proto must be initialized in op info"); "Operator Proto must be initialized in op info");

@ -22,6 +22,8 @@ namespace framework {
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker { class OpProtoAndCheckerMaker {
public: public:
using OpProto = proto::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {} : proto_(proto), op_checker_(op_checker) {}
@ -80,7 +82,7 @@ class OpProtoAndCheckerMaker {
class NOPMaker : public OpProtoAndCheckerMaker { class NOPMaker : public OpProtoAndCheckerMaker {
public: public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {} : OpProtoAndCheckerMaker(proto, op_checker) {}
}; };

@ -18,7 +18,7 @@ limitations under the License. */
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public: public:
TestAttrProtoMaker(paddle::framework::OpProto* proto, TestAttrProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker) paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op"); AddAttr<float>("scale", "scale of test op");
@ -27,7 +27,7 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
}; };
TEST(ProtoMaker, DuplicatedAttr) { TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
@ -35,7 +35,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public: public:
TestInOutProtoMaker(paddle::framework::OpProto* proto, TestInOutProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker) paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op"); AddInput("input", "input of test op");
@ -44,7 +44,7 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
}; };
TEST(ProtoMaker, DuplicatedInOut) { TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);

@ -31,7 +31,8 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
} }
static VariableNameMap ConvertOpDescVarsToVarNameMap( static VariableNameMap ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) { const google::protobuf::RepeatedPtrField<proto::OpDesc::Var>&
op_desc_vars) {
VariableNameMap ret_val; VariableNameMap ret_val;
for (auto& var : op_desc_vars) { for (auto& var : op_desc_vars) {
auto& var_names = ret_val[var.parameter()]; auto& var_names = ret_val[var.parameter()];
@ -43,7 +44,8 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val; return ret_val;
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const proto::OpDesc& op_desc) {
VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be" VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) " "used in unit tests. Use CreateOp(const OpDescBind& op_desc) "
"instead."; "instead.";

@ -77,7 +77,7 @@ class OpRegistry {
const VariableNameMap& outputs, const VariableNameMap& outputs,
AttributeMap attrs); AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };

@ -51,7 +51,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static void BuildVar(const std::string& param_name, static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments, std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) { paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name); var->set_parameter(param_name);
for (auto& arg_name : arguments) { for (auto& arg_name : arguments) {
var->add_arguments(arg_name); var->add_arguments(arg_name);
@ -63,7 +63,7 @@ REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker); paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs()); BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs()); BuildVar("output", {"bb"}, op_desc.add_outputs());
@ -71,7 +71,7 @@ TEST(OpRegistry, CreateOp) {
float scale = 3.3; float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
@ -83,14 +83,14 @@ TEST(OpRegistry, CreateOp) {
} }
TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs()); BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs()); BuildVar("output", {"bb"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(-2.0); attr->set_f(-2.0);
bool caught = false; bool caught = false;
@ -108,7 +108,7 @@ TEST(OpRegistry, IllegalAttr) {
} }
TEST(OpRegistry, DefaultValue) { TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs()); BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs()); BuildVar("output", {"bb"}, op_desc.add_outputs());
@ -123,7 +123,7 @@ TEST(OpRegistry, DefaultValue) {
} }
TEST(OpRegistry, CustomChecker) { TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("my_test_op"); op_desc.set_type("my_test_op");
BuildVar("input", {"ii"}, op_desc.add_inputs()); BuildVar("input", {"ii"}, op_desc.add_inputs());
BuildVar("output", {"oo"}, op_desc.add_outputs()); BuildVar("output", {"oo"}, op_desc.add_outputs());
@ -145,7 +145,7 @@ TEST(OpRegistry, CustomChecker) {
// set 'test_attr' set to an illegal value // set 'test_attr' set to an illegal value
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
@ -164,7 +164,7 @@ TEST(OpRegistry, CustomChecker) {
op_desc.mutable_attrs()->Clear(); op_desc.mutable_attrs()->Clear();
attr = op_desc.mutable_attrs()->Add(); attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(4); attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;

@ -377,7 +377,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
VarDesc::VarType GetVarType(const std::string& name) const override { proto::VarDesc::VarType GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name); auto* var = scope_.FindVar(name);
return ToVarType(var->Type()); return ToVarType(var->Type());
} }
@ -417,7 +417,7 @@ OpKernelType OperatorWithKernel::GetKernelType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace()); return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
} }
DataType OperatorWithKernel::IndicateDataType( proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
auto& scope = ctx.scope(); auto& scope = ctx.scope();
int data_type = -1; int data_type = -1;
@ -443,7 +443,7 @@ DataType OperatorWithKernel::IndicateDataType(
} }
} }
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input"); PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type); return static_cast<proto::DataType>(data_type);
} }
} // namespace framework } // namespace framework

@ -358,12 +358,13 @@ struct OpKernelType {
}; };
platform::Place place_; platform::Place place_;
DataType data_type_; proto::DataType data_type_;
OpKernelType(DataType data_type, platform::Place place) OpKernelType(proto::DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {} : place_(place), data_type_(data_type) {}
OpKernelType(DataType data_type, const platform::DeviceContext& dev_ctx) OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {} : place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelType& o) const { bool operator==(const OpKernelType& o) const {
@ -409,7 +410,7 @@ class OperatorWithKernel : public OperatorBase {
private: private:
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
// same. // same.
DataType IndicateDataType(const ExecutionContext& ctx) const; proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
}; };
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key); std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);

@ -58,7 +58,7 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static void BuildVar(const std::string& param_name, static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments, std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) { paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name); var->set_parameter(param_name);
for (auto& arg_name : arguments) { for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name; *var->mutable_arguments()->Add() = arg_name;
@ -70,14 +70,14 @@ REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
TEST(OperatorBase, all) { TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator"); op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs()); BuildVar("input", {"IN1"}, op_desc.add_inputs());
BuildVar("output", {"OUT1"}, op_desc.add_outputs()); BuildVar("output", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext device_context; paddle::platform::CPUDeviceContext device_context;
@ -115,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override { OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(DataType::FP32, ctx.GetPlace()); return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
} }
}; };
@ -195,14 +195,14 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input // test with single input
TEST(OpKernel, all) { TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs()); BuildVar("x", {"IN1"}, op_desc.add_inputs());
BuildVar("y", {"OUT1"}, op_desc.add_outputs()); BuildVar("y", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
@ -224,7 +224,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) { TEST(OpKernel, multi_inputs) {
using namespace paddle::framework; using namespace paddle::framework;
OpDesc op_desc; proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel"); op_desc.set_type("op_multi_inputs_with_kernel");
BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs()); BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
BuildVar("k", {"k0"}, op_desc.add_inputs()); BuildVar("k", {"k0"}, op_desc.add_inputs());
@ -232,7 +232,7 @@ TEST(OpKernel, multi_inputs) {
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;

@ -26,7 +26,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
return blocks_.back().get(); return blocks_.back().get();
} }
ProgramDesc *ProgramDescBind::Proto() { proto::ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Flush(); block->Flush();
} }
@ -49,7 +49,7 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
} }
} }
ProgramDescBind::ProgramDescBind(const ProgramDesc &desc) { ProgramDescBind::ProgramDescBind(const proto::ProgramDesc &desc) {
desc_ = desc; desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc)); blocks_.emplace_back(new BlockDescBind(this, &block_desc));

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save