Merge pull request #14962 from sneaxiy/rewrite_variable_type

Rewrite variable type
revert-15207-remove_op_handle_lock_and_fix_var
Zeng Jinle 6 years ago committed by GitHub
commit c0bcff00dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -68,18 +68,23 @@ cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memor
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc)
cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto)
if (WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda)
endif()
cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits)
cc_library(scope SRCS scope.cc DEPS glog threadpool var_type_traits)
cc_library(scope_pool SRCS scope_pool.cc DEPS scope)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_test(variable_test SRCS variable_test.cc DEPS tensor var_type_traits)
cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry device_context math_function)
DEPS operator op_registry device_context math_function scope)
if(WITH_GPU)
if (WIN32)

@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"

@ -88,7 +88,7 @@ void EagerDeletionOpHandle::RunImpl() {
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
framework::ToTypeName(var->Type()), name);
}
}

@ -24,7 +24,7 @@ static void VisitVariable(Variable* var, Func* func) {
} else if (var->IsType<SelectedRows>()) {
(*func)(var->GetMutable<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var->Type().name());
PADDLE_THROW("Not supported type %s", ToTypeName(var->Type()));
}
}
@ -35,7 +35,7 @@ static void VisitVariable(const Variable& var, Func* func) {
} else if (var.IsType<SelectedRows>()) {
(*func)(var.Get<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var.Type().name());
PADDLE_THROW("Not supported type %s", ToTypeName(var.Type()));
}
}

@ -119,7 +119,7 @@ static void DeleteUnusedTensors(
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
framework::ToTypeName(var->Type()), name);
}
}
}

@ -380,7 +380,7 @@ const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
return &(var.Get<SelectedRows>().value());
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var.Type().name());
ToTypeName(var.Type()));
}
}
@ -391,7 +391,7 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
return var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
ToTypeName(var->Type()));
}
}
@ -485,7 +485,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"should be LoDTensor, but the received type is %s",
var->Type().name());
ToTypeName(var->Type()));
return &(var->Get<LoDTensor>());
});
return res;
@ -504,7 +504,7 @@ const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>(
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
sub_name, ToTypeName(var->Type()));
return &(var->Get<LoDTensor>());
});
return res;
@ -533,7 +533,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
sub_name, ToTypeName(var->Type()));
return var->GetMutable<LoDTensor>();
});
return res;
@ -775,7 +775,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
PADDLE_THROW(
"Only LoDTensor/SelectedRows support 'GetDim', but Variables "
"type_id is %s.",
var->Type().name());
ToTypeName(var->Type()));
}
}
@ -798,7 +798,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
ToTypeName(var->Type()));
}
}

