|
|
|
@ -106,9 +106,10 @@ DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
|
|
|
|
|
DECLARE_VALID_DTYPE_TO_PY_ARRAY(double);
|
|
|
|
|
DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool);
|
|
|
|
|
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int8_t);
|
|
|
|
|
DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t);
|
|
|
|
|
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int16_t);
|
|
|
|
|
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int);
|
|
|
|
|
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t);
|
|
|
|
|
DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t);
|
|
|
|
|
|
|
|
|
|
inline std::string TensorDTypeToPyDTypeStr(
|
|
|
|
|
framework::proto::VarType::Type type) {
|
|
|
|
@ -218,13 +219,16 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
|
|
|
|
|
SetTensorFromPyArrayT<double, P>(self, array, place, zero_copy);
|
|
|
|
|
} else if (py::isinstance<py::array_t<int8_t>>(array)) {
|
|
|
|
|
SetTensorFromPyArrayT<int8_t, P>(self, array, place, zero_copy);
|
|
|
|
|
} else if (py::isinstance<py::array_t<int16_t>>(array)) {
|
|
|
|
|
SetTensorFromPyArrayT<int16_t, P>(self, array, place, zero_copy);
|
|
|
|
|
} else if (py::isinstance<py::array_t<uint8_t>>(array)) {
|
|
|
|
|
SetTensorFromPyArrayT<uint8_t, P>(self, array, place, zero_copy);
|
|
|
|
|
} else if (py::isinstance<py::array_t<paddle::platform::float16>>(array)) {
|
|
|
|
|
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
|
|
|
|
|
zero_copy);
|
|
|
|
|
} else if (py::isinstance<py::array_t<uint16_t>>(array)) {
|
|
|
|
|
// TODO(cql): temporary keeping uint16, should be depracated later
|
|
|
|
|
// TODO(cql): temporary keeping uint16, which is used for casting float16
|
|
|
|
|
// before. It should be depracated later.
|
|
|
|
|
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
|
|
|
|
|
zero_copy);
|
|
|
|
|
} else if (py::isinstance<py::array_t<bool>>(array)) {
|
|
|
|
@ -234,7 +238,7 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
|
|
|
|
|
"Incompatible data or style type: tensor.set() supports bool, float16, "
|
|
|
|
|
"float32, "
|
|
|
|
|
"float64, "
|
|
|
|
|
"int8, int32, int64 and uint8, uint16, but got %s!",
|
|
|
|
|
"int8, int16, int32, int64 and uint8, uint16, but got %s!",
|
|
|
|
|
array.dtype());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -435,16 +439,18 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self,
|
|
|
|
|
return _sliceAndConcat<float>(self, obj, dim);
|
|
|
|
|
case framework::proto::VarType::FP64:
|
|
|
|
|
return _sliceAndConcat<double>(self, obj, dim);
|
|
|
|
|
case framework::proto::VarType::INT8:
|
|
|
|
|
return _sliceAndConcat<int8_t>(self, obj, dim);
|
|
|
|
|
case framework::proto::VarType::INT16:
|
|
|
|
|
return _sliceAndConcat<int16_t>(self, obj, dim);
|
|
|
|
|
case framework::proto::VarType::INT32:
|
|
|
|
|
return _sliceAndConcat<int>(self, obj, dim);
|
|
|
|
|
case framework::proto::VarType::INT64:
|
|
|
|
|
return _sliceAndConcat<int64_t>(self, obj, dim);
|
|
|
|
|
case framework::proto::VarType::BOOL:
|
|
|
|
|
return _sliceAndConcat<bool>(self, obj, dim);
|
|
|
|
|
case framework::proto::VarType::INT16:
|
|
|
|
|
return _sliceAndConcat<bool>(self, obj, dim);
|
|
|
|
|
case framework::proto::VarType::UINT8:
|
|
|
|
|
return _sliceAndConcat<bool>(self, obj, dim);
|
|
|
|
|
return _sliceAndConcat<uint8_t>(self, obj, dim);
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("Not support type %d", src_type);
|
|
|
|
|
}
|
|
|
|
|