|
|
|
@ -33,7 +33,7 @@ class Tensor {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T* data() const {
|
|
|
|
|
CheckDimsValidity<T>();
|
|
|
|
|
CheckDims<T>();
|
|
|
|
|
return reinterpret_cast<const T*>(
|
|
|
|
|
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
|
|
|
|
|
}
|
|
|
|
@ -62,7 +62,7 @@ class Tensor {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ShareDataFrom(const Tensor& src) {
|
|
|
|
|
src.CheckDimsValidity<T>();
|
|
|
|
|
src.CheckDims<T>();
|
|
|
|
|
holder_ = src.holder_;
|
|
|
|
|
set_dims(src.dims());
|
|
|
|
|
offset_ = src.offset_;
|
|
|
|
@ -73,7 +73,7 @@ class Tensor {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
|
|
|
|
|
platform::is_cpu_place(dst_place),
|
|
|
|
|
"Tensor::CopyFrom only support CPU now.");
|
|
|
|
|
src.CheckDimsValidity<T>();
|
|
|
|
|
src.CheckDims<T>();
|
|
|
|
|
size_t size = src.numel_ * sizeof(T);
|
|
|
|
|
set_dims(src.dims());
|
|
|
|
|
const void* src_ptr = static_cast<const void*>(src.data<T>());
|
|
|
|
@ -83,7 +83,7 @@ class Tensor {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
Tensor Slice(const int& begin_idx, const int& end_idx) const {
|
|
|
|
|
CheckDimsValidity<T>();
|
|
|
|
|
CheckDims<T>();
|
|
|
|
|
PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0],
|
|
|
|
|
"Slice index is less than zero or out of bound.");
|
|
|
|
|
PADDLE_ENFORCE(begin_idx < end_idx,
|
|
|
|
@ -109,7 +109,6 @@ class Tensor {
|
|
|
|
|
}
|
|
|
|
|
dims_ = dims;
|
|
|
|
|
numel_ = product(dims_);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DDim dims() const { return dims_; }
|
|
|
|
@ -155,10 +154,10 @@ class Tensor {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void CheckDimsValidity() const {
|
|
|
|
|
inline void CheckDims() const {
|
|
|
|
|
PADDLE_ENFORCE(holder_ != nullptr,
|
|
|
|
|
"Tenosr holds no memory. Call Tensor::mutable_data first.");
|
|
|
|
|
PADDLE_ENFORCE(holder_->size() > numel_ * sizeof(T) + offset_,
|
|
|
|
|
PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_,
|
|
|
|
|
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
|
|
|
|
|
"first to re-allocate memory.");
|
|
|
|
|
}
|
|
|
|
|