@ -165,11 +165,9 @@ std::string Scope::Rename(const std::string& origin_name) const {
Variable* Scope::VarInternal(const std::string& name) {
auto* v = FindVarLocally(name);
if (v != nullptr) return v;
v = new Variable();
vars_[name].reset(v);
vars_.emplace(name, std::unique_ptr<Variable>(v));
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
}

@ -19,52 +19,50 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace framework {
template <typename T>
inline bool IsType(const std::type_index& type_index) {
return type_index == std::type_index(typeid(T));
inline bool IsType(const std::type_index& type) {
return type == typeid(T);
}
inline proto::VarType::Type ToVarType(std::type_index type) {
if (IsType<LoDTensor>(type)) {
return proto::VarType_Type_LOD_TENSOR;
} else if (IsType<LoDRankTable>(type)) {
return proto::VarType_Type_LOD_RANK_TABLE;
} else if (IsType<LoDTensorArray>(type)) {
return proto::VarType_Type_LOD_TENSOR_ARRAY;
} else if (IsType<SelectedRows>(type)) {
return proto::VarType_Type_SELECTED_ROWS;
} else if (IsType<ReaderHolder>(type)) {
return proto::VarType_Type_READER;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
inline proto::VarType::Type ToVarType(int type) {
switch (type) {
case proto::VarType::LOD_TENSOR:
case proto::VarType::SELECTED_ROWS:
case proto::VarType::LOD_RANK_TABLE:
case proto::VarType::LOD_TENSOR_ARRAY:
case proto::VarType::READER:
return static_cast<proto::VarType::Type>(type);
default:
PADDLE_THROW("ToVarType:Unsupported type %d", type);
}
}
template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) {
case proto::VarType_Type_LOD_TENSOR:
switch (var.Type()) {
case proto::VarType::LOD_TENSOR:
visitor(var.Get<LoDTensor>());
return;
case proto::VarType_Type_LOD_RANK_TABLE:
case proto::VarType::LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>());
return;
case proto::VarType_Type_LOD_TENSOR_ARRAY:
case proto::VarType::LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>());
return;
case proto::VarType_Type_SELECTED_ROWS:
case proto::VarType::SELECTED_ROWS:
visitor(var.Get<SelectedRows>());
return;
case proto::VarType_Type_READER:
case proto::VarType::READER:
visitor(var.Get<ReaderHolder>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
PADDLE_THROW("Not supported visit type, %s", ToTypeName(var.Type()));
}
}

@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR,
ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType());
}

@ -0,0 +1,119 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/macros.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#endif
#include <cudnn.h>
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#endif
namespace paddle {
namespace framework {
// Besides registering variable type id, it is helpful to register a
// var_id -> std::type_index map (for example, get type names according to id)
namespace detail {
template <int kStart, int kEnd, bool kStop>
struct VarIdToTypeIndexMapInitializerImpl {
template <typename MapType1, typename MapType2>
static void Init(MapType1 *id_to_type, MapType2 *type_to_id) {
using Type =
typename std::tuple_element<kStart, VarTypeRegistry::ArgTuple>::type;
static_assert(!std::is_same<Type, void>::value, "Type cannot be void");
constexpr int kId = VarTypeTrait<Type>::kId;
auto type = std::type_index(typeid(Type));
PADDLE_ENFORCE(id_to_type->count(kId) == 0,
"Registered duplicate type id %d for type %s", kId,
type.name());
PADDLE_ENFORCE(type_to_id->count(type) == 0,
"Registered duplicate type_index %s for id %d", type.name(),
kId);
id_to_type->emplace(kId, type);
type_to_id->emplace(type, kId);
VarIdToTypeIndexMapInitializerImpl<kStart + 1, kEnd,
kStart + 1 == kEnd>::Init(id_to_type,
type_to_id);
}
};
template <int kStart, int kEnd>
struct VarIdToTypeIndexMapInitializerImpl<kStart, kEnd, true> {
template <typename MapType1, typename MapType2>
static void Init(MapType1 *, MapType2 *) {}
};
// VarIdToTypeIndexMapInitializer is designed to initialize var_id ->
// std::type_index map and std::type_index -> var_id map
using VarIdToTypeIndexMapInitializer =
VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum,
VarTypeRegistry::kRegisteredTypeNum ==
0>;
struct VarIdToTypeIndexMapHolder {
DISABLE_COPY_AND_ASSIGN(VarIdToTypeIndexMapHolder);
public:
static const std::type_index &ToTypeIndex(int var_id) {
auto it = Instance().id_to_type_map_.find(var_id);
PADDLE_ENFORCE(it != Instance().id_to_type_map_.end(),
"VarId %d is not registered.", var_id);
return it->second;
}
static int ToTypeId(const std::type_index &type) {
auto it = Instance().type_to_id_map_.find(type);
PADDLE_ENFORCE(it != Instance().type_to_id_map_.end(),
"VarType %s is not registered.", type.name());
return it->second;
}
private:
VarIdToTypeIndexMapHolder() {
VarIdToTypeIndexMapInitializer::Init(&id_to_type_map_, &type_to_id_map_);
}
static const VarIdToTypeIndexMapHolder &Instance() {
static const VarIdToTypeIndexMapHolder instance;
return instance;
}
std::unordered_map<int, std::type_index> id_to_type_map_;
std::unordered_map<std::type_index, int> type_to_id_map_;
};
} // namespace detail
const std::type_index &ToTypeIndex(int var_id) {
return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
}
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
int ToTypeId(const std::type_index &type) {
return detail::VarIdToTypeIndexMapHolder::ToTypeId(type);
}
} // namespace framework
} // namespace paddle

@ -0,0 +1,195 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <tuple>
#include <typeindex>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include <cudnn.h>
#ifndef _WIN32
#include <nccl.h>
#endif
#endif
// Users should add forward declarations here
namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
class Communicator;
#endif
#endif
} // namespace platform
namespace framework {
class Tensor;
class LoDTensor;
class SelectedRows;
class LoDRankTable;
class ReaderHolder;
class Scope;
} // namespace framework
namespace operators {
template <typename T>
class AlgorithmsCache;
class CudnnRNNCache;
namespace reader {
class LoDTensorBlockingQueueHolder;
} // namespace reader
} // namespace operators
} // namespace paddle
namespace paddle {
namespace framework {
const char *ToTypeName(int var_id);
const std::type_index &ToTypeIndex(int var_id);
int ToTypeId(const std::type_index &type);
namespace detail {
template <bool kStop, int kStart, int kEnd, typename T1, typename T2,
typename... Args>
struct TypePosFinderImpl {
static constexpr int kPos =
std::is_same<T1, T2>::value
? kStart
: TypePosFinderImpl<kStart + 2 == kEnd, kStart + 1, kEnd, T1,
Args...>::kPos;
};
template <int kStart, int kEnd, typename T1, typename T2>
struct TypePosFinderImpl<true, kStart, kEnd, T1, T2> {
static constexpr int kPos = std::is_same<T1, T2>::value ? kStart : -1;
};
// TypePosFinder helps to find the position in which T is inside Args...
// If T is not inside Args..., kPos would be -1
template <typename T, typename... Args>
struct TypePosFinder {
static constexpr int kPos =
TypePosFinderImpl<sizeof...(Args) == 1, 0, sizeof...(Args), T,
Args...>::kPos;
};
template <typename... Args>
struct VarTypeRegistryImpl {
static constexpr size_t kRegisteredTypeNum = sizeof...(Args);
using ArgTuple = std::tuple<Args...>;
// TypePos() returns the position in which T is inside Args...
// If T is not inside Args..., return -1
template <typename T>
static constexpr int TypePos() {
return TypePosFinder<T, Args...>::kPos;
}
// IsRegistered() returns whether T is registered inside RegistryImpl
template <typename T>
static constexpr bool IsRegistered() {
return TypePos<T>() >= 0;
}
};
} // namespace detail
#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \
template <> \
struct VarTypeTrait<type> { \
static_assert(VarTypeRegistry::IsRegistered<type>(), \
"Must be registered type"); \
using Type = type; \
static constexpr int kId = static_cast<int>(proto_id); \
}
/**
* The following codes are designed to register variable types.
* Only registered types can be stored in Variable.
* This registry mechanism is designed to speed up Variable.
*
* Caution: If you want to add more var types, please consider carefully
* whether you really need to add it.
*/
// Users should add other variable types below.
// Paddle would generate unique Ids for each registered variable types.
using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
std::map<size_t, Tensor>, operators::reader::LoDTensorBlockingQueueHolder,
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
ncclUniqueId, platform::Communicator,
#endif
operators::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>,
operators::CudnnRNNCache,
#endif
int, float>;
template <typename T>
struct VarTypeTrait {
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
using Type = T;
/**
* Unique VarType Id generation.
*
* The auto-generated id should not be the same as any protobuf id defined in
* framework.proto. Therefore, we generate id by adding the type pos and
* maximum protobuf id (i.e., proto::VarType::TUPLE).
*
* However, we may need more protobuf id in the future.
* To avoid changing this auto id generation algorithm frequently, we
* generate id by adding the type pos and twice of maximum protobuf id (i.e.,
* proto::VarType::TUPLE).
*/
static constexpr int kId = VarTypeRegistry::TypePos<T>() +
static_cast<int>(proto::VarType::TUPLE) * 2;
};
// Users should set some of variable type ids to be what is defined in
// framework.proto below
REG_PROTO_VAR_TYPE_TRAIT(LoDTensor, proto::VarType::LOD_TENSOR);
REG_PROTO_VAR_TYPE_TRAIT(SelectedRows, proto::VarType::SELECTED_ROWS);
REG_PROTO_VAR_TYPE_TRAIT(std::vector<Scope *>, proto::VarType::STEP_SCOPES);
REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY);
REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST);
REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER);
REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32);
REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32);
/** End of variable type registration */
template <typename T>
inline constexpr bool IsRegisteredVarType() {
return VarTypeRegistry::IsRegistered<T>();
}
#undef REG_PROTO_VAR_TYPE_TRAIT
} // namespace framework
} // namespace paddle

