|
|
|
@ -28,12 +28,14 @@
|
|
|
|
|
#include "utils/convert_utils_py.h"
|
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
|
#include "utils/primitive_utils.h"
|
|
|
|
|
#include "pipeline/jit/resource.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr auto kBpropAttrName = "bprop";
|
|
|
|
|
constexpr auto kCellHookAttrName = "cell_hook";
|
|
|
|
|
constexpr auto kCellIDAttrName = "cell_id";
|
|
|
|
|
|
|
|
|
|
void SyncData(const py::object &arg) {
|
|
|
|
|
if (py::isinstance<py::tuple>(arg)) {
|
|
|
|
|
py::tuple arg_list = py::cast<py::tuple>(arg);
|
|
|
|
@ -49,6 +51,12 @@ void SyncData(const py::object &arg) {
|
|
|
|
|
} // namespace
|
|
|
|
|
std::map<std::string, py::object> PrimitivePy::hook_grad_;
|
|
|
|
|
|
|
|
|
|
PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj)
|
|
|
|
|
: Primitive(name, false), python_obj_(python_obj), signatures_() {
|
|
|
|
|
pipeline::Resource::RecordPrimitivePy(this);
|
|
|
|
|
}
|
|
|
|
|
PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); }
|
|
|
|
|
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
|
|
|
|
|
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
|
|
|
|
|
signatures_ = signatures;
|
|
|
|
|
set_has_signature(true);
|
|
|
|
|