merge develop

yu239-patch-1
Yang Yang 8 years ago
commit ec01f635f5

@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname); auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg); auto* grad = block_desc->FindVar(arg);
if (param == nullptr) { if (param == nullptr) {
grad->SetDataType(proto::DataType::FP32); grad->SetDataType(proto::VarType::FP32);
} else { } else {
grad->SetDataType(param->GetDataType()); grad->SetDataType(param->GetDataType());
} }

@ -51,10 +51,10 @@ class TestOpWithKernel : public OperatorWithKernel {
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) { if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel"; VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0)); return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0));
} else { } else {
VLOG(3) << "use default kernel"; VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32, return OpKernelType(proto::VarType::FP32,
ctx.Input<Tensor>("input")->place()); ctx.Input<Tensor>("input")->place());
} }
} }

@ -27,9 +27,9 @@ TEST(DataTransform, DataLayoutFunction) {
in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place); in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place);
in.set_layout(DataLayout::kNHWC); in.set_layout(DataLayout::kNHWC);
auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place, auto kernel_nhwc = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNHWC, LibraryType::kPlain); DataLayout::kNHWC, LibraryType::kPlain);
auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place, auto kernel_ncwh = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNCHW, LibraryType::kPlain); DataLayout::kNCHW, LibraryType::kPlain);
TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out); TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);

@ -20,35 +20,35 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline proto::DataType ToDataType(std::type_index type) { inline proto::VarType::Type ToDataType(std::type_index type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) { if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32; return proto::VarType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) { } else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64; return proto::VarType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) { } else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32; return proto::VarType::INT32;
} else if (typeid(int64_t).hash_code() == type.hash_code()) { } else if (typeid(int64_t).hash_code() == type.hash_code()) {
return DataType::INT64; return proto::VarType::INT64;
} else if (typeid(bool).hash_code() == type.hash_code()) { } else if (typeid(bool).hash_code() == type.hash_code()) {
return DataType::BOOL; return proto::VarType::BOOL;
} else { } else {
PADDLE_THROW("Not supported"); PADDLE_THROW("Not supported");
} }
} }
inline std::type_index ToTypeIndex(proto::DataType type) { inline std::type_index ToTypeIndex(proto::VarType::Type type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case proto::VarType::FP32:
return typeid(float); return typeid(float);
case DataType::FP64: case proto::VarType::FP64:
return typeid(double); return typeid(double);
case DataType::INT32: case proto::VarType::INT32:
return typeid(int); return typeid(int);
case DataType::INT64: case proto::VarType::INT64:
return typeid(int64_t); return typeid(int64_t);
case DataType::BOOL: case proto::VarType::BOOL:
return typeid(bool); return typeid(bool);
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) {
} }
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::DataType type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case proto::VarType::FP32:
visitor.template operator()<float>(); visitor.template operator()<float>();
break; break;
case DataType::FP64: case proto::VarType::FP64:
visitor.template operator()<double>(); visitor.template operator()<double>();
break; break;
case DataType::INT32: case proto::VarType::INT32:
visitor.template operator()<int>(); visitor.template operator()<int>();
break; break;
case DataType::INT64: case proto::VarType::INT64:
visitor.template operator()<int64_t>(); visitor.template operator()<int64_t>();
break; break;
case DataType::BOOL: case proto::VarType::BOOL:
visitor.template operator()<bool>(); visitor.template operator()<bool>();
break; break;
default: default:
@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) {
} }
} }
inline std::string DataTypeToString(const proto::DataType type) { inline std::string DataTypeToString(const proto::VarType::Type type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP16: case proto::VarType::FP16:
return "float16"; return "float16";
case DataType::FP32: case proto::VarType::FP32:
return "float32"; return "float32";
case DataType::FP64: case proto::VarType::FP64:
return "float64"; return "float64";
case DataType::INT16: case proto::VarType::INT16:
return "int16"; return "int16";
case DataType::INT32: case proto::VarType::INT32:
return "int32"; return "int32";
case DataType::INT64: case proto::VarType::INT64:
return "int64"; return "int64";
case DataType::BOOL: case proto::VarType::BOOL:
return "bool"; return "bool";
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) {
} }
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::DataType& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
return out; return out;
} }

@ -65,19 +65,19 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
auto ctx = pool.Get(in.place()); auto ctx = pool.Get(in.place());
switch (src_type) { switch (src_type) {
case proto::DataType::FP32: case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break; break;
case proto::DataType::FP64: case proto::VarType::FP64:
framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx));
break; break;
case proto::DataType::INT32: case proto::VarType::INT32:
framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx));
break; break;
case proto::DataType::INT64: case proto::VarType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx));
break; break;
case proto::DataType::BOOL: case proto::VarType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break; break;
default: default:

