|
|
@ -15,18 +15,19 @@
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
#include "pybind_api/ir/primitive_py.h"
|
|
|
|
#include "pybind_api/ir/primitive_py.h"
|
|
|
|
|
|
|
|
|
|
|
|
#include <mutex>
|
|
|
|
#include <mutex>
|
|
|
|
#include "ir/signature.h"
|
|
|
|
#include "ir/signature.h"
|
|
|
|
#include "pipeline/jit/parse/python_adapter.h"
|
|
|
|
|
|
|
|
#include "pipeline/jit/parse/data_converter.h"
|
|
|
|
#include "pipeline/jit/parse/data_converter.h"
|
|
|
|
|
|
|
|
#include "pipeline/jit/parse/python_adapter.h"
|
|
|
|
#include "pybind11/pytypes.h"
|
|
|
|
#include "pybind11/pytypes.h"
|
|
|
|
#include "utils/convert_utils_base.h"
|
|
|
|
|
|
|
|
#include "utils/convert_utils_py.h"
|
|
|
|
|
|
|
|
#include "utils/primitive_utils.h"
|
|
|
|
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
|
|
|
|
#include "pybind_api/api_register.h"
|
|
|
|
#include "pybind_api/api_register.h"
|
|
|
|
#include "pybind_api/export_flags.h"
|
|
|
|
#include "pybind_api/export_flags.h"
|
|
|
|
#include "pybind_api/ir/base_ref_py.h"
|
|
|
|
#include "pybind_api/ir/base_ref_py.h"
|
|
|
|
|
|
|
|
#include "utils/convert_utils_base.h"
|
|
|
|
|
|
|
|
#include "utils/convert_utils_py.h"
|
|
|
|
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
|
|
|
|
#include "utils/primitive_utils.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace {
|
|
|
|
namespace {
|
|
|
@ -107,6 +108,42 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
|
|
|
|
return grads;
|
|
|
|
return grads;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const {
|
|
|
|
|
|
|
|
if (py::isinstance<py::tuple>(expected_grad_out)) {
|
|
|
|
|
|
|
|
if (!py::isinstance<py::tuple>(grad_out)) {
|
|
|
|
|
|
|
|
hook_grad_.clear();
|
|
|
|
|
|
|
|
MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto actual_out_tuple = py::cast<py::tuple>(grad_out);
|
|
|
|
|
|
|
|
auto expected_out_tuple = py::cast<py::tuple>(expected_grad_out);
|
|
|
|
|
|
|
|
if (actual_out_tuple.size() != expected_out_tuple.size()) {
|
|
|
|
|
|
|
|
hook_grad_.clear();
|
|
|
|
|
|
|
|
MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size()
|
|
|
|
|
|
|
|
<< ", but it is " << actual_out_tuple.size();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (size_t i = 0; i < expected_out_tuple.size(); ++i) {
|
|
|
|
|
|
|
|
CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
|
|
|
|
|
|
|
|
if (!py::isinstance<tensor::Tensor>(grad_out)) {
|
|
|
|
|
|
|
|
hook_grad_.clear();
|
|
|
|
|
|
|
|
MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
|
|
|
|
|
|
|
|
auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(actual_out_tensor);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(expected_out_tensor);
|
|
|
|
|
|
|
|
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
|
|
|
|
|
|
|
|
hook_grad_.clear();
|
|
|
|
|
|
|
|
MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be "
|
|
|
|
|
|
|
|
<< expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is "
|
|
|
|
|
|
|
|
<< actual_out_tensor->GetShapeAndDataTypeInfo();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|
|
|
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|
|
|
py::tuple py_args = ConvertDatatoPyTuple(args);
|
|
|
|
py::tuple py_args = ConvertDatatoPyTuple(args);
|
|
|
|
bool is_bprop = this->HasAttr(kBpropAttrName);
|
|
|
|
bool is_bprop = this->HasAttr(kBpropAttrName);
|
|
|
@ -138,6 +175,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|
|
|
if (py::isinstance<py::none>(obj)) {
|
|
|
|
if (py::isinstance<py::none>(obj)) {
|
|
|
|
obj = py_args[2];
|
|
|
|
obj = py_args[2];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
CheckHookConsistency(obj, py_args[2]);
|
|
|
|
hook_grad_.erase(cell_id);
|
|
|
|
hook_grad_.erase(cell_id);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
hook_grad_[cell_id] = py_args[2];
|
|
|
|
hook_grad_[cell_id] = py_args[2];
|
|
|
@ -149,6 +187,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|
|
|
if (py::isinstance<py::none>(obj)) {
|
|
|
|
if (py::isinstance<py::none>(obj)) {
|
|
|
|
obj = py_args[2];
|
|
|
|
obj = py_args[2];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
CheckHookConsistency(obj, py_args[2]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
obj = py::make_tuple(obj);
|
|
|
|
obj = py::make_tuple(obj);
|
|
|
|
return std::make_shared<PyObjectRef>(obj);
|
|
|
|
return std::make_shared<PyObjectRef>(obj);
|
|
|
|