Move framework.proto to proto namespace (#6718)

* Move framework.proto to proto namespace

* Fix compile

* Fix compile

* Fix Compile
del_some_in_makelist
Yu Yang 7 years ago committed by GitHub
parent a87f4963ed
commit e445b3ff20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -53,7 +53,7 @@ Kernel实现 | CPU、CUDA共享Kernel实现在`.h`文件中否则CPU
```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
@ -82,7 +82,7 @@ The equation is: Out = X * Y
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient();

@ -50,7 +50,7 @@ First, define `ProtoMaker` to describe the Operator's input, output, and additio
```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
@ -79,7 +79,7 @@ An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/de
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient();

@ -19,42 +19,42 @@ limitations under the License. */
namespace paddle {
namespace framework {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: {
case proto::AttrType::BOOLEAN: {
return attr_desc.b();
}
case framework::AttrType::INT: {
case proto::AttrType::INT: {
return attr_desc.i();
}
case framework::AttrType::FLOAT: {
case proto::AttrType::FLOAT: {
return attr_desc.f();
}
case framework::AttrType::STRING: {
case proto::AttrType::STRING: {
return attr_desc.s();
}
case framework::AttrType::BOOLEANS: {
case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
}
return val;
}
case framework::AttrType::INTS: {
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
return val;
}
case framework::AttrType::FLOATS: {
case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
return val;
}
case framework::AttrType::STRINGS: {
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);

@ -27,12 +27,12 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
inline AttrType AttrTypeID() {
inline proto::AttrType AttrTypeID() {
Attribute tmp = T();
return static_cast<AttrType>(tmp.which() - 1);
return static_cast<proto::AttrType>(tmp.which() - 1);
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader {
public:

@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
grad->SetDataType(DataType::FP32);
grad->SetDataType(proto::DataType::FP32);
} else {
grad->SetDataType(param->GetDataType());
}

@ -166,7 +166,7 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator.");

@ -128,22 +128,22 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
}
BlockDesc *BlockDescBind::Proto() {
proto::BlockDesc *BlockDescBind::Proto() {
Flush();
return desc_;
}
BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
BlockDescBind::BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {
for (const VarDesc &var_desc : desc_->vars()) {
for (const proto::VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc));
}
for (const OpDesc &op_desc : desc_->ops()) {
for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog));
}
}
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
BlockDescBind::BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc,
ProgramDescBind *prog)
: prog_(prog), desc_(desc) {
need_update_ = true;

@ -36,9 +36,9 @@ class ProgramDescBind;
class BlockDescBind {
public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc);
BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc);
BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc,
ProgramDescBind *prog);
~BlockDescBind() {
@ -88,7 +88,7 @@ class BlockDescBind {
void Flush();
BlockDesc *Proto();
proto::BlockDesc *Proto();
ProgramDescBind *Program() { return this->prog_; }
@ -97,8 +97,8 @@ class BlockDescBind {
void ClearPBVars();
private:
ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own
ProgramDescBind *prog_; // not_own
proto::BlockDesc *desc_; // not_own
bool need_update_;
std::deque<std::unique_ptr<OpDescBind>> ops_;

@ -20,7 +20,8 @@
namespace paddle {
namespace framework {
inline DataType ToDataType(std::type_index type) {
inline proto::DataType ToDataType(std::type_index type) {
using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
@ -36,7 +37,8 @@ inline DataType ToDataType(std::type_index type) {
}
}
inline std::type_index ToTypeIndex(DataType type) {
inline std::type_index ToTypeIndex(proto::DataType type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
return typeid(float);
@ -54,7 +56,8 @@ inline std::type_index ToTypeIndex(DataType type) {
}
template <typename Visitor>
inline void VisitDataType(DataType type, Visitor visitor) {
inline void VisitDataType(proto::DataType type, Visitor visitor) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
visitor.template operator()<float>();

@ -90,7 +90,7 @@ struct OpInfoFiller<T, kOperator> {
template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new OpProto;
info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_);
maker.Validate();

@ -41,20 +41,20 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_.swap(borrowed_contexts);
}
static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
if (var_type == VarDesc::LOD_TENSOR) {
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == VarDesc::SELECTED_ROWS) {
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == VarDesc::FEED_MINIBATCH) {
} else if (var_type == proto::VarDesc::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::FETCH_LIST) {
} else if (var_type == proto::VarDesc::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::STEP_SCOPES) {
} else if (var_type == proto::VarDesc::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>();
} else if (var_type == VarDesc::LOD_RANK_TABLE) {
} else if (var_type == proto::VarDesc::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == VarDesc::LOD_TENSOR_ARRAY) {
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else {
PADDLE_THROW(

@ -14,7 +14,7 @@ limitations under the License. */
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package paddle.framework;
package paddle.framework.proto;
enum AttrType {
INT = 0;

@ -197,7 +197,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework::TensorDesc desc;
proto::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto *pb_dims = desc.mutable_dims();
@ -262,7 +262,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
framework::TensorDesc desc;
proto::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
@ -281,16 +281,16 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
void *buf;
platform::Place cpu = platform::CPUPlace();
switch (desc.data_type()) {
case framework::FP32:
case proto::FP32:
buf = tensor->mutable_data<float>(cpu);
break;
case framework::FP64:
case proto::FP64:
buf = tensor->mutable_data<double>(cpu);
break;
case framework::INT32:
case proto::INT32:
buf = tensor->mutable_data<int>(cpu);
break;
case framework::INT64:
case proto::INT64:
buf = tensor->mutable_data<int64_t>(cpu);
break;
default:

@ -58,11 +58,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != VarDesc::LOD_TENSOR) {
if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor";
return;
}
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR,
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j,
out);
out_var->SetLoDLevel(in_var->GetLodLevel());
@ -70,7 +70,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override;
protected:
VarDesc::VarType GetVarType(const std::string &name) const override;
proto::VarDesc::VarType GetVarType(const std::string &name) const override;
DDim GetDim(const std::string &name) const override;
@ -90,12 +90,12 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
need_update_ = true;
}
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
int input_size = desc_.inputs_size();
for (int i = 0; i < input_size; ++i) {
const OpDesc::Var &var = desc_.inputs(i);
const proto::OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
@ -106,7 +106,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
// restore outputs_
int output_size = desc_.outputs_size();
for (int i = 0; i < output_size; ++i) {
const OpDesc::Var &var = desc_.outputs(i);
const proto::OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
@ -115,9 +115,9 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
}
}
// restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) {
for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
if (attr.type() != AttrType::BLOCK) {
if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = GetAttrValue(attr);
} else {
auto bid = attr.block_idx();
@ -126,7 +126,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
}
}
OpDesc *OpDescBind::Proto() {
proto::OpDesc *OpDescBind::Proto() {
Flush();
return &desc_;
}
@ -175,10 +175,10 @@ void OpDescBind::SetOutput(const std::string &param_name,
this->outputs_[param_name] = args;
}
AttrType OpDescBind::GetAttrType(const std::string &name) const {
proto::AttrType OpDescBind::GetAttrType(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<AttrType>(it->second.which() - 1);
return static_cast<proto::AttrType>(it->second.which() - 1);
}
std::vector<std::string> OpDescBind::AttrNames() const {
@ -253,8 +253,8 @@ void OpDescBind::RenameInput(const std::string &old_name,
}
struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_;
explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
mutable proto::OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
@ -272,7 +272,9 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); }
void operator()(proto::BlockDesc *desc) const {
attr_->set_block_idx(desc->idx());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
@ -297,7 +299,7 @@ void OpDescBind::Flush() {
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1));
static_cast<proto::AttrType>(attr.second.which() - 1));
SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second);
}
@ -375,7 +377,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name)
->SetType(VarDesc::LOD_TENSOR);
->SetType(proto::VarDesc::LOD_TENSOR);
}
}
}
@ -484,7 +486,7 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
}
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
const std::string &name) const {
return block_.FindVarRecursive(name)->GetType();
}

@ -33,9 +33,9 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs);
OpDescBind(const OpDesc &desc, ProgramDescBind *prog);
OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog);
OpDesc *Proto();
proto::OpDesc *Proto();
std::string Type() const { return desc_.type(); }
@ -59,7 +59,7 @@ class OpDescBind {
return attrs_.find(name) != attrs_.end();
}
AttrType GetAttrType(const std::string &name) const;
proto::AttrType GetAttrType(const std::string &name) const;
std::vector<std::string> AttrNames() const;
@ -126,7 +126,7 @@ class OpDescBind {
return ret_val;
}
OpDesc desc_;
proto::OpDesc desc_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;

@ -34,7 +34,7 @@ class InferShapeBase {
struct OpInfo {
OpCreator creator_;
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
proto::OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_;
@ -43,7 +43,7 @@ struct OpInfo {
return proto_ != nullptr && checker_ != nullptr;
}
const OpProto& Proto() const {
const proto::OpProto& Proto() const {
PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered");
PADDLE_ENFORCE(proto_->IsInitialized(),
"Operator Proto must be initialized in op info");

@ -22,6 +22,8 @@ namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
using OpProto = proto::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
@ -80,7 +82,7 @@ class OpProtoAndCheckerMaker {
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};

@ -18,7 +18,7 @@ limitations under the License. */
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(paddle::framework::OpProto* proto,
TestAttrProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op");
@ -27,7 +27,7 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
};
TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::OpProto op_proto;
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
@ -35,7 +35,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(paddle::framework::OpProto* proto,
TestInOutProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
@ -44,7 +44,7 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
};
TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpProto op_proto;
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);

@ -31,7 +31,8 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
}
static VariableNameMap ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) {
const google::protobuf::RepeatedPtrField<proto::OpDesc::Var>&
op_desc_vars) {
VariableNameMap ret_val;
for (auto& var : op_desc_vars) {
auto& var_names = ret_val[var.parameter()];
@ -43,7 +44,8 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val;
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const proto::OpDesc& op_desc) {
VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) "
"instead.";

@ -77,7 +77,7 @@ class OpRegistry {
const VariableNameMap& outputs,
AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
};

@ -51,7 +51,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
@ -63,7 +63,7 @@ REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
@ -71,7 +71,7 @@ TEST(OpRegistry, CreateOp) {
float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
@ -83,14 +83,14 @@ TEST(OpRegistry, CreateOp) {
}
TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(-2.0);
bool caught = false;
@ -108,7 +108,7 @@ TEST(OpRegistry, IllegalAttr) {
}
TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
@ -123,7 +123,7 @@ TEST(OpRegistry, DefaultValue) {
}
TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("my_test_op");
BuildVar("input", {"ii"}, op_desc.add_inputs());
BuildVar("output", {"oo"}, op_desc.add_outputs());
@ -145,7 +145,7 @@ TEST(OpRegistry, CustomChecker) {
// set 'test_attr' set to an illegal value
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(3);
caught = false;
try {
@ -164,7 +164,7 @@ TEST(OpRegistry, CustomChecker) {
op_desc.mutable_attrs()->Clear();
attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;

@ -377,7 +377,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
VarDesc::VarType GetVarType(const std::string& name) const override {
proto::VarDesc::VarType GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
}
@ -417,7 +417,7 @@ OpKernelType OperatorWithKernel::GetKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
DataType OperatorWithKernel::IndicateDataType(
proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
@ -443,7 +443,7 @@ DataType OperatorWithKernel::IndicateDataType(
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
return static_cast<proto::DataType>(data_type);
}
} // namespace framework

@ -358,12 +358,13 @@ struct OpKernelType {
};
platform::Place place_;
DataType data_type_;
proto::DataType data_type_;
OpKernelType(DataType data_type, platform::Place place)
OpKernelType(proto::DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}
OpKernelType(DataType data_type, const platform::DeviceContext& dev_ctx)
OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelType& o) const {
@ -409,7 +410,7 @@ class OperatorWithKernel : public OperatorBase {
private:
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
DataType IndicateDataType(const ExecutionContext& ctx) const;
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
};
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);

@ -58,7 +58,7 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
@ -70,14 +70,14 @@ REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs());
BuildVar("output", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext device_context;
@ -115,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(DataType::FP32, ctx.GetPlace());
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
}
};
@ -195,14 +195,14 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input
TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs());
BuildVar("y", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context;
@ -224,7 +224,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;
OpDesc op_desc;
proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
BuildVar("k", {"k0"}, op_desc.add_inputs());
@ -232,7 +232,7 @@ TEST(OpKernel, multi_inputs) {
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context;

@ -26,7 +26,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
return blocks_.back().get();
}
ProgramDesc *ProgramDescBind::Proto() {
proto::ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) {
block->Flush();
}
@ -49,7 +49,7 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
}
}
ProgramDescBind::ProgramDescBind(const ProgramDesc &desc) {
ProgramDescBind::ProgramDescBind(const proto::ProgramDesc &desc) {
desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc));

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save