@ -32,11 +32,11 @@ TEST(DataTypeTransform, CPUTransform) {
ptr[i] = i / 3; ptr[i] = i / 3;
} }
auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place, auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place, auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::DataType::INT32, place, auto kernel_int32 = OpKernelType(proto::VarType::INT32, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
TransDataType(kernel_fp32, kernel_fp64, in, &out); TransDataType(kernel_fp32, kernel_fp64, in, &out);

@ -91,34 +91,35 @@ message OpProto {
required string comment = 5; required string comment = 5;
} }
enum DataType {
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message VarType { message VarType {
enum Type { enum Type {
LOD_TENSOR = 1; // Pod Types
SELECTED_ROWS = 2; BOOL = 0;
FEED_MINIBATCH = 3; INT16 = 1;
FETCH_LIST = 4; INT32 = 2;
STEP_SCOPES = 5; INT64 = 3;
LOD_RANK_TABLE = 6; FP16 = 4;
LOD_TENSOR_ARRAY = 7; FP32 = 5;
PLACE_LIST = 8; FP64 = 6;
READER = 9;
NCCL_COM = 10; // Other types that may need additional descriptions
LOD_TENSOR = 7;
SELECTED_ROWS = 8;
FEED_MINIBATCH = 9;
FETCH_LIST = 10;
STEP_SCOPES = 11;
LOD_RANK_TABLE = 12;
LOD_TENSOR_ARRAY = 13;
PLACE_LIST = 14;
READER = 15;
NCCL_COM = 16;
} }
required Type type = 1; required Type type = 1;
message TensorDesc { message TensorDesc {
required DataType data_type = 1; // Should only be PODType. Is enforced in C++
required Type data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
} }
optional TensorDesc selected_rows = 2; optional TensorDesc selected_rows = 2;

@ -40,12 +40,12 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8 // place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8; constexpr static int LEFT_SHIFT = 8;
proto::DataType data_type_; proto::VarType::Type data_type_;
DataLayout data_layout_; DataLayout data_layout_;
platform::Place place_; platform::Place place_;
LibraryType library_type_; LibraryType library_type_;
OpKernelType(proto::DataType data_type, platform::Place place, OpKernelType(proto::VarType::Type data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type), : data_type_(data_type),
@ -53,7 +53,7 @@ struct OpKernelType {
place_(place), place_(place),
library_type_(library_type) {} library_type_(library_type) {}
OpKernelType(proto::DataType data_type, OpKernelType(proto::VarType::Type data_type,
const platform::DeviceContext& dev_ctx, const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain)

@ -18,7 +18,7 @@ limitations under the License. */
TEST(OpKernelType, ToString) { TEST(OpKernelType, ToString) {
using OpKernelType = paddle::framework::OpKernelType; using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType; using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace; using CPUPlace = paddle::platform::CPUPlace;
using DataLayout = paddle::framework::DataLayout; using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType; using LibraryType = paddle::framework::LibraryType;
@ -33,7 +33,7 @@ TEST(OpKernelType, ToString) {
TEST(OpKernelType, Hash) { TEST(OpKernelType, Hash) {
using OpKernelType = paddle::framework::OpKernelType; using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType; using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace; using CPUPlace = paddle::platform::CPUPlace;
using CUDAPlace = paddle::platform::CUDAPlace; using CUDAPlace = paddle::platform::CUDAPlace;
using DataLayout = paddle::framework::DataLayout; using DataLayout = paddle::framework::DataLayout;

@ -226,7 +226,7 @@ class OpWithKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context()); return framework::OpKernelType(proto::VarType::FP32, ctx.device_context());
} }
}; };
@ -290,9 +290,9 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0),
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout, DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN); framework::LibraryType::kCUDNN);
} }
}; };

@ -569,7 +569,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
} }
proto::DataType OperatorWithKernel::IndicateDataType( proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
auto& scope = ctx.scope(); auto& scope = ctx.scope();
int data_type = -1; int data_type = -1;
@ -595,7 +595,7 @@ proto::DataType OperatorWithKernel::IndicateDataType(
} }
} }
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input"); PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<proto::DataType>(data_type); return static_cast<proto::VarType::Type>(data_type);
} }
OpKernelType OperatorWithKernel::GetExpectedKernelType( OpKernelType OperatorWithKernel::GetExpectedKernelType(

@ -394,9 +394,9 @@ class OperatorWithKernel : public OperatorBase {
const OpKernelType& expected_kernel_type) const; const OpKernelType& expected_kernel_type) const;
private: private:
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. By default all input data must be
// same. // same.
proto::DataType IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
}; };

