[WIP] Move DataType enum inside VarType (#8447)

* Move Pod Types from DataType enum to Type enum

* Fixed data_type.h

* Fix type in TensorDesc

* Add comment to framework.proto

* Fixed type in data_type.h

* Updated format of type in data_type.h

* Fix var_desc.h

* Fix op_kernel_type.h

* Fixed data_type_transform_test.cc

* Fix operator.h

* Fixed data_type_transform.cc

* Fixed op_kernel_type_test.cc

* Fix operator.cc

* Fixed data_layout_transform_test.cc

* Fix var_desc.cc

* Fixed assign_value_op.cc

* Fixed assign_value_op.h

* fixed protobuf.cc

* Fix data_layout_transform_test.cc and op_kernel_type_test.cc

* Fixed rnn_memory_helper_op.cc

* Fix progrma_desc_test.cc

* Fixed fill_constant_batch_size_like_op.cc

* Fix operator_test.cc

* Fixed fill_constant_op.cc

* Fixed gaussian_random_op.cc

* Fixed uniform_random_op.cc

* Fixed edit_distance_op.cc

* Fixed fill_constant_batch_size_like_op.cc

* Fixed rnn_memory_helper_op.cc

* Fixed chunk_eval_op.cc

* Fixed assign_value_op.cc

* Fixed assign_value_op.h

* Fixed cast_op.h

* Fixed cast_op.h

* Fix fill constant op

* Fixed clang for assign_value_op.cc

* Fix one_hot_op.h

* Fix one_hot_op.cc

* Fix fill_op.cc

* Fixed sum_op.cc

* Fixed sum_op clang

* Fix uniform_random_op.cc

* Fix gaussian_random_op.cc

* Fix backward.cc

* Fix protobuf.cc

* Fixed prune_test.cc

* Fixed op_registry_test.cc

* Fix data_device_transform_test.cu

* Fix travis error

* Fixed one_hot_op.cu

* Fixed op_registry_test.cc

* Fixed nccl_op.cc

* Fixing python tests

* Revert "Fixing python tests"

This reverts commit fccaa4c5818ed9f379ea1ce4315066cc78076c64.

* Fixing Pybind to remove data type

* Fixing tensor.py

* Updated the new files:

* Resolve error in merge conflict of fill_constant_batch_size_like_op.cc
emailweixu-patch-1
Abhinav Arora 7 years ago committed by kavyasrinet
parent 74e0eb7267
commit c7ad26d6a4

@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
grad->SetDataType(proto::DataType::FP32);
grad->SetDataType(proto::VarType::FP32);
} else {
grad->SetDataType(param->GetDataType());
}

@ -51,10 +51,10 @@ class TestOpWithKernel : public OperatorWithKernel {
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0));
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32,
return OpKernelType(proto::VarType::FP32,
ctx.Input<Tensor>("input")->place());
}
}

@ -27,9 +27,9 @@ TEST(DataTransform, DataLayoutFunction) {
in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place);
in.set_layout(DataLayout::kNHWC);
auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place,
auto kernel_nhwc = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNHWC, LibraryType::kPlain);
auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place,
auto kernel_ncwh = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNCHW, LibraryType::kPlain);
TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);

