Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into imperative_mnist
test=developrevert-15207-remove_op_handle_lock_and_fix_var
commit
9e3155e01d
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,222 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h"
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/operators/math/cpu_vec.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
#define GET_CONV_BN_NODES(pattern_name) \
|
||||
/* OPERATORS */ \
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \
|
||||
GET_IR_NODE_FROM_SUBGRAPH(affine_channel, affine_channel, pattern_name); \
|
||||
/* CONV inputs */ \
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \
|
||||
/* CONV outputs */ \
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \
|
||||
/* Affine Channel inputs */ \
|
||||
GET_IR_NODE_FROM_SUBGRAPH(ac_scale, ac_scale, pattern_name); \
|
||||
GET_IR_NODE_FROM_SUBGRAPH(ac_bias, ac_bias, pattern_name); \
|
||||
/* Affine channel outputs */ \
|
||||
GET_IR_NODE_FROM_SUBGRAPH(ac_out, ac_out, pattern_name); /* Out */
|
||||
|
||||
void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
|
||||
const ir::Node& ac_scale,
|
||||
const LoDTensor& ac_bias_tensor,
|
||||
LoDTensor* eltwise_y_in_tensor) {
|
||||
using EigenVectorArrayMap =
|
||||
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
|
||||
using ConstEigenVectorArrayMap =
|
||||
Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
|
||||
using EigenMatrixArrayMap = Eigen::Map<
|
||||
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
|
||||
// Re-compute bias of conv2d from AffineChannel
|
||||
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), ac_bias_tensor.dims());
|
||||
|
||||
auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable<LoDTensor>();
|
||||
|
||||
ConstEigenVectorArrayMap scale_array(scale_tensor->data<float>(),
|
||||
scale_tensor->numel(), 1);
|
||||
ConstEigenVectorArrayMap ac_bias_array(ac_bias_tensor.data<float>(),
|
||||
ac_bias_tensor.numel(), 1);
|
||||
|
||||
EigenVectorArrayMap eltwise_y_in_array(
|
||||
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
|
||||
eltwise_y_in_tensor->numel(), 1);
|
||||
|
||||
eltwise_y_in_array = (eltwise_y_in_array * scale_array) + ac_bias_array;
|
||||
|
||||
// Re-compute weight of conv2d from AffineChannel
|
||||
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
|
||||
auto weights_shape = weights->dims();
|
||||
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
|
||||
|
||||
EigenMatrixArrayMap weights_array_2d(
|
||||
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
|
||||
weights_shape_2d[1]);
|
||||
|
||||
weights_array_2d.colwise() *= scale_array;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
auto* scope = param_scope();
|
||||
PADDLE_ENFORCE(scope);
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto* conv_input =
|
||||
gpd.mutable_pattern()
|
||||
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
|
||||
->AsInput()
|
||||
->assert_is_op_input("conv2d", "Input");
|
||||
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
|
||||
name_scope_);
|
||||
conv_ac_pattern(conv_input, false /*with_eltwise_add*/);
|
||||
|
||||
int found_conv_ac_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle ConvAffineChannel fuse";
|
||||
|
||||
GET_CONV_BN_NODES(conv_ac_pattern);
|
||||
|
||||
// check if fuse can be done and if MKL-DNN should be used
|
||||
FuseOptions fuse_option = FindFuseOption(*conv, *affine_channel);
|
||||
if (fuse_option == DO_NOT_FUSE) {
|
||||
VLOG(3) << "do not perform conv+affinechannel fuse";
|
||||
return;
|
||||
}
|
||||
|
||||
// Create eltwise_y (conv bias) variable
|
||||
VarDesc eltwise_y_in_desc(
|
||||
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
|
||||
eltwise_y_in_desc.SetPersistable(true);
|
||||
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
|
||||
auto* eltwise_y_in_tensor =
|
||||
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
|
||||
|
||||
// Get affine_channel bias
|
||||
auto* ac_bias_tensor =
|
||||
scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>();
|
||||
|
||||
// Initialize eltwise_y
|
||||
eltwise_y_in_tensor->Resize(ac_bias_tensor->dims());
|
||||
std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
|
||||
eltwise_y_in_tensor->numel(), 0.0f);
|
||||
|
||||
// update weights and biases
|
||||
recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor,
|
||||
eltwise_y_in_tensor);
|
||||
|
||||
// create an elementwise add node.
|
||||
OpDesc desc;
|
||||
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
|
||||
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
|
||||
desc.SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
|
||||
desc.SetType("elementwise_add");
|
||||
desc.SetAttr("axis", 1);
|
||||
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
|
||||
|
||||
GraphSafeRemoveNodes(graph.get(), {ac_scale, ac_bias, affine_channel});
|
||||
|
||||
IR_NODE_LINK_TO(conv_out, eltwise_op);
|
||||
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
|
||||
IR_NODE_LINK_TO(eltwise_op, ac_out);
|
||||
found_conv_ac_count++;
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
|
||||
AddStatis(found_conv_ac_count);
|
||||
return graph;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
auto* scope = param_scope();
|
||||
PADDLE_ENFORCE(scope);
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto* conv_input =
|
||||
gpd.mutable_pattern()
|
||||
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
|
||||
->AsInput()
|
||||
->assert_is_op_input("conv2d", "Input");
|
||||
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
|
||||
name_scope_);
|
||||
conv_ac_pattern(conv_input, true /*with_eltwise_add*/);
|
||||
|
||||
int found_conv_ac_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle ConvBN fuse";
|
||||
|
||||
GET_CONV_BN_NODES(conv_ac_pattern);
|
||||
// OPERATORS
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_ac_pattern);
|
||||
// BIAS inputs
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_ac_pattern);
|
||||
// BIAS outputs
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_ac_pattern);
|
||||
|
||||
// Get eltwise_y (conv bias) variable
|
||||
auto* eltwise_y_in_tensor =
|
||||
scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();
|
||||
|
||||
// Get batch norm bias
|
||||
auto* ac_bias_tensor =
|
||||
scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>();
|
||||
|
||||
recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor,
|
||||
eltwise_y_in_tensor);
|
||||
|
||||
// Update the elementwise_add node
|
||||
eltwise->Op()->SetAttr("axis", 1);
|
||||
eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
|
||||
|
||||
GraphSafeRemoveNodes(graph.get(),
|
||||
{ac_scale, ac_bias, affine_channel, eltwise_out});
|
||||
|
||||
IR_NODE_LINK_TO(eltwise, ac_out);
|
||||
|
||||
found_conv_ac_count++;
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
AddStatis(found_conv_ac_count);
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(conv_affine_channel_fuse_pass,
|
||||
paddle::framework::ir::ConvAffineChannelFusePass);
|
||||
REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass,
|
||||
paddle::framework::ir::ConvEltwiseAddAffineChannelFusePass);
|
@ -0,0 +1,49 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Fuse the Conv and ConvAffineChannel.
|
||||
*/
|
||||
class ConvAffineChannelFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ConvAffineChannelFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
const std::string name_scope_{"conv_affine_channel_fuse"};
|
||||
};
|
||||
|
||||
class ConvEltwiseAddAffineChannelFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ConvEltwiseAddAffineChannelFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,119 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/var_type_traits.h"
|
||||
#include "paddle/fluid/framework/lod_rank_table.h"
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#ifndef _WIN32
|
||||
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
|
||||
#endif
|
||||
#include <cudnn.h>
|
||||
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
|
||||
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// Besides registering variable type id, it is helpful to register a
|
||||
// var_id -> std::type_index map (for example, get type names according to id)
|
||||
namespace detail {
|
||||
|
||||
template <int kStart, int kEnd, bool kStop>
|
||||
struct VarIdToTypeIndexMapInitializerImpl {
|
||||
template <typename MapType1, typename MapType2>
|
||||
static void Init(MapType1 *id_to_type, MapType2 *type_to_id) {
|
||||
using Type =
|
||||
typename std::tuple_element<kStart, VarTypeRegistry::ArgTuple>::type;
|
||||
static_assert(!std::is_same<Type, void>::value, "Type cannot be void");
|
||||
constexpr int kId = VarTypeTrait<Type>::kId;
|
||||
auto type = std::type_index(typeid(Type));
|
||||
PADDLE_ENFORCE(id_to_type->count(kId) == 0,
|
||||
"Registered duplicate type id %d for type %s", kId,
|
||||
type.name());
|
||||
PADDLE_ENFORCE(type_to_id->count(type) == 0,
|
||||
"Registered duplicate type_index %s for id %d", type.name(),
|
||||
kId);
|
||||
id_to_type->emplace(kId, type);
|
||||
type_to_id->emplace(type, kId);
|
||||
VarIdToTypeIndexMapInitializerImpl<kStart + 1, kEnd,
|
||||
kStart + 1 == kEnd>::Init(id_to_type,
|
||||
type_to_id);
|
||||
}
|
||||
};
|
||||
|
||||
template <int kStart, int kEnd>
|
||||
struct VarIdToTypeIndexMapInitializerImpl<kStart, kEnd, true> {
|
||||
template <typename MapType1, typename MapType2>
|
||||
static void Init(MapType1 *, MapType2 *) {}
|
||||
};
|
||||
|
||||
// VarIdToTypeIndexMapInitializer is designed to initialize var_id ->
|
||||
// std::type_index map and std::type_index -> var_id map
|
||||
using VarIdToTypeIndexMapInitializer =
|
||||
VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum,
|
||||
VarTypeRegistry::kRegisteredTypeNum ==
|
||||
0>;
|
||||
|
||||
struct VarIdToTypeIndexMapHolder {
|
||||
DISABLE_COPY_AND_ASSIGN(VarIdToTypeIndexMapHolder);
|
||||
|
||||
public:
|
||||
static const std::type_index &ToTypeIndex(int var_id) {
|
||||
auto it = Instance().id_to_type_map_.find(var_id);
|
||||
PADDLE_ENFORCE(it != Instance().id_to_type_map_.end(),
|
||||
"VarId %d is not registered.", var_id);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static int ToTypeId(const std::type_index &type) {
|
||||
auto it = Instance().type_to_id_map_.find(type);
|
||||
PADDLE_ENFORCE(it != Instance().type_to_id_map_.end(),
|
||||
"VarType %s is not registered.", type.name());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
VarIdToTypeIndexMapHolder() {
|
||||
VarIdToTypeIndexMapInitializer::Init(&id_to_type_map_, &type_to_id_map_);
|
||||
}
|
||||
|
||||
static const VarIdToTypeIndexMapHolder &Instance() {
|
||||
static const VarIdToTypeIndexMapHolder instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::unordered_map<int, std::type_index> id_to_type_map_;
|
||||
std::unordered_map<std::type_index, int> type_to_id_map_;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
const std::type_index &ToTypeIndex(int var_id) {
|
||||
return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
|
||||
}
|
||||
|
||||
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
|
||||
|
||||
int ToTypeId(const std::type_index &type) {
|
||||
return detail::VarIdToTypeIndexMapHolder::ToTypeId(type);
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,195 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <typeindex>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/framework.pb.h"
|
||||
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include <cudnn.h>
|
||||
#ifndef _WIN32
|
||||
#include <nccl.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Users should add forward declarations here
|
||||
namespace paddle {
|
||||
|
||||
namespace platform {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#ifndef _WIN32
|
||||
class Communicator;
|
||||
#endif
|
||||
#endif
|
||||
} // namespace platform
|
||||
|
||||
namespace framework {
|
||||
class Tensor;
|
||||
class LoDTensor;
|
||||
class SelectedRows;
|
||||
class LoDRankTable;
|
||||
class ReaderHolder;
|
||||
class Scope;
|
||||
} // namespace framework
|
||||
|
||||
namespace operators {
|
||||
template <typename T>
|
||||
class AlgorithmsCache;
|
||||
|
||||
class CudnnRNNCache;
|
||||
|
||||
namespace reader {
|
||||
class LoDTensorBlockingQueueHolder;
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
|
||||
} // namespace paddle
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
const char *ToTypeName(int var_id);
|
||||
const std::type_index &ToTypeIndex(int var_id);
|
||||
int ToTypeId(const std::type_index &type);
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <bool kStop, int kStart, int kEnd, typename T1, typename T2,
|
||||
typename... Args>
|
||||
struct TypePosFinderImpl {
|
||||
static constexpr int kPos =
|
||||
std::is_same<T1, T2>::value
|
||||
? kStart
|
||||
: TypePosFinderImpl<kStart + 2 == kEnd, kStart + 1, kEnd, T1,
|
||||
Args...>::kPos;
|
||||
};
|
||||
|
||||
template <int kStart, int kEnd, typename T1, typename T2>
|
||||
struct TypePosFinderImpl<true, kStart, kEnd, T1, T2> {
|
||||
static constexpr int kPos = std::is_same<T1, T2>::value ? kStart : -1;
|
||||
};
|
||||
|
||||
// TypePosFinder helps to find the position in which T is inside Args...
|
||||
// If T is not inside Args..., kPos would be -1
|
||||
template <typename T, typename... Args>
|
||||
struct TypePosFinder {
|
||||
static constexpr int kPos =
|
||||
TypePosFinderImpl<sizeof...(Args) == 1, 0, sizeof...(Args), T,
|
||||
Args...>::kPos;
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
struct VarTypeRegistryImpl {
|
||||
static constexpr size_t kRegisteredTypeNum = sizeof...(Args);
|
||||
using ArgTuple = std::tuple<Args...>;
|
||||
|
||||
// TypePos() returns the position in which T is inside Args...
|
||||
// If T is not inside Args..., return -1
|
||||
template <typename T>
|
||||
static constexpr int TypePos() {
|
||||
return TypePosFinder<T, Args...>::kPos;
|
||||
}
|
||||
|
||||
// IsRegistered() returns whether T is registered inside RegistryImpl
|
||||
template <typename T>
|
||||
static constexpr bool IsRegistered() {
|
||||
return TypePos<T>() >= 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \
|
||||
template <> \
|
||||
struct VarTypeTrait<type> { \
|
||||
static_assert(VarTypeRegistry::IsRegistered<type>(), \
|
||||
"Must be registered type"); \
|
||||
using Type = type; \
|
||||
static constexpr int kId = static_cast<int>(proto_id); \
|
||||
}
|
||||
|
||||
/**
|
||||
* The following codes are designed to register variable types.
|
||||
* Only registered types can be stored in Variable.
|
||||
* This registry mechanism is designed to speed up Variable.
|
||||
*
|
||||
* Caution: If you want to add more var types, please consider carefully
|
||||
* whether you really need to add it.
|
||||
*/
|
||||
|
||||
// Users should add other variable types below.
|
||||
// Paddle would generate unique Ids for each registered variable types.
|
||||
using VarTypeRegistry = detail::VarTypeRegistryImpl<
|
||||
Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
|
||||
LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
|
||||
std::map<size_t, Tensor>, operators::reader::LoDTensorBlockingQueueHolder,
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#ifndef _WIN32
|
||||
ncclUniqueId, platform::Communicator,
|
||||
#endif
|
||||
operators::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>,
|
||||
operators::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>,
|
||||
operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>,
|
||||
operators::CudnnRNNCache,
|
||||
#endif
|
||||
int, float>;
|
||||
|
||||
template <typename T>
|
||||
struct VarTypeTrait {
|
||||
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
|
||||
using Type = T;
|
||||
/**
|
||||
* Unique VarType Id generation.
|
||||
*
|
||||
* The auto-generated id should not be the same as any protobuf id defined in
|
||||
* framework.proto. Therefore, we generate id by adding the type pos and
|
||||
* maximum protobuf id (i.e., proto::VarType::TUPLE).
|
||||
*
|
||||
* However, we may need more protobuf id in the future.
|
||||
* To avoid changing this auto id generation algorithm frequently, we
|
||||
* generate id by adding the type pos and twice of maximum protobuf id (i.e.,
|
||||
* proto::VarType::TUPLE).
|
||||
*/
|
||||
static constexpr int kId = VarTypeRegistry::TypePos<T>() +
|
||||
static_cast<int>(proto::VarType::TUPLE) * 2;
|
||||
};
|
||||
|
||||
// Users should set some of variable type ids to be what is defined in
|
||||
// framework.proto below
|
||||
REG_PROTO_VAR_TYPE_TRAIT(LoDTensor, proto::VarType::LOD_TENSOR);
|
||||
REG_PROTO_VAR_TYPE_TRAIT(SelectedRows, proto::VarType::SELECTED_ROWS);
|
||||
REG_PROTO_VAR_TYPE_TRAIT(std::vector<Scope *>, proto::VarType::STEP_SCOPES);
|
||||
REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
|
||||
REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY);
|
||||
REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST);
|
||||
REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER);
|
||||
REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32);
|
||||
REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32);
|
||||
|
||||
/** End of variable type registration */
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool IsRegisteredVarType() {
|
||||
return VarTypeRegistry::IsRegistered<T>();
|
||||
}
|
||||
|
||||
#undef REG_PROTO_VAR_TYPE_TRAIT
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue