!7343 checkout the consistency of grad out in pynative cell hook

Merge pull request !7343 from JoyLvliang/checkout-the-consistency-of-grad-out-in-pynative-cell-hook
pull/7343/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7ff0512210

@ -15,18 +15,19 @@
*/ */
#include "pybind_api/ir/primitive_py.h" #include "pybind_api/ir/primitive_py.h"
#include <mutex> #include <mutex>
#include "ir/signature.h" #include "ir/signature.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "pybind11/pytypes.h" #include "pybind11/pytypes.h"
#include "utils/convert_utils_base.h"
#include "utils/convert_utils_py.h"
#include "utils/primitive_utils.h"
#include "utils/ms_context.h"
#include "pybind_api/api_register.h" #include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h" #include "pybind_api/export_flags.h"
#include "pybind_api/ir/base_ref_py.h" #include "pybind_api/ir/base_ref_py.h"
#include "utils/convert_utils_base.h"
#include "utils/convert_utils_py.h"
#include "utils/ms_context.h"
#include "utils/primitive_utils.h"
namespace mindspore { namespace mindspore {
namespace { namespace {
@ -107,6 +108,42 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
return grads; return grads;
} }
void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const {
if (py::isinstance<py::tuple>(expected_grad_out)) {
if (!py::isinstance<py::tuple>(grad_out)) {
hook_grad_.clear();
MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!";
}
auto actual_out_tuple = py::cast<py::tuple>(grad_out);
auto expected_out_tuple = py::cast<py::tuple>(expected_grad_out);
if (actual_out_tuple.size() != expected_out_tuple.size()) {
hook_grad_.clear();
MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size()
<< ", but it is " << actual_out_tuple.size();
}
for (size_t i = 0; i < expected_out_tuple.size(); ++i) {
CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i]);
}
}
if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
if (!py::isinstance<tensor::Tensor>(grad_out)) {
hook_grad_.clear();
MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!";
}
auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
MS_EXCEPTION_IF_NULL(actual_out_tensor);
MS_EXCEPTION_IF_NULL(expected_out_tensor);
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
hook_grad_.clear();
MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be "
<< expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is "
<< actual_out_tensor->GetShapeAndDataTypeInfo();
}
}
}
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
py::tuple py_args = ConvertDatatoPyTuple(args); py::tuple py_args = ConvertDatatoPyTuple(args);
bool is_bprop = this->HasAttr(kBpropAttrName); bool is_bprop = this->HasAttr(kBpropAttrName);
@ -138,6 +175,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
if (py::isinstance<py::none>(obj)) { if (py::isinstance<py::none>(obj)) {
obj = py_args[2]; obj = py_args[2];
} }
CheckHookConsistency(obj, py_args[2]);
hook_grad_.erase(cell_id); hook_grad_.erase(cell_id);
} else { } else {
hook_grad_[cell_id] = py_args[2]; hook_grad_[cell_id] = py_args[2];
@ -149,6 +187,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
if (py::isinstance<py::none>(obj)) { if (py::isinstance<py::none>(obj)) {
obj = py_args[2]; obj = py_args[2];
} }
CheckHookConsistency(obj, py_args[2]);
} }
obj = py::make_tuple(obj); obj = py::make_tuple(obj);
return std::make_shared<PyObjectRef>(obj); return std::make_shared<PyObjectRef>(obj);

@ -17,20 +17,20 @@
#ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ #ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_
#define MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ #define MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_
#include <unordered_map> #include <map>
#include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <map> #include <unordered_map>
#include <vector>
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "utils/misc.h" #include "frontend/parallel/ops_info/operator_info.h"
#include "pybind11/pybind11.h"
#include "utils/log_adapter.h"
#include "ir/primitive.h" #include "ir/primitive.h"
#include "ir/signature.h" #include "ir/signature.h"
#include "frontend/parallel/ops_info/operator_info.h" #include "pybind11/pybind11.h"
#include "utils/log_adapter.h"
#include "utils/misc.h"
namespace py = pybind11; namespace py = pybind11;
namespace mindspore { namespace mindspore {
@ -69,6 +69,7 @@ class PrimitivePy : public Primitive {
private: private:
py::function GetComputeFunction() const; py::function GetComputeFunction() const;
void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const;
py::object python_obj_; py::object python_obj_;
py::function hook_; py::function hook_;
std::vector<Signature> signatures_; std::vector<Signature> signatures_;

Loading…
Cancel
Save