@ -20,35 +20,35 @@ limitations under the License. */
namespace paddle {
namespace framework {
inline proto::DataType ToDataType(std::type_index type) {
inline proto::VarType::Type ToDataType(std::type_index type) {
using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
return proto::VarType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64;
return proto::VarType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32;
return proto::VarType::INT32;
} else if (typeid(int64_t).hash_code() == type.hash_code()) {
return DataType::INT64;
return proto::VarType::INT64;
} else if (typeid(bool).hash_code() == type.hash_code()) {
return DataType::BOOL;
return proto::VarType::BOOL;
} else {
PADDLE_THROW("Not supported");
}
}
inline std::type_index ToTypeIndex(proto::DataType type) {
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
case proto::VarType::FP32:
return typeid(float);
case DataType::FP64:
case proto::VarType::FP64:
return typeid(double);
case DataType::INT32:
case proto::VarType::INT32:
return typeid(int);
case DataType::INT64:
case proto::VarType::INT64:
return typeid(int64_t);
case DataType::BOOL:
case proto::VarType::BOOL:
return typeid(bool);
default:
PADDLE_THROW("Not support type %d", type);
@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) {
}
template <typename Visitor>
inline void VisitDataType(proto::DataType type, Visitor visitor) {
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
case proto::VarType::FP32:
visitor.template operator()<float>();
break;
case DataType::FP64:
case proto::VarType::FP64:
visitor.template operator()<double>();
break;
case DataType::INT32:
case proto::VarType::INT32:
visitor.template operator()<int>();
break;
case DataType::INT64:
case proto::VarType::INT64:
visitor.template operator()<int64_t>();
break;
case DataType::BOOL:
case proto::VarType::BOOL:
visitor.template operator()<bool>();
break;
default:
@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) {
}
}
inline std::string DataTypeToString(const proto::DataType type) {
inline std::string DataTypeToString(const proto::VarType::Type type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP16:
case proto::VarType::FP16:
return "float16";
case DataType::FP32:
case proto::VarType::FP32:
return "float32";
case DataType::FP64:
case proto::VarType::FP64:
return "float64";
case DataType::INT16:
case proto::VarType::INT16:
return "int16";
case DataType::INT32:
case proto::VarType::INT32:
return "int32";
case DataType::INT64:
case proto::VarType::INT64:
return "int64";
case DataType::BOOL:
case proto::VarType::BOOL:
return "bool";
default:
PADDLE_THROW("Not support type %d", type);
@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) {
}
inline std::ostream& operator<<(std::ostream& out,
const proto::DataType& type) {
const proto::VarType::Type& type) {
out << DataTypeToString(type);
return out;
}

@ -65,19 +65,19 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
auto ctx = pool.Get(in.place());
switch (src_type) {
case proto::DataType::FP32:
case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break;
case proto::DataType::FP64:
case proto::VarType::FP64:
framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx));
break;
case proto::DataType::INT32:
case proto::VarType::INT32:
framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx));
break;
case proto::DataType::INT64:
case proto::VarType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx));
break;
case proto::DataType::BOOL:
case proto::VarType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break;
default:

