Follow comments from Xu Wei

gangliao-patch-1
Yi Wang 8 years ago
parent 5a22d73651
commit 58efbf41b3

@ -12,6 +12,8 @@
*/
#pragma once
#include <memory>
#include <typeindex>
#include <typeinfo>
namespace paddle {
@ -26,24 +28,14 @@ class Variable {
template <typename T>
T* GetMutable() {
if (holder_ != nullptr && typeid(T) == holder_->Type()) {
return static_cast<T*>(holder_->Ptr());
} else {
return Reset<T>(new T(), DefaultDeleter<T>());
if (holder_ == nullptr ||
std::type_index(typeid(T)) != std::type_index(holder_->Type())) {
holder_.reset(new PlaceholderImpl<T>(new T()));
}
}
~Variable() {
if (holder_ != nullptr) delete holder_;
return static_cast<T*>(holder_->Ptr());
}
private:
// DefaultDeleter is functor which uses C++'s delete(T*).
template <typename T>
struct DefaultDeleter {
void operator()(T* ptr) { delete ptr; }
};
struct Placeholder {
virtual ~Placeholder() {}
virtual const std::type_info& Type() const = 0;
@ -54,34 +46,17 @@ class Variable {
// parameter of Variable.
template <typename T>
struct PlaceholderImpl : public Placeholder {
typedef std::function<void(T*)> Deleter;
PlaceholderImpl(T* ptr) : ptr_(ptr), type_(typeid(T)) {}
PlaceholderImpl(T* ptr, Deleter d)
: ptr_(ptr), type_(typeid(T)), deleter_(d) {}
virtual ~PlaceholderImpl() {
deleter_(ptr_);
ptr_ = nullptr;
}
virtual const std::type_info& Type() const { return type_; }
virtual void* Ptr() const { return ptr_; }
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
T* ptr_ = nullptr;
std::unique_ptr<T> ptr_;
const std::type_info& type_;
std::function<void(T*)> deleter_ = DefaultDeleter<T>();
};
template <typename T>
T* Reset(T* allocated, typename PlaceholderImpl<T>::Deleter deleter) {
if (holder_ != nullptr) {
delete holder_;
}
holder_ = new PlaceholderImpl<T>(allocated, deleter);
return allocated;
}
Placeholder* holder_; // pointers to a PlaceholderImpl object indeed.
std::unique_ptr<Placeholder>
holder_; // pointers to a PlaceholderImpl object indeed.
};
} // namespace framework

Loading…
Cancel
Save