@ -0,0 +1,120 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cstdint>
#include <iostream>
#include <unordered_set>
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#endif
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#endif
namespace paddle {
namespace framework {
template <int kPos, int kEnd, bool kStop>
struct TypeIndexChecker {
template <typename SetType1, typename SetType2>
static void Check(SetType1 *var_id_set, SetType2 *type_index_set) {
using Type =
typename std::tuple_element<kPos, VarTypeRegistry::ArgTuple>::type;
static_assert(std::is_same<typename VarTypeTrait<Type>::Type, Type>::value,
"Type must be the same");
constexpr auto kId = VarTypeTrait<Type>::kId;
std::type_index actual_type(typeid(Type));
EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name()));
EXPECT_EQ(ToTypeIndex(kId), actual_type);
EXPECT_EQ(ToTypeId(actual_type), kId);
EXPECT_EQ(ToTypeIndex(ToTypeId(actual_type)), actual_type);
EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId);
EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT
EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT
var_id_set->insert(kId);
type_index_set->insert(std::type_index(typeid(Type)));
TypeIndexChecker<kPos + 1, kEnd, kPos + 1 == kEnd>::Check(var_id_set,
type_index_set);
}
};
template <int kPos, int kEnd>
struct TypeIndexChecker<kPos, kEnd, true> {
template <typename SetType1, typename SetType2>
static void Check(SetType1 *, SetType2 *) {}
};
TEST(var_type_traits, check_no_duplicate_registry) {
constexpr size_t kRegisteredNum = VarTypeRegistry::kRegisteredTypeNum;
std::unordered_set<int> var_id_set;
std::unordered_set<std::type_index> type_index_set;
TypeIndexChecker<0, kRegisteredNum, kRegisteredNum == 0>::Check(
&var_id_set, &type_index_set);
}
template <typename T>
bool CheckVarId(int proto_id) {
static_assert(std::is_same<typename VarTypeTrait<T>::Type, T>::value,
"Type must be the same");
return VarTypeTrait<T>::kId == proto_id;
}
TEST(var_type_traits, check_proto_type_id) {
ASSERT_TRUE(CheckVarId<LoDTensor>(proto::VarType::LOD_TENSOR));
ASSERT_TRUE(CheckVarId<SelectedRows>(proto::VarType::SELECTED_ROWS));
ASSERT_TRUE(CheckVarId<std::vector<Scope *>>(proto::VarType::STEP_SCOPES));
ASSERT_TRUE(CheckVarId<LoDRankTable>(proto::VarType::LOD_RANK_TABLE));
ASSERT_TRUE(CheckVarId<LoDTensorArray>(proto::VarType::LOD_TENSOR_ARRAY));
ASSERT_TRUE(CheckVarId<platform::PlaceList>(proto::VarType::PLACE_LIST));
ASSERT_TRUE(CheckVarId<ReaderHolder>(proto::VarType::READER));
ASSERT_TRUE(CheckVarId<int>(proto::VarType::INT32));
ASSERT_TRUE(CheckVarId<float>(proto::VarType::FP32));
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, proto::VarType::LOD_TENSOR);
ASSERT_EQ(proto::VarType_Type_SELECTED_ROWS, proto::VarType::SELECTED_ROWS);
ASSERT_EQ(proto::VarType_Type_STEP_SCOPES, proto::VarType::STEP_SCOPES);
ASSERT_EQ(proto::VarType_Type_LOD_RANK_TABLE, proto::VarType::LOD_RANK_TABLE);
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR_ARRAY,
proto::VarType::LOD_TENSOR_ARRAY);
ASSERT_EQ(proto::VarType_Type_PLACE_LIST, proto::VarType::PLACE_LIST);
ASSERT_EQ(proto::VarType_Type_READER, proto::VarType::READER);
ASSERT_EQ(proto::VarType_Type_FEED_MINIBATCH, proto::VarType::FEED_MINIBATCH);
ASSERT_EQ(proto::VarType_Type_FETCH_LIST, proto::VarType::FETCH_LIST);
ASSERT_EQ(proto::VarType_Type_RAW, proto::VarType::RAW);
ASSERT_EQ(proto::VarType_Type_TUPLE, proto::VarType::TUPLE);
ASSERT_EQ(proto::VarType_Type_INT32, proto::VarType::INT32);
ASSERT_EQ(proto::VarType_Type_FP32, proto::VarType::FP32);
}
TEST(var_type_traits, test_registry) {
using Registry = detail::VarTypeRegistryImpl<int8_t, int32_t, size_t, double>;
ASSERT_TRUE(Registry::TypePos<int8_t>() == 0);
ASSERT_TRUE(Registry::TypePos<int32_t>() == 1);
ASSERT_TRUE(Registry::TypePos<size_t>() == 2);
ASSERT_TRUE(Registry::TypePos<double>() == 3);
ASSERT_TRUE(Registry::TypePos<float>() == -1);
}
} // namespace framework
} // namespace paddle

