|
|
|
@ -20,23 +20,19 @@ class Tensor {
|
|
|
|
|
using paddle::platform::get_place;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {}
|
|
|
|
|
explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T* data() const {
|
|
|
|
|
PADDLE_ASSERT(holder_ != nullptr);
|
|
|
|
|
PADDLE_ASSERT(holder_->Place() == place_);
|
|
|
|
|
PADDLE_ASSERT(holder_->Size() >= dims_.product() * sizeof(T));
|
|
|
|
|
PADDLE_ASSERT(holder_ != nullptr,
|
|
|
|
|
"Tensor::data must be called after Tensor::mutable_data");
|
|
|
|
|
return static_cast<const T*>(holder->Ptr());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, // must be POD types
|
|
|
|
|
typename = std::enable_if<std::is_pod<T>::value>::type>
|
|
|
|
|
T* mutable_data() {
|
|
|
|
|
if (holder_ == nullptr || holder_->Place() != place_ ||
|
|
|
|
|
holder_->Size() < dims_.product() * sizeof(T)) {
|
|
|
|
|
holder_.reset(new PlaceholderImpl(place_, dims.product() * sizeof(T)));
|
|
|
|
|
T* mutable_data(DDim dims, Place place) {
|
|
|
|
|
if (holder_ == nullptr || holder_->Place() != place ||
|
|
|
|
|
holder_->Size() < dims.product() * sizeof(T)) {
|
|
|
|
|
holder_.reset(new PlaceholderImpl(place, dims.product() * sizeof(T)));
|
|
|
|
|
}
|
|
|
|
|
return static_cast<T*>(holder_->Ptr());
|
|
|
|
|
}
|
|
|
|
@ -44,16 +40,7 @@ class Tensor {
|
|
|
|
|
template <typename T, // must be POD types
|
|
|
|
|
typename = std::enable_if<std::is_pod<T>::value>::type>
|
|
|
|
|
T* mutable_data(DDim dims) {
|
|
|
|
|
dims_ = dims;
|
|
|
|
|
return mutable_data<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, // must be POD types
|
|
|
|
|
typename = std::enable_if<std::is_pod<T>::value>::type>
|
|
|
|
|
T* mutable_data(DDim dims, Place place) {
|
|
|
|
|
dims_ = dims;
|
|
|
|
|
place_ = place;
|
|
|
|
|
return mutable_data<T>();
|
|
|
|
|
return mutable_data<T>(dims, paddle::platform::get_place());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -69,7 +56,7 @@ class Tensor {
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct PlaceholderImpl : public Placeholder {
|
|
|
|
|
PlaceholderImpl(Place pl, size_t size)
|
|
|
|
|
: ptr_(memory::Alloc(pl, size), paddle::memory::Deleter(pl)),
|
|
|
|
|
: ptr_(paddle::memory::Alloc(pl, size), paddle::memory::Deleter(pl)),
|
|
|
|
|
place_(pl),
|
|
|
|
|
size_(size) {}
|
|
|
|
|
|
|
|
|
@ -83,8 +70,6 @@ class Tensor {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Placeholder> holder_; // holds the memory block if allocated.
|
|
|
|
|
DDim dims_; // could be smallers than the holder_->Size().
|
|
|
|
|
paddle::platform::Place place_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|