|
|
|
@ -22,6 +22,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/platform/complex128.h"
|
|
|
|
|
#include "paddle/fluid/platform/complex64.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
#include "paddle/fluid/platform/float16.h"
|
|
|
|
|
#include "paddle/fluid/platform/transform.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -170,6 +171,8 @@ DataType Tensor::type() const {
|
|
|
|
|
return DataType::COMPLEX64;
|
|
|
|
|
} else if (type == framework::proto::VarType::COMPLEX128) {
|
|
|
|
|
return DataType::COMPLEX128;
|
|
|
|
|
} else if (type == framework::proto::VarType::FP16) {
|
|
|
|
|
return DataType::FLOAT16;
|
|
|
|
|
}
|
|
|
|
|
// TODO(JiabinYang) Support more dtype here
|
|
|
|
|
return DataType::FLOAT32;
|
|
|
|
@ -229,6 +232,8 @@ template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>(
|
|
|
|
|
const PlaceType &target_place) const;
|
|
|
|
|
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>(
|
|
|
|
|
const PlaceType &target_place) const;
|
|
|
|
|
template PD_DLL_DECL Tensor
|
|
|
|
|
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
|
|
|
|
|
|
|
|
|
|
template PD_DLL_DECL float *Tensor::data<float>() const;
|
|
|
|
|
template PD_DLL_DECL double *Tensor::data<double>() const;
|
|
|
|
@ -242,6 +247,8 @@ template PD_DLL_DECL paddle::platform::complex64 *
|
|
|
|
|
Tensor::data<paddle::platform::complex64>() const;
|
|
|
|
|
template PD_DLL_DECL paddle::platform::complex128 *
|
|
|
|
|
Tensor::data<paddle::platform::complex128>() const;
|
|
|
|
|
template PD_DLL_DECL paddle::platform::float16 *
|
|
|
|
|
Tensor::data<paddle::platform::float16>() const;
|
|
|
|
|
|
|
|
|
|
template PD_DLL_DECL float *Tensor::mutable_data<float>();
|
|
|
|
|
template PD_DLL_DECL double *Tensor::mutable_data<double>();
|
|
|
|
@ -255,6 +262,8 @@ template PD_DLL_DECL paddle::platform::complex64 *
|
|
|
|
|
Tensor::mutable_data<paddle::platform::complex64>();
|
|
|
|
|
template PD_DLL_DECL paddle::platform::complex128 *
|
|
|
|
|
Tensor::mutable_data<paddle::platform::complex128>();
|
|
|
|
|
template PD_DLL_DECL paddle::platform::float16 *
|
|
|
|
|
Tensor::mutable_data<paddle::platform::float16>();
|
|
|
|
|
|
|
|
|
|
template PD_DLL_DECL float *Tensor::mutable_data<float>(const PlaceType &place);
|
|
|
|
|
template PD_DLL_DECL double *Tensor::mutable_data<double>(
|
|
|
|
@ -274,6 +283,8 @@ template PD_DLL_DECL paddle::platform::complex64 *
|
|
|
|
|
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
|
|
|
|
|
template PD_DLL_DECL paddle::platform::complex128 *
|
|
|
|
|
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
|
|
|
|
|
template PD_DLL_DECL paddle::platform::float16 *
|
|
|
|
|
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> Tensor::shape() const {
|
|
|
|
|
GET_CASTED_TENSOR
|
|
|
|
@ -344,6 +355,11 @@ Tensor Tensor::cast(const DataType &target_type) const {
|
|
|
|
|
CastDataType<paddle::platform::complex128>(
|
|
|
|
|
*tensor, rlt_tensor_, ctx));
|
|
|
|
|
break;
|
|
|
|
|
case framework::proto::VarType::FP16:
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
dst_type,
|
|
|
|
|
CastDataType<paddle::platform::float16>(*tensor, rlt_tensor_, ctx));
|
|
|
|
|
break;
|
|
|
|
|
// TODO(JiabinYang) Support more dtype here
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|