|
|
|
@ -25,8 +25,9 @@ inline const T* Tensor::data() const {
|
|
|
|
|
check_memory_size();
|
|
|
|
|
bool valid =
|
|
|
|
|
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
|
|
|
|
|
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d",
|
|
|
|
|
DataTypeToString(type_));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
valid, "Tensor holds the wrong type, it holds %s, but desires to be %s",
|
|
|
|
|
DataTypeToString(type_), DataTypeToString(DataTypeTrait<T>::DataType));
|
|
|
|
|
|
|
|
|
|
return reinterpret_cast<const T*>(
|
|
|
|
|
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
|
|
|
|
@ -39,7 +40,9 @@ inline T* Tensor::data() {
|
|
|
|
|
check_memory_size();
|
|
|
|
|
bool valid =
|
|
|
|
|
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
|
|
|
|
|
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
valid, "Tensor holds the wrong type, it holds %s, but desires to be %s",
|
|
|
|
|
DataTypeToString(type_), DataTypeToString(DataTypeTrait<T>::DataType));
|
|
|
|
|
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
|
|
|
|
|
offset_);
|
|
|
|
|
}
|
|
|
|
|