rewrite variable type

test=develop
revert-15207-remove_op_handle_lock_and_fix_var
sneaxiy 6 years ago
parent 16c244bc3f
commit ae6f46a1a9

@ -78,17 +78,25 @@ cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memor
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc)
cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto)
if (WITH_GPU)
target_link_libraries(var_type_traits cudnn)
if (NOT WIN32)
target_link_libraries(var_type_traits nccl)
endif()
endif()
cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits)
cc_library(scope SRCS scope.cc DEPS glog threadpool var_type_traits)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_test(variable_test SRCS variable_test.cc DEPS tensor var_type_traits)
cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry device_context math_function)
DEPS operator op_registry device_context math_function scope)
if(WITH_GPU)
if (WIN32)

@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"

@ -88,7 +88,7 @@ void EagerDeletionOpHandle::RunImpl() {
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
framework::ToTypeName(var->Type()), name);
}
}

@ -24,7 +24,7 @@ static void VisitVariable(Variable* var, Func* func) {
} else if (var->IsType<SelectedRows>()) {
(*func)(var->GetMutable<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var->Type().name());
PADDLE_THROW("Not supported type %s", ToTypeName(var->Type()));
}
}
@ -35,7 +35,7 @@ static void VisitVariable(const Variable& var, Func* func) {
} else if (var.IsType<SelectedRows>()) {
(*func)(var.Get<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var.Type().name());
PADDLE_THROW("Not supported type %s", ToTypeName(var.Type()));
}
}

@ -119,7 +119,7 @@ static void DeleteUnusedTensors(
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
framework::ToTypeName(var->Type()), name);
}
}
}

@ -365,7 +365,7 @@ const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
return &(var.Get<SelectedRows>().value());
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var.Type().name());
ToTypeName(var.Type()));
}
}
@ -376,7 +376,7 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
return var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
ToTypeName(var->Type()));
}
}
@ -430,7 +430,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
sub_name, ToTypeName(var->Type()));
return &(var->Get<LoDTensor>());
});
return res;
@ -454,7 +454,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
sub_name, ToTypeName(var->Type()));
return var->GetMutable<LoDTensor>();
});
return res;
@ -641,7 +641,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
PADDLE_THROW(
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
name, ToTypeName(var->Type()));
}
}
@ -657,7 +657,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
name, var->Type().name());
name, ToTypeName(var->Type()));
}
}

