|
|
|
@ -14,11 +14,14 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/pybind/imperative.h"
|
|
|
|
|
|
|
|
|
|
#include <Python.h>
|
|
|
|
|
#include <pybind11/chrono.h>
|
|
|
|
|
#include <pybind11/complex.h>
|
|
|
|
|
#include <pybind11/functional.h>
|
|
|
|
|
#include <pybind11/stl.h>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/block_desc.h"
|
|
|
|
|
#include "paddle/fluid/imperative/layer.h"
|
|
|
|
@ -31,6 +34,8 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace pybind {
|
|
|
|
|
|
|
|
|
|
namespace py = ::pybind11;
|
|
|
|
|
|
|
|
|
|
class Layer : public imperative::Layer {
|
|
|
|
|
public:
|
|
|
|
|
using imperative::Layer::Layer; // Inherit constructors
|
|
|
|
@ -51,10 +56,102 @@ class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
|
|
|
|
|
PyOpBase(const std::string &name) : OpBase(name) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Function like obj.attr_name in Python.
|
|
|
|
|
static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) {
|
|
|
|
|
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
|
|
|
|
|
// is not inside obj, but it would also set the error flag of Python.
|
|
|
|
|
// If the error flag is set in C++, C++ code would not raise Exception,
|
|
|
|
|
// but Python would raise Exception once C++ call ends.
|
|
|
|
|
// To avoid unexpected Exception raised in Python, we check whether
|
|
|
|
|
// attribute exists before calling PyObject_GetAttrString.
|
|
|
|
|
//
|
|
|
|
|
// Caution: PyObject_GetAttrString would increase reference count of PyObject.
|
|
|
|
|
// Developer should call Py_DECREF manually after the attribute is not used.
|
|
|
|
|
if (PyObject_HasAttrString(obj, attr_name)) {
|
|
|
|
|
return PyObject_GetAttrString(obj, attr_name);
|
|
|
|
|
} else {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static T PyObjectCast(PyObject *obj) {
|
|
|
|
|
try {
|
|
|
|
|
return py::cast<T>(py::handle(obj));
|
|
|
|
|
} catch (py::cast_error &) {
|
|
|
|
|
PADDLE_THROW("Python object is not type of %s", typeid(T).name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
|
|
|
|
|
// Unlike py::object, py::handle does not change reference count of PyObject *.
|
|
|
|
|
static std::vector<std::shared_ptr<imperative::VarBase>>
|
|
|
|
|
GetVarBaseListFromPyHandle(const py::handle &handle) {
|
|
|
|
|
PyObject *py_obj = handle.ptr(); // get underlying PyObject
|
|
|
|
|
// Python None is not nullptr in C++!
|
|
|
|
|
if (!py_obj || py_obj == Py_None) {
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const char *kIVarField = "_ivar";
|
|
|
|
|
PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField);
|
|
|
|
|
std::vector<std::shared_ptr<imperative::VarBase>> result;
|
|
|
|
|
|
|
|
|
|
if (py_ivar) { // Variable
|
|
|
|
|
result.emplace_back(
|
|
|
|
|
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
|
|
|
|
|
Py_DECREF(py_ivar);
|
|
|
|
|
} else if (PyList_Check(py_obj)) { // List of Variable
|
|
|
|
|
size_t len = PyList_GET_SIZE(py_obj);
|
|
|
|
|
result.reserve(len);
|
|
|
|
|
for (size_t i = 0; i < len; ++i) {
|
|
|
|
|
PyObject *py_ivar =
|
|
|
|
|
PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kIVarField);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(py_ivar);
|
|
|
|
|
result.emplace_back(
|
|
|
|
|
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
|
|
|
|
|
Py_DECREF(py_ivar);
|
|
|
|
|
}
|
|
|
|
|
} else if (PyTuple_Check(py_obj)) { // Tuple of Variable
|
|
|
|
|
size_t len = PyTuple_GET_SIZE(py_obj);
|
|
|
|
|
result.reserve(len);
|
|
|
|
|
for (size_t i = 0; i < len; ++i) {
|
|
|
|
|
PyObject *py_ivar =
|
|
|
|
|
PyObject_GetAttrString(PyTuple_GET_ITEM(py_obj, i), kIVarField);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(py_ivar);
|
|
|
|
|
result.emplace_back(
|
|
|
|
|
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
|
|
|
|
|
Py_DECREF(py_ivar);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"unsupported type %s, must be Variable, List[Variable] or "
|
|
|
|
|
"tuple[Variable]",
|
|
|
|
|
py::str(handle));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(PyErr_Occurred() == nullptr,
|
|
|
|
|
py::str(py::handle(PyErr_Occurred())));
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using PyVarBaseMap = std::unordered_map<std::string, py::handle>;
|
|
|
|
|
|
|
|
|
|
static imperative::VarBasePtrMap ConvertToVarBasePtrMap(
|
|
|
|
|
const PyVarBaseMap &map) {
|
|
|
|
|
imperative::VarBasePtrMap result;
|
|
|
|
|
for (auto &pair : map) {
|
|
|
|
|
auto var_vec = GetVarBaseListFromPyHandle(pair.second);
|
|
|
|
|
if (!var_vec.empty()) {
|
|
|
|
|
result.emplace(pair.first, std::move(var_vec));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Bind Methods
|
|
|
|
|
void BindImperative(pybind11::module *m_ptr) {
|
|
|
|
|
namespace py = ::pybind11;
|
|
|
|
|
|
|
|
|
|
auto &m = *m_ptr;
|
|
|
|
|
|
|
|
|
|
py::class_<imperative::detail::BackwardStrategy> backward_strategy(
|
|
|
|
@ -145,31 +242,41 @@ void BindImperative(pybind11::module *m_ptr) {
|
|
|
|
|
return self.Forward(inputs);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
py::class_<imperative::Tracer>(*m, "Tracer", "")
|
|
|
|
|
// NOTE(zjl): Tracer use PyVarBaseMap as its parameter but not VarBasePtrMap.
|
|
|
|
|
// We call Python C-API to convert PyVarBaseMap to VarBasePtrMap, instead
|
|
|
|
|
// making conversion in Python code. This speed up Tracer.trace() about 6%
|
|
|
|
|
// in ptb model and make time cost in Python to be nearly zero.
|
|
|
|
|
py::class_<imperative::Tracer>(m, "Tracer", "")
|
|
|
|
|
.def("__init__",
|
|
|
|
|
[](imperative::Tracer &self, framework::BlockDesc *root_block) {
|
|
|
|
|
new (&self) imperative::Tracer(root_block);
|
|
|
|
|
})
|
|
|
|
|
.def("trace",
|
|
|
|
|
[](imperative::Tracer &self, imperative::OpBase *op,
|
|
|
|
|
const imperative::VarBasePtrMap &inputs,
|
|
|
|
|
imperative::VarBasePtrMap *outputs,
|
|
|
|
|
const PyVarBaseMap &inputs, const PyVarBaseMap &outputs,
|
|
|
|
|
framework::AttributeMap attrs_map,
|
|
|
|
|
const platform::CPUPlace expected_place,
|
|
|
|
|
const bool stop_gradient = false) {
|
|
|
|
|
py::gil_scoped_release release;
|
|
|
|
|
self.Trace(op, inputs, outputs, attrs_map, expected_place,
|
|
|
|
|
stop_gradient);
|
|
|
|
|
auto ins = ConvertToVarBasePtrMap(inputs);
|
|
|
|
|
auto outs = ConvertToVarBasePtrMap(outputs);
|
|
|
|
|
{
|
|
|
|
|
py::gil_scoped_release release;
|
|
|
|
|
self.Trace(op, std::move(ins), &outs, attrs_map, expected_place,
|
|
|
|
|
stop_gradient);
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
.def("trace", [](imperative::Tracer &self, imperative::OpBase *op,
|
|
|
|
|
const imperative::VarBasePtrMap &inputs,
|
|
|
|
|
imperative::VarBasePtrMap *outputs,
|
|
|
|
|
const PyVarBaseMap &inputs, const PyVarBaseMap &outputs,
|
|
|
|
|
framework::AttributeMap attrs_map,
|
|
|
|
|
const platform::CUDAPlace expected_place,
|
|
|
|
|
const bool stop_gradient = false) {
|
|
|
|
|
py::gil_scoped_release release;
|
|
|
|
|
self.Trace(op, inputs, outputs, attrs_map, expected_place,
|
|
|
|
|
stop_gradient);
|
|
|
|
|
auto ins = ConvertToVarBasePtrMap(inputs);
|
|
|
|
|
auto outs = ConvertToVarBasePtrMap(outputs);
|
|
|
|
|
{
|
|
|
|
|
py::gil_scoped_release release;
|
|
|
|
|
self.Trace(op, std::move(ins), &outs, attrs_map, expected_place,
|
|
|
|
|
stop_gradient);
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// define parallel context
|
|
|
|
|