|
|
|
@ -37,8 +37,10 @@ class Tensor {
|
|
|
|
|
template <typename T, // must be POD types
|
|
|
|
|
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
|
|
|
|
|
T* mutable_data(DDim dims, paddle::platform::Place place) {
|
|
|
|
|
if (holder_ == nullptr || holder_->Place() != place ||
|
|
|
|
|
holder_->Size() < product(dims) * sizeof(T)) {
|
|
|
|
|
if (holder_ == nullptr ||
|
|
|
|
|
!(holder_->Place() ==
|
|
|
|
|
place) /* some versions of boost::variant don't have operator!= */
|
|
|
|
|
|| holder_->Size() < product(dims) * sizeof(T)) {
|
|
|
|
|
holder_.reset(new PlaceholderImpl<T>(place, product(dims) * sizeof(T)));
|
|
|
|
|
}
|
|
|
|
|
return static_cast<T*>(holder_->Ptr());
|
|
|
|
|