|
|
|
@ -29,8 +29,6 @@ class Tensor {
|
|
|
|
|
public:
|
|
|
|
|
Tensor() : numel_(0), offset_(0) {}
|
|
|
|
|
|
|
|
|
|
Tensor& operator=(const Tensor& src) = delete;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T* data() const {
|
|
|
|
|
CheckDims<T>();
|
|
|
|
@ -39,13 +37,13 @@ class Tensor {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* mutable_data(DDim dims, paddle::platform::Place place) {
|
|
|
|
|
T* mutable_data(DDim dims, platform::Place place) {
|
|
|
|
|
set_dims(dims);
|
|
|
|
|
return mutable_data<T>(place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* mutable_data(paddle::platform::Place place) {
|
|
|
|
|
T* mutable_data(platform::Place place) {
|
|
|
|
|
PADDLE_ENFORCE(numel_ > 0,
|
|
|
|
|
"Tensor::numel_ must be larger than zero to call "
|
|
|
|
|
"Tensor::mutable_data. Call Tensor::set_dim first.");
|
|
|
|
@ -53,7 +51,18 @@ class Tensor {
|
|
|
|
|
!(holder_->place() ==
|
|
|
|
|
place) /* some versions of boost::variant don't have operator!= */
|
|
|
|
|
|| holder_->size() < numel_ * sizeof(T) + offset_) {
|
|
|
|
|
holder_.reset(new PlaceholderImpl<T>(place, numel_ * sizeof(T)));
|
|
|
|
|
switch (place.which()) {
|
|
|
|
|
case 0:
|
|
|
|
|
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
|
|
|
|
|
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case 1:
|
|
|
|
|
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
|
|
|
|
|
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
offset_ = 0;
|
|
|
|
|
}
|
|
|
|
|
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
|
|
|
|
@ -69,7 +78,7 @@ class Tensor {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) {
|
|
|
|
|
void CopyFrom(const Tensor& src, platform::Place dst_place) {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
|
|
|
|
|
platform::is_cpu_place(dst_place),
|
|
|
|
|
"Tensor::CopyFrom only support CPU now.");
|
|
|
|
@ -119,37 +128,36 @@ class Tensor {
|
|
|
|
|
struct Placeholder {
|
|
|
|
|
virtual ~Placeholder() {}
|
|
|
|
|
virtual void* ptr() const = 0;
|
|
|
|
|
virtual paddle::platform::Place place() const = 0;
|
|
|
|
|
virtual platform::Place place() const = 0;
|
|
|
|
|
virtual size_t size() const = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
template <typename T, typename PlaceType>
|
|
|
|
|
struct PlaceholderImpl : public Placeholder {
|
|
|
|
|
private:
|
|
|
|
|
template <typename PType>
|
|
|
|
|
class Deleter {
|
|
|
|
|
public:
|
|
|
|
|
Deleter(platform::Place place) : place_(place) {}
|
|
|
|
|
void operator()(T* ptr) {
|
|
|
|
|
paddle::memory::Free(place_, static_cast<void*>(ptr));
|
|
|
|
|
}
|
|
|
|
|
Deleter(PType place) : place_(place) {}
|
|
|
|
|
void operator()(T* ptr) { memory::Free(place_, static_cast<void*>(ptr)); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
paddle::platform::Place place_;
|
|
|
|
|
PType place_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
PlaceholderImpl(paddle::platform::Place place, size_t size)
|
|
|
|
|
: ptr_(static_cast<T*>(paddle::memory::Alloc(place, size)),
|
|
|
|
|
Deleter(place)),
|
|
|
|
|
PlaceholderImpl(PlaceType place, size_t size)
|
|
|
|
|
: ptr_(static_cast<T*>(memory::Alloc(place, size)),
|
|
|
|
|
Deleter<PlaceType>(place)),
|
|
|
|
|
place_(place),
|
|
|
|
|
size_(size) {}
|
|
|
|
|
|
|
|
|
|
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
|
|
|
|
|
virtual size_t size() const { return size_; }
|
|
|
|
|
virtual paddle::platform::Place place() const { return place_; }
|
|
|
|
|
virtual platform::Place place() const { return place_; }
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<T, Deleter> ptr_;
|
|
|
|
|
paddle::platform::Place place_; // record the place of ptr_.
|
|
|
|
|
std::unique_ptr<T, Deleter<PlaceType>> ptr_;
|
|
|
|
|
platform::Place place_; // record the place of ptr_.
|
|
|
|
|
size_t size_; // size of the memory block.
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -166,7 +174,7 @@ class Tensor {
|
|
|
|
|
DDim dims_;
|
|
|
|
|
size_t numel_; // cache of `product(dims_)`
|
|
|
|
|
size_t offset_; // marks the begin of tensor data area.
|
|
|
|
|
};
|
|
|
|
|
}; // namespace framework
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|