@ -32,11 +32,11 @@ TEST(DataTypeTransform, CPUTransform) {
ptr[i] = i / 3;
}
auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place,
auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place,
auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::DataType::INT32, place,
auto kernel_int32 = OpKernelType(proto::VarType::INT32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
TransDataType(kernel_fp32, kernel_fp64, in, &out);

@ -91,7 +91,9 @@ message OpProto {
required string comment = 5;
}
enum DataType {
message VarType {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
@ -99,25 +101,24 @@ enum DataType {
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message VarType {
enum Type {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
STEP_SCOPES = 5;
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
// Other types that may need additional descriptions
LOD_TENSOR = 7;
SELECTED_ROWS = 8;
FEED_MINIBATCH = 9;
FETCH_LIST = 10;
STEP_SCOPES = 11;
LOD_RANK_TABLE = 12;
LOD_TENSOR_ARRAY = 13;
PLACE_LIST = 14;
READER = 15;
}
required Type type = 1;
message TensorDesc {
required DataType data_type = 1;
// Should only be PODType. Is enforced in C++
required Type data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;

@ -40,12 +40,12 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;
proto::DataType data_type_;
proto::VarType::Type data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;
OpKernelType(proto::DataType data_type, platform::Place place,
OpKernelType(proto::VarType::Type data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
@ -53,7 +53,7 @@ struct OpKernelType {
place_(place),
library_type_(library_type) {}
OpKernelType(proto::DataType data_type,
OpKernelType(proto::VarType::Type data_type,
const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)

@ -18,7 +18,7 @@ limitations under the License. */
TEST(OpKernelType, ToString) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace;
using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType;
@ -33,7 +33,7 @@ TEST(OpKernelType, ToString) {
TEST(OpKernelType, Hash) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace;
using CUDAPlace = paddle::platform::CUDAPlace;
using DataLayout = paddle::framework::DataLayout;

@ -226,7 +226,7 @@ class OpWithKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
return framework::OpKernelType(proto::VarType::FP32, ctx.device_context());
}
};
@ -290,8 +290,8 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
return framework::OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
};

@ -569,7 +569,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}
proto::DataType OperatorWithKernel::IndicateDataType(
proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
@ -595,7 +595,7 @@ proto::DataType OperatorWithKernel::IndicateDataType(
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<proto::DataType>(data_type);
return static_cast<proto::VarType::Type>(data_type);
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(

@ -394,9 +394,9 @@ class OperatorWithKernel : public OperatorBase {
const OpKernelType& expected_kernel_type) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
// indicate kernel DataType by input data. By default all input data must be
// same.
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final;
};

@ -119,7 +119,7 @@ class OpWithKernelTest : public OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
}
};

@ -24,13 +24,13 @@ TEST(ProgramDesc, copy_ctor) {
auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
@ -86,13 +86,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();

@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(paddle::framework::proto::DataType::FP32);
var->SetDataType(paddle::framework::proto::VarType::FP32);
}
}

@ -87,12 +87,12 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
return res;
}
void VarDesc::SetDataType(proto::DataType data_type) {
void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type);
}
void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) {
const std::vector<proto::VarType::Type> &multiple_data_type) {
if (multiple_data_type.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given data types("
<< multiple_data_type.size()
@ -108,13 +108,13 @@ void VarDesc::SetDataTypes(
}
}
proto::DataType VarDesc::GetDataType() const {
proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type();
}
std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res;
std::vector<proto::VarType::Type> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type());

@ -80,13 +80,14 @@ class VarDesc {
std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(proto::DataType data_type);
void SetDataType(proto::VarType::Type data_type);
void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type);
void SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type);
proto::DataType GetDataType() const;
proto::VarType::Type GetDataType() const;
std::vector<proto::DataType> GetDataTypes() const;
std::vector<proto::VarType::Type> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level);

@ -36,7 +36,8 @@ class AssignValueOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::DataType(ctx.Attr<int>("dtype")), ctx.GetPlace());
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
@ -49,8 +50,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) "
"Shape of values.");
AddAttr<int>("dtype", "data type of values")
.InEnum({framework::proto::DataType::INT32,
framework::proto::DataType::FP32});
.InEnum({framework::proto::VarType::INT32,
framework::proto::VarType::FP32});
AddAttr<std::vector<float>>("fp32_values", "store the float values")
.SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int values")

@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel<T> {
int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr;
switch (dtype) {
case framework::proto::DataType::INT32:
case framework::proto::VarType::INT32:
value_name = "int32_values";
break;
case framework::proto::DataType::FP32:
case framework::proto::VarType::FP32:
value_name = "fp32_values";
break;
default:

@ -55,7 +55,8 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("out_dtype")),
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>()));
}

@ -57,7 +57,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
};

@ -42,7 +42,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
};

@ -24,7 +24,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
@ -36,7 +36,7 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddComment(R"DOC(

@ -38,7 +38,7 @@ class FillConstantOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype"));
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu");
auto &out =
@ -64,7 +64,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);

@ -51,7 +51,8 @@ class FillOp : public framework::OperatorBase {
"Cannot find variable %s", Output("Out"))
.GetMutable<framework::LoDTensor>());
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
auto dtype = static_cast<framework::proto::DataType>(Attr<int>("dtype"));
auto dtype =
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype));
@ -93,7 +94,7 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by
"value", "The float values of tensor, which are flatten in row major");
AddAttr<std::vector<int>>("shape", "The shape of output tensor");
AddAttr<int>("dtype", "The data type of output tensor, Default is float")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddAttr<bool>("force_cpu",
"Whether the output tensor must be at CPU memory or not. "
"Default is false.")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save