@ -119,7 +119,7 @@ class OpWithKernelTest : public OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace()); return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
} }
}; };

@ -24,13 +24,13 @@ TEST(ProgramDesc, copy_ctor) {
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
auto* op = global_block->AppendOp(); auto* op = global_block->AppendOp();
@ -86,13 +86,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
auto* op = global_block->AppendOp(); auto* op = global_block->AppendOp();

@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) { for (auto kv : outputs) {
for (auto v : kv.second) { for (auto v : kv.second) {
auto var = block->Var(v); auto var = block->Var(v);
var->SetDataType(paddle::framework::proto::DataType::FP32); var->SetDataType(paddle::framework::proto::VarType::FP32);
} }
} }

@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -52,7 +53,9 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
}; };
static inline size_t SizeOfType(std::type_index type) { static inline size_t SizeOfType(std::type_index type) {
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t> functor; SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t,
platform::float16>
functor;
size_t size = functor(type); size_t size = functor(type);
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
return size; return size;

@ -87,12 +87,12 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
return res; return res;
} }
void VarDesc::SetDataType(proto::DataType data_type) { void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type); mutable_tensor_desc()->set_data_type(data_type);
} }
void VarDesc::SetDataTypes( void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) { const std::vector<proto::VarType::Type> &multiple_data_type) {
if (multiple_data_type.size() != GetTensorDescNum()) { if (multiple_data_type.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given data types(" VLOG(3) << "WARNING: The number of given data types("
<< multiple_data_type.size() << multiple_data_type.size()
@ -108,13 +108,13 @@ void VarDesc::SetDataTypes(
} }
} }
proto::DataType VarDesc::GetDataType() const { proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type(); return tensor_desc().data_type();
} }
std::vector<proto::DataType> VarDesc::GetDataTypes() const { std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::TensorDesc> descs = tensor_descs(); std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res; std::vector<proto::VarType::Type> res;
res.reserve(descs.size()); res.reserve(descs.size());
for (const auto &tensor_desc : descs) { for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type()); res.push_back(tensor_desc.data_type());

@ -80,13 +80,14 @@ class VarDesc {
std::vector<std::vector<int64_t>> GetShapes() const; std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(proto::DataType data_type); void SetDataType(proto::VarType::Type data_type);
void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type); void SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type);
proto::DataType GetDataType() const; proto::VarType::Type GetDataType() const;
std::vector<proto::DataType> GetDataTypes() const; std::vector<proto::VarType::Type> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level); void SetLoDLevel(int32_t lod_level);

@ -36,7 +36,8 @@ class AssignValueOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::proto::DataType(ctx.Attr<int>("dtype")), ctx.GetPlace()); framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
} }
}; };
@ -49,8 +50,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) " "(vector<int>) "
"Shape of values."); "Shape of values.");
AddAttr<int>("dtype", "data type of values") AddAttr<int>("dtype", "data type of values")
.InEnum({framework::proto::DataType::INT32, .InEnum({framework::proto::VarType::INT32,
framework::proto::DataType::FP32}); framework::proto::VarType::FP32});
AddAttr<std::vector<float>>("fp32_values", "store the float values") AddAttr<std::vector<float>>("fp32_values", "store the float values")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int values") AddAttr<std::vector<int>>("int32_values", "store the int values")

@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel<T> {
int dtype = ctx.Attr<int>("dtype"); int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr; const char* value_name = nullptr;
switch (dtype) { switch (dtype) {
case framework::proto::DataType::INT32: case framework::proto::VarType::INT32:
value_name = "int32_values"; value_name = "int32_values";
break; break;
case framework::proto::DataType::FP32: case framework::proto::VarType::FP32:
value_name = "fp32_values"; value_name = "fp32_values";
break; break;
default: default:

@ -55,7 +55,8 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType( framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("out_dtype")), static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>( CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>())); in, out, context.template device_context<DeviceContext>()));
} }

@ -57,7 +57,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32, return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace()); platform::CPUPlace());
} }
}; };

@ -42,7 +42,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32, return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context()); ctx.device_context());
} }
}; };

@ -24,7 +24,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context()); ctx.device_context());
} }
}; };
@ -36,7 +36,7 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
AddComment(R"DOC( AddComment(R"DOC(

@ -38,7 +38,7 @@ class FillConstantOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto data_type = auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
auto value = Attr<float>("value"); auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
auto &out = auto &out =
@ -64,7 +64,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output"); AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save