From 1b98e91a7ec3220a679d88475c5a4a182e158fc2 Mon Sep 17 00:00:00 2001 From: lvliang Date: Thu, 15 Oct 2020 16:17:21 +0800 Subject: [PATCH] checkout-the-consistency-of-grad-out-in-pynative-cell-hook --- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 49 +++++++++++++++++-- mindspore/ccsrc/pybind_api/ir/primitive_py.h | 15 +++--- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index f9c64d4257..08804dc305 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -15,18 +15,19 @@ */ #include "pybind_api/ir/primitive_py.h" + #include #include "ir/signature.h" -#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/python_adapter.h" #include "pybind11/pytypes.h" -#include "utils/convert_utils_base.h" -#include "utils/convert_utils_py.h" -#include "utils/primitive_utils.h" -#include "utils/ms_context.h" #include "pybind_api/api_register.h" #include "pybind_api/export_flags.h" #include "pybind_api/ir/base_ref_py.h" +#include "utils/convert_utils_base.h" +#include "utils/convert_utils_py.h" +#include "utils/ms_context.h" +#include "utils/primitive_utils.h" namespace mindspore { namespace { @@ -107,6 +108,42 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) return grads; } +void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const { + if (py::isinstance(expected_grad_out)) { + if (!py::isinstance(grad_out)) { + hook_grad_.clear(); + MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!"; + } + auto actual_out_tuple = py::cast(grad_out); + auto expected_out_tuple = py::cast(expected_grad_out); + if (actual_out_tuple.size() != expected_out_tuple.size()) { + hook_grad_.clear(); + MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size() + << ", but it is " << actual_out_tuple.size(); + } + for (size_t i = 0; i < expected_out_tuple.size(); ++i) { + CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i]); + } + } + + if (py::isinstance(expected_grad_out)) { + if (!py::isinstance(grad_out)) { + hook_grad_.clear(); + MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!"; + } + auto actual_out_tensor = py::cast(grad_out); + auto expected_out_tensor = py::cast(expected_grad_out); + MS_EXCEPTION_IF_NULL(actual_out_tensor); + MS_EXCEPTION_IF_NULL(expected_out_tensor); + if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) { + hook_grad_.clear(); + MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be " + << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is " + << actual_out_tensor->GetShapeAndDataTypeInfo(); + } + } +} + BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { py::tuple py_args = ConvertDatatoPyTuple(args); bool is_bprop = this->HasAttr(kBpropAttrName); @@ -138,6 +175,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { if (py::isinstance(obj)) { obj = py_args[2]; } + CheckHookConsistency(obj, py_args[2]); hook_grad_.erase(cell_id); } else { hook_grad_[cell_id] = py_args[2]; @@ -149,6 +187,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { if (py::isinstance(obj)) { obj = py_args[2]; } + CheckHookConsistency(obj, py_args[2]); } obj = py::make_tuple(obj); return std::make_shared(obj); diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.h b/mindspore/ccsrc/pybind_api/ir/primitive_py.h index 02ac810361..efd27abccc 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.h +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.h @@ -17,20 +17,20 @@ #ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ #define MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ -#include -#include +#include #include #include #include -#include +#include +#include #include "abstract/abstract_value.h" -#include "utils/misc.h" -#include "pybind11/pybind11.h" -#include "utils/log_adapter.h" +#include "frontend/parallel/ops_info/operator_info.h" #include "ir/primitive.h" #include "ir/signature.h" -#include "frontend/parallel/ops_info/operator_info.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" +#include "utils/misc.h" namespace py = pybind11; namespace mindspore { @@ -69,6 +69,7 @@ class PrimitivePy : public Primitive { private: py::function GetComputeFunction() const; + void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const; py::object python_obj_; py::function hook_; std::vector signatures_;