@ -288,6 +288,18 @@ class ExecutionContext {
const platform::DeviceContext& device_context_;
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

@ -165,11 +165,9 @@ std::string Scope::Rename(const std::string& origin_name) const {
Variable* Scope::VarInternal(const std::string& name) {
auto* v = FindVarLocally(name);
if (v != nullptr) return v;
v = new Variable();
vars_[name].reset(v);
vars_.emplace(name, std::unique_ptr<Variable>(v));
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
}

@ -19,35 +19,33 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace framework {
template <typename T>
inline bool IsType(const std::type_index& type_index) {
return type_index == std::type_index(typeid(T));
inline bool IsType(const std::type_index& type) {
return type == typeid(T);
}
inline proto::VarType::Type ToVarType(std::type_index type) {
if (IsType<LoDTensor>(type)) {
return proto::VarType_Type_LOD_TENSOR;
} else if (IsType<LoDRankTable>(type)) {
return proto::VarType_Type_LOD_RANK_TABLE;
} else if (IsType<LoDTensorArray>(type)) {
return proto::VarType_Type_LOD_TENSOR_ARRAY;
} else if (IsType<SelectedRows>(type)) {
return proto::VarType_Type_SELECTED_ROWS;
} else if (IsType<ReaderHolder>(type)) {
return proto::VarType_Type_READER;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
inline proto::VarType::Type ToVarType(int type) {
switch (type) {
case proto::VarType::LOD_TENSOR:
case proto::VarType::SELECTED_ROWS:
case proto::VarType::LOD_RANK_TABLE:
case proto::VarType::LOD_TENSOR_ARRAY:
case proto::VarType::READER:
return static_cast<proto::VarType::Type>(type);
default:
PADDLE_THROW("ToVarType:Unsupported type %d", type);
}
}
template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) {
switch (var.Type()) {
case proto::VarType_Type_LOD_TENSOR:
visitor(var.Get<LoDTensor>());
return;
@ -64,7 +62,7 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
visitor(var.Get<ReaderHolder>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
PADDLE_THROW("Not supported visit type, %s", ToTypeName(var.Type()));
}
}

@ -0,0 +1,27 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/var_type_traits.h"
namespace paddle {
namespace framework {
const char* ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
const std::type_index& ToTypeIndex(int var_id) {
return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
}
} // namespace framework
} // namespace paddle

@ -0,0 +1,207 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <tuple>
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#include <nccl.h>
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#endif
#include <cudnn.h>
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#endif
namespace paddle {
namespace framework {
namespace detail {
template <bool kStop, int kStart, int kEnd, typename T1, typename T2,
typename... Args>
struct TypePosFinderImpl {
static constexpr int kPos =
std::is_same<T1, T2>::value
? kStart
: TypePosFinderImpl<kStart + 2 == kEnd, kStart + 1, kEnd, T1,
Args...>::kPos;
};
template <int kStart, int kEnd, typename T1, typename T2>
struct TypePosFinderImpl<true, kStart, kEnd, T1, T2> {
static constexpr int kPos = std::is_same<T1, T2>::value ? kStart : -1;
};
// TypePosFinder helps to find the position in which T is inside Args...
// If T is not inside Args..., kPos would be -1
template <typename T, typename... Args>
struct TypePosFinder {
static constexpr int kPos =
TypePosFinderImpl<sizeof...(Args) == 1, 0, sizeof...(Args), T,
Args...>::kPos;
};
template <typename... Args>
struct VarTypeRegistryImpl {
static constexpr size_t kRegisteredTypeNum = sizeof...(Args);
using ArgTuple = std::tuple<Args...>;
// TypePos() returns the position in which T is inside Args...
// If T is not inside Args... or T is void, return -1
template <typename T>
static constexpr int TypePos() {
return std::is_same<T, void>::value ? -1 : TypePosFinder<T, Args...>::kPos;
}
// IsRegistered() returns whether T is registered inside RegistryImpl
template <typename T>
static constexpr bool IsRegistered() {
return TypePos<T>() >= 0;
}
};
} // namespace detail
#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \
template <> \
struct VarTypeTrait<type> { \
static_assert(VarTypeRegistry::IsRegistered<type>(), \
"Must be registered type"); \
using Type = type; \
static constexpr int kId = proto_id; \
}
/**
* The following codes are designed to register variable types.
* Only registered types can be stored in Variable.
* This registry mechanism is designed to speed up Variable.
*/
// Users should add other variable types below.
// Paddle would generate unique Ids for each registered variable types.
class Scope;
using VarTypeRegistry = detail::VarTypeRegistryImpl<
LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable, LoDTensorArray,
platform::PlaceList, ReaderHolder, Tensor, std::string, Scope *,
std::map<size_t, Tensor>, operators::reader::LoDTensorBlockingQueueHolder,
int, float,
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
ncclUniqueId, platform::Communicator,
#endif
operators::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>,
operators::CudnnRNNCache,
#endif
void>; // void indicates end of registration, add other types before void
template <typename T>
struct VarTypeTrait {
static_assert(std::is_same<T, void>::value ||
VarTypeRegistry::IsRegistered<T>(),
"Must be registered type");
using Type = T;
// Default id generation
static constexpr int kId = VarTypeRegistry::TypePos<T>() +
static_cast<int>(proto::VarType::TUPLE) * 2;
};
// Users should set some of variable type ids to be what is defined in
// framework.proto here
REG_PROTO_VAR_TYPE_TRAIT(LoDTensor, proto::VarType::LOD_TENSOR);
REG_PROTO_VAR_TYPE_TRAIT(SelectedRows, proto::VarType::SELECTED_ROWS);
REG_PROTO_VAR_TYPE_TRAIT(std::vector<Scope *>, proto::VarType::STEP_SCOPES);
REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY);
REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST);
REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER);
/** End of variable type registration */
// Besides register variable id, it is helpful to register a
// var_id -> std::type_index (for example, get var names according to id)
namespace detail {
template <int kStart, int kEnd, bool kStop>
struct VarIdToTypeIndexMapInitializerImpl {
static void Init(std::unordered_map<int, std::type_index> *m) {
using Type =
typename std::tuple_element<kStart, VarTypeRegistry::ArgTuple>::type;
constexpr int kId = VarTypeTrait<Type>::kId;
if (!std::is_same<Type, void>::value) {
m->emplace(kId, std::type_index(typeid(Type)));
}
VarIdToTypeIndexMapInitializerImpl<kStart + 1, kEnd,
kStart + 1 == kEnd>::Init(m);
}
};
template <int kStart, int kEnd>
struct VarIdToTypeIndexMapInitializerImpl<kStart, kEnd, true> {
static void Init(std::unordered_map<int, std::type_index> *m) {}
};
// VarIdToTypeIndexMapInitializer is designed to initialize var_id ->
// std::type_index map
using VarIdToTypeIndexMapInitializer =
VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum,
VarTypeRegistry::kRegisteredTypeNum ==
0>;
struct VarIdToTypeIndexMapHolder {
public:
static const std::type_index &ToTypeIndex(int var_id) {
static const VarIdToTypeIndexMapHolder instance;
auto it = instance.var_type_map_.find(var_id);
PADDLE_ENFORCE(it != instance.var_type_map_.end(),
"VarId %d is not registered.", var_id);
return it->second;
}
private:
VarIdToTypeIndexMapHolder() {
VarIdToTypeIndexMapInitializer::Init(&var_type_map_);
}
std::unordered_map<int, std::type_index> var_type_map_;
};
} // namespace detail
const char *ToTypeName(int var_id);
const std::type_index &ToTypeIndex(int var_id);
template <typename T>
inline constexpr bool IsRegisteredVarType() {
return VarTypeRegistry::IsRegistered<T>();
}
#undef REG_PROTO_VAR_TYPE_TRAIT
} // namespace framework
} // namespace paddle

@ -0,0 +1,75 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/var_type_traits.h"
#include <gtest/gtest.h>
#include <cstdint>
namespace paddle {
namespace framework {
template <int kPos, int kEnd, bool kStop>
struct TypeIndexChecker {
static void Check() {
using Type =
typename std::tuple_element<kPos, VarTypeRegistry::ArgTuple>::type;
if (!std::is_same<Type, void>::value) {
EXPECT_TRUE(ToTypeIndex(VarTypeTrait<Type>::kId) == typeid(Type));
EXPECT_TRUE(std::string(ToTypeName(VarTypeTrait<Type>::kId)) ==
typeid(Type).name());
}
TypeIndexChecker<kPos + 1, kEnd, kPos + 1 == kEnd>::Check();
}
};
template <int kPos, int kEnd>
struct TypeIndexChecker<kPos, kEnd, true> {
static void Check() {}
};
TEST(var_type_traits, check_type_index) {
constexpr size_t kRegisteredNum = VarTypeRegistry::kRegisteredTypeNum;
TypeIndexChecker<0, kRegisteredNum, kRegisteredNum == 0>::Check();
}
template <typename T>
bool CheckVarId(int proto_id) {
static_assert(std::is_same<typename VarTypeTrait<T>::Type, T>::value,
"Type must be the same");
return VarTypeTrait<T>::kId == proto_id;
}
TEST(var_type_traits, check_proto_type_id) {
ASSERT_TRUE(CheckVarId<LoDTensor>(proto::VarType::LOD_TENSOR));
ASSERT_TRUE(CheckVarId<SelectedRows>(proto::VarType::SELECTED_ROWS));
ASSERT_TRUE(CheckVarId<std::vector<Scope *>>(proto::VarType::STEP_SCOPES));
ASSERT_TRUE(CheckVarId<LoDRankTable>(proto::VarType::LOD_RANK_TABLE));
ASSERT_TRUE(CheckVarId<LoDTensorArray>(proto::VarType::LOD_TENSOR_ARRAY));
ASSERT_TRUE(CheckVarId<platform::PlaceList>(proto::VarType::PLACE_LIST));
ASSERT_TRUE(CheckVarId<ReaderHolder>(proto::VarType::READER));
}
TEST(var_type_traits, test_registry) {
using Registry =
detail::VarTypeRegistryImpl<int8_t, int32_t, size_t, double, void>;
ASSERT_TRUE(Registry::TypePos<int8_t>() == 0);
ASSERT_TRUE(Registry::TypePos<int32_t>() == 1);
ASSERT_TRUE(Registry::TypePos<size_t>() == 2);
ASSERT_TRUE(Registry::TypePos<double>() == 3);
ASSERT_TRUE(Registry::TypePos<void>() == -1);
ASSERT_TRUE(Registry::TypePos<float>() == -1);
}
} // namespace framework
} // namespace paddle

@ -18,7 +18,7 @@
#include <typeindex>
#include <typeinfo>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/framework/var_type_traits.h"
namespace paddle {
namespace framework {
@ -27,10 +27,14 @@ class Variable {
public:
template <typename T>
const T& Get() const {
static_assert(
IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h");
PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing");
PADDLE_ENFORCE(IsType<T>(),
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name());
ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
return *static_cast<const T*>(holder_->Ptr());
}
@ -39,61 +43,59 @@ class Variable {
template <typename T>
T* GetMutable() {
if (!holder_) {
holder_.reset(new PlaceholderImpl<T>(new T()));
holder_.reset(new PlaceholderImpl<T>());
} else {
PADDLE_ENFORCE(IsType<T>(),
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name());
ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
}
return static_cast<T*>(holder_->Ptr());
}
template <typename T>
bool IsType() const {
return holder_ != nullptr &&
std::type_index(typeid(T)) == std::type_index(holder_->Type());
return holder_ && holder_->Type() == VarTypeTrait<T>::kId;
}
void Clear() { holder_.reset(); }
std::type_index Type() const {
int Type() const {
PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
return holder_->Type();
}
private:
struct Placeholder {
virtual ~Placeholder() {}
virtual const std::type_info& Type() const = 0;
virtual void* Ptr() const = 0;
explicit Placeholder(int type) : type_(type) {}
virtual ~Placeholder() = default;
inline int Type() const { return type_; }
inline const void* Ptr() const { return ptr_; }
inline void* Ptr() { return ptr_; }
protected:
void* ptr_;
int type_;
};
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
template <typename T>
struct PlaceholderImpl : public Placeholder {
explicit PlaceholderImpl(T* ptr) : ptr_(ptr), type_(typeid(T)) {}
virtual const std::type_info& Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
static_assert(
IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h");
PlaceholderImpl() : Placeholder(VarTypeTrait<T>::kId) {
this->ptr_ = &obj_;
}
std::unique_ptr<T> ptr_;
const std::type_info& type_;
private:
T obj_;
};
std::unique_ptr<Placeholder>
holder_; // pointers to a PlaceholderImpl object indeed.
// name_ is only meaningful with a Scope and accessible by it.
//
// NOTE: Please don't expose name_ by adding methods like
// Variable::Name or Scope::VarName! A variable could have a human
// readable name or an auto-generated scope-unique name. In the
// former case, the caller knows the name and doesn't need to access
// the name; in the latter case, the variable should be identified
// by its address but not the unreadable name.
friend class Scope;
const std::string* name_;
// pointers to a PlaceholderImpl object indeed.
std::unique_ptr<Placeholder> holder_;
};
} // namespace framework

@ -16,27 +16,28 @@
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
TEST(Variable, GetMutable) {
using paddle::framework::Variable;
struct Tensor {
int content_;
};
namespace paddle {
namespace framework {
TEST(Variable, GetMutable) {
std::unique_ptr<Variable> v(new Variable());
Tensor* t = v->GetMutable<Tensor>();
t->content_ = 1234;
auto* t = v->GetMutable<std::string>();
*t = "1234";
const Tensor& tt = v->Get<Tensor>();
EXPECT_EQ(1234, tt.content_);
const auto& tt = v->Get<std::string>();
EXPECT_EQ("1234", tt);
try {
v->GetMutable<std::string>();
v->GetMutable<Tensor>();
} catch (std::exception& e) {
return;
}
EXPECT_TRUE(false);
}
} // namespace framework
} // namespace paddle

@ -25,7 +25,7 @@ void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) {
// TODO(Superjomn) should avoid the case when a TensorArray is a
// parameter.
if (var_name == "feed" || var_name == "fetch") continue;
if (var->Type() == typeid(framework::LoDTensorArray)) {
if (var->IsType<framework::LoDTensorArray>()) {
VLOG(4) << "collect " << var_name;
arrays_.push_back(var->GetMutable<framework::LoDTensorArray>());
}

@ -27,8 +27,11 @@ namespace details {
// training phase.
struct TensorArrayBatchCleaner {
TensorArrayBatchCleaner() {
valid_types_.insert(typeid(framework::Tensor));
valid_types_.insert(typeid(framework::LoDTensor));
constexpr auto kTensorId = framework::VarTypeTrait<framework::Tensor>::kId;
constexpr auto kLoDTensorId =
framework::VarTypeTrait<framework::LoDTensor>::kId;
valid_types_.insert(kTensorId);
valid_types_.insert(kLoDTensorId);
}
// Collect the variables that are not Tensor or LoDTensor, and reset them to a
// bool(trick), because some of them are containers, and some operators just
@ -46,7 +49,7 @@ struct TensorArrayBatchCleaner {
bool no_tensor_flag_{true};
std::vector<framework::LoDTensorArray *> arrays_;
std::unordered_set<std::type_index> valid_types_;
std::unordered_set<int> valid_types_;
std::unordered_set<framework::Variable *> no_tensor_vars_;
};

@ -74,7 +74,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
if (framework::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
@ -184,7 +184,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
if (framework::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif

@ -64,7 +64,7 @@ class ClipByNormKernel : public framework::OpKernel<T> {
output->mutable_data<T>(context.GetPlace());
} else {
PADDLE_THROW("Unexpected branch, input variable type is %s",
in_var->Type().name());
framework::ToTypeName(in_var->Type()));
}
PADDLE_ENFORCE_NOT_NULL(input);

@ -92,7 +92,8 @@ inline void CopyOrShare(const framework::Variable &src,
TensorCopy(src_sr.value(), dst_place, dst_sr->mutable_value());
}
} else {
PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name());
PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s",
framework::ToTypeName(src.Type()));
}
}

@ -175,14 +175,13 @@ class WhileGradOp : public framework::OperatorBase {
auto &og_inside =
detail::Ref(cur_scope.Var(inside_og_name),
"Cannot find inside gradient %s", inside_og_name);
if (framework::IsType<framework::LoDTensor>(og_outside.Type())) {
if (og_outside.IsType<framework::LoDTensor>()) {
auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
auto &inside_tensor =
detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
inside_tensor.set_lod(outside_tensor.lod());
inside_tensor.ShareDataWith(outside_tensor);
} else if (framework::IsType<framework::LoDTensorArray>(
og_outside.Type())) {
} else if (og_outside.IsType<framework::LoDTensorArray>()) {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
@ -256,7 +255,7 @@ class WhileGradOp : public framework::OperatorBase {
var->IsType<LoDTensor>(),
"Currently the type of var only can be LoDTensorArray, "
"or LoDTensor, but the received var[%s] is %s.",
inside_grad_name, var->Type().name());
inside_grad_name, framework::ToTypeName(var->Type()));
if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>();

@ -84,7 +84,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
framework::DataLayout layout = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
if (framework::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
@ -369,7 +369,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
if (framework::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -171,8 +171,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber,
reinterpret_cast<const char*>(slr->rows().data()),

@ -116,7 +116,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
} else {
PADDLE_THROW(
"% should be LoDTensor or SelectedRows, but the received type is %s",
ctx.Inputs("Ids")[0], ids_var->Type().name());
ctx.Inputs("Ids")[0], framework::ToTypeName(ids_var->Type()));
}
}
};

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save