|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/var_type_traits.h"
|
|
|
|
|
#include "paddle/fluid/platform/macros.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -23,54 +24,83 @@ namespace detail {
|
|
|
|
|
|
|
|
|
|
template <int kStart, int kEnd, bool kStop>
|
|
|
|
|
struct VarIdToTypeIndexMapInitializerImpl {
|
|
|
|
|
static void Init(std::unordered_map<int, std::type_index> *m) {
|
|
|
|
|
template <typename MapType1, typename MapType2>
|
|
|
|
|
static void Init(MapType1 *id_to_type, MapType2 *type_to_id) {
|
|
|
|
|
using Type =
|
|
|
|
|
typename std::tuple_element<kStart, VarTypeRegistry::ArgTuple>::type;
|
|
|
|
|
static_assert(!std::is_same<Type, void>::value, "Type cannot be void");
|
|
|
|
|
constexpr int kId = VarTypeTrait<Type>::kId;
|
|
|
|
|
if (!std::is_same<Type, void>::value) {
|
|
|
|
|
m->emplace(kId, std::type_index(typeid(Type)));
|
|
|
|
|
}
|
|
|
|
|
auto type = std::type_index(typeid(Type));
|
|
|
|
|
PADDLE_ENFORCE(id_to_type->count(kId) == 0,
|
|
|
|
|
"Registered duplicate type id %d for type %s", kId,
|
|
|
|
|
type.name());
|
|
|
|
|
PADDLE_ENFORCE(type_to_id->count(type) == 0,
|
|
|
|
|
"Registered duplicate type_index %s for id %d", type.name(),
|
|
|
|
|
kId);
|
|
|
|
|
id_to_type->emplace(kId, type);
|
|
|
|
|
type_to_id->emplace(type, kId);
|
|
|
|
|
VarIdToTypeIndexMapInitializerImpl<kStart + 1, kEnd,
|
|
|
|
|
kStart + 1 == kEnd>::Init(m);
|
|
|
|
|
kStart + 1 == kEnd>::Init(id_to_type,
|
|
|
|
|
type_to_id);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <int kStart, int kEnd>
|
|
|
|
|
struct VarIdToTypeIndexMapInitializerImpl<kStart, kEnd, true> {
|
|
|
|
|
static void Init(std::unordered_map<int, std::type_index> *m) {}
|
|
|
|
|
template <typename MapType1, typename MapType2>
|
|
|
|
|
static void Init(MapType1 *, MapType2 *) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// VarIdToTypeIndexMapInitializer is designed to initialize var_id ->
|
|
|
|
|
// std::type_index map
|
|
|
|
|
// std::type_index map and std::type_index -> var_id map
|
|
|
|
|
using VarIdToTypeIndexMapInitializer =
|
|
|
|
|
VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum,
|
|
|
|
|
VarTypeRegistry::kRegisteredTypeNum ==
|
|
|
|
|
0>;
|
|
|
|
|
|
|
|
|
|
struct VarIdToTypeIndexMapHolder {
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(VarIdToTypeIndexMapHolder);
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
static const std::type_index &ToTypeIndex(int var_id) {
|
|
|
|
|
static const VarIdToTypeIndexMapHolder instance;
|
|
|
|
|
auto it = instance.var_type_map_.find(var_id);
|
|
|
|
|
PADDLE_ENFORCE(it != instance.var_type_map_.end(),
|
|
|
|
|
auto it = Instance().id_to_type_map_.find(var_id);
|
|
|
|
|
PADDLE_ENFORCE(it != Instance().id_to_type_map_.end(),
|
|
|
|
|
"VarId %d is not registered.", var_id);
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static int ToTypeId(const std::type_index &type) {
|
|
|
|
|
auto it = Instance().type_to_id_map_.find(type);
|
|
|
|
|
PADDLE_ENFORCE(it != Instance().type_to_id_map_.end(),
|
|
|
|
|
"VarType %s is not registered.", type.name());
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
VarIdToTypeIndexMapHolder() {
|
|
|
|
|
VarIdToTypeIndexMapInitializer::Init(&var_type_map_);
|
|
|
|
|
VarIdToTypeIndexMapInitializer::Init(&id_to_type_map_, &type_to_id_map_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static const VarIdToTypeIndexMapHolder &Instance() {
|
|
|
|
|
static const VarIdToTypeIndexMapHolder instance;
|
|
|
|
|
return instance;
|
|
|
|
|
}
|
|
|
|
|
std::unordered_map<int, std::type_index> var_type_map_;
|
|
|
|
|
|
|
|
|
|
std::unordered_map<int, std::type_index> id_to_type_map_;
|
|
|
|
|
std::unordered_map<std::type_index, int> type_to_id_map_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace detail
|
|
|
|
|
|
|
|
|
|
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
|
|
|
|
|
|
|
|
|
|
const std::type_index &ToTypeIndex(int var_id) {
|
|
|
|
|
return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
|
|
|
|
|
|
|
|
|
|
int ToTypeId(const std::type_index &type) {
|
|
|
|
|
return detail::VarIdToTypeIndexMapHolder::ToTypeId(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|