@ -18,7 +18,7 @@
#include <typeindex>
#include <typeinfo>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/framework/var_type_traits.h"
namespace paddle {
namespace framework {
@ -27,10 +27,14 @@ class Variable {
public:
template <typename T>
const T& Get() const {
static_assert(
IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h");
PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing");
PADDLE_ENFORCE(IsType<T>(),
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name());
ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
return *static_cast<const T*>(holder_->Ptr());
}
@ -39,61 +43,61 @@ class Variable {
template <typename T>
T* GetMutable() {
if (!holder_) {
holder_.reset(new PlaceholderImpl<T>(new T()));
holder_.reset(new PlaceholderImpl<T>());
} else {
PADDLE_ENFORCE(IsType<T>(),
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name());
ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
}
return static_cast<T*>(holder_->Ptr());
}
template <typename T>
bool IsType() const {
return holder_ != nullptr &&
std::type_index(typeid(T)) == std::type_index(holder_->Type());
return holder_ && holder_->Type() == VarTypeTrait<T>::kId;
}
void Clear() { holder_.reset(); }
std::type_index Type() const {
int Type() const {
PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
return holder_->Type();
}
private:
struct Placeholder {
virtual ~Placeholder() {}
virtual const std::type_info& Type() const = 0;
virtual void* Ptr() const = 0;
virtual ~Placeholder() = default;
inline int Type() const { return type_; }
inline const void* Ptr() const { return ptr_; }
inline void* Ptr() { return ptr_; }
protected:
inline void Init(void* p, int type) {
ptr_ = p;
type_ = type;
}
void* ptr_;
int type_;
};
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
template <typename T>
struct PlaceholderImpl : public Placeholder {
explicit PlaceholderImpl(T* ptr) : ptr_(ptr), type_(typeid(T)) {}
virtual const std::type_info& Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
static_assert(
IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h");
PlaceholderImpl() { this->Init(&obj_, VarTypeTrait<T>::kId); }
std::unique_ptr<T> ptr_;
const std::type_info& type_;
private:
T obj_;
};
std::unique_ptr<Placeholder>
holder_; // pointers to a PlaceholderImpl object indeed.
// name_ is only meaningful with a Scope and accessible by it.
//
// NOTE: Please don't expose name_ by adding methods like
// Variable::Name or Scope::VarName! A variable could have a human
// readable name or an auto-generated scope-unique name. In the
// former case, the caller knows the name and doesn't need to access
// the name; in the latter case, the variable should be identified
// by its address but not the unreadable name.
friend class Scope;
const std::string* name_;
// pointers to a PlaceholderImpl object indeed.
std::unique_ptr<Placeholder> holder_;
};
} // namespace framework

@ -16,27 +16,28 @@
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
TEST(Variable, GetMutable) {
using paddle::framework::Variable;
struct Tensor {
int content_;
};
namespace paddle {
namespace framework {
TEST(Variable, GetMutable) {
std::unique_ptr<Variable> v(new Variable());
Tensor* t = v->GetMutable<Tensor>();
t->content_ = 1234;
auto* t = v->GetMutable<std::string>();
*t = "1234";
const Tensor& tt = v->Get<Tensor>();
EXPECT_EQ(1234, tt.content_);
const auto& tt = v->Get<std::string>();
EXPECT_EQ("1234", tt);
try {
v->GetMutable<std::string>();
v->GetMutable<Tensor>();
} catch (std::exception& e) {
return;
}
EXPECT_TRUE(false);
}
} // namespace framework
} // namespace paddle

@ -25,7 +25,7 @@ void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) {
// TODO(Superjomn) should avoid the case when a TensorArray is a
// parameter.
if (var_name == "feed" || var_name == "fetch") continue;
if (var->Type() == typeid(framework::LoDTensorArray)) {
if (var->IsType<framework::LoDTensorArray>()) {
VLOG(4) << "collect " << var_name;
arrays_.push_back(var->GetMutable<framework::LoDTensorArray>());
}

@ -27,8 +27,11 @@ namespace details {
// training phase.
struct TensorArrayBatchCleaner {
TensorArrayBatchCleaner() {
valid_types_.insert(typeid(framework::Tensor));
valid_types_.insert(typeid(framework::LoDTensor));
constexpr auto kTensorId = framework::VarTypeTrait<framework::Tensor>::kId;
constexpr auto kLoDTensorId =
framework::VarTypeTrait<framework::LoDTensor>::kId;
valid_types_.insert(kTensorId);
valid_types_.insert(kLoDTensorId);
}
// Collect the variables that are not Tensor or LoDTensor, and reset them to a
// bool(trick), because some of them are containers, and some operators just
@ -46,7 +49,7 @@ struct TensorArrayBatchCleaner {
bool no_tensor_flag_{true};
std::vector<framework::LoDTensorArray *> arrays_;
std::unordered_set<std::type_index> valid_types_;
std::unordered_set<int> valid_types_;
std::unordered_set<framework::Variable *> no_tensor_vars_;
};

@ -64,7 +64,7 @@ class ClipByNormKernel : public framework::OpKernel<T> {
output->mutable_data<T>(context.GetPlace());
} else {
PADDLE_THROW("Unexpected branch, input variable type is %s",
in_var->Type().name());
framework::ToTypeName(in_var->Type()));
}
PADDLE_ENFORCE_NOT_NULL(input);

@ -175,14 +175,13 @@ class WhileGradOp : public framework::OperatorBase {
auto &og_inside =
detail::Ref(cur_scope.Var(inside_og_name),
"Cannot find inside gradient %s", inside_og_name);
if (framework::IsType<framework::LoDTensor>(og_outside.Type())) {
if (og_outside.IsType<framework::LoDTensor>()) {
auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
auto &inside_tensor =
detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
inside_tensor.set_lod(outside_tensor.lod());
inside_tensor.ShareDataWith(outside_tensor);
} else if (framework::IsType<framework::LoDTensorArray>(
og_outside.Type())) {
} else if (og_outside.IsType<framework::LoDTensorArray>()) {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
@ -256,7 +255,7 @@ class WhileGradOp : public framework::OperatorBase {
var->IsType<LoDTensor>(),
"Currently the type of var only can be LoDTensorArray, "
"or LoDTensor, but the received var[%s] is %s.",
inside_grad_name, var->Type().name());
inside_grad_name, framework::ToTypeName(var->Type()));
if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>();

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -116,7 +116,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
} else {
PADDLE_THROW(
"% should be LoDTensor or SelectedRows, but the received type is %s",
ctx.Inputs("Ids")[0], ids_var->Type().name());
ctx.Inputs("Ids")[0], framework::ToTypeName(ids_var->Type()));
}
}
};

@ -83,7 +83,7 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
z = ctx.Output<framework::LoDTensor>("Out");
} else {
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
x_var->Type().name());
framework::ToTypeName(x_var->Type()));
}
z->mutable_data<T>(ctx.GetPlace());

@ -27,12 +27,14 @@ class AdadeltaOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
ctx.Inputs("Grad").front(),
framework::ToTypeName(grad_var->Type()));
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor =

@ -50,7 +50,8 @@ class AdagradOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
auto *param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto *moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");

@ -347,7 +347,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save