|
|
|
@ -40,7 +40,10 @@ template <size_t I, typename... ARGS>
|
|
|
|
|
struct CastToPyBufferImpl<true, I, ARGS...> {
|
|
|
|
|
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
|
|
|
|
|
py::buffer_info operator()(framework::Tensor &tensor) {
|
|
|
|
|
if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) {
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
|
|
|
|
|
"Only CPU tensor can cast to numpy array");
|
|
|
|
|
|
|
|
|
|
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
|
|
|
|
|
auto dim_vec = framework::vectorize(tensor.dims());
|
|
|
|
|
std::vector<size_t> dims_outside;
|
|
|
|
|
std::vector<size_t> strides;
|
|
|
|
@ -54,12 +57,13 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
|
|
|
|
|
prod *= dims_outside[i - 1];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return py::buffer_info(tensor.mutable_data<CUR_TYPE>(tensor.place()),
|
|
|
|
|
sizeof(CUR_TYPE),
|
|
|
|
|
py::format_descriptor<CUR_TYPE>::format(),
|
|
|
|
|
(size_t)framework::arity(tensor.dims()),
|
|
|
|
|
dims_outside,
|
|
|
|
|
strides);
|
|
|
|
|
return py::buffer_info(
|
|
|
|
|
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
|
|
|
|
|
sizeof(CUR_TYPE),
|
|
|
|
|
py::format_descriptor<CUR_TYPE>::format(),
|
|
|
|
|
(size_t)framework::arity(tensor.dims()),
|
|
|
|
|
dims_outside,
|
|
|
|
|
strides);
|
|
|
|
|
} else {
|
|
|
|
|
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
|
|
|
|
|
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
|