|
|
|
@ -36,9 +36,15 @@ namespace pybind {
|
|
|
|
|
|
|
|
|
|
static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
|
|
|
|
|
const std::string& op_type, const std::string& arg_name, int arg_idx,
|
|
|
|
|
const py::handle& handle) {
|
|
|
|
|
const py::handle& handle, bool dispensable = false) {
|
|
|
|
|
PyObject* py_obj = handle.ptr(); // get underlying PyObject
|
|
|
|
|
if (!py_obj || py_obj == Py_None) {
|
|
|
|
|
if (!dispensable) {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"%s(): argument '%s' (position %d) must be Tensor, but got "
|
|
|
|
|
"%s",
|
|
|
|
|
op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name));
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
try {
|
|
|
|
@ -54,9 +60,15 @@ static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
|
|
|
|
|
static inline std::vector<std::shared_ptr<imperative::VarBase>>
|
|
|
|
|
CastPyHandleToVarBaseList(const std::string& op_type,
|
|
|
|
|
const std::string& arg_name, int arg_idx,
|
|
|
|
|
const py::handle& handle) {
|
|
|
|
|
const py::handle& handle, bool dispensable = false) {
|
|
|
|
|
PyObject* py_obj = handle.ptr(); // get underlying PyObject
|
|
|
|
|
if (!py_obj || py_obj == Py_None) {
|
|
|
|
|
if (!dispensable) {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"%s(): argument '%s' (position %d) must be Tensor, but got "
|
|
|
|
|
"%s",
|
|
|
|
|
op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name));
|
|
|
|
|
}
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::shared_ptr<imperative::VarBase>> result;
|
|
|
|
|