|
|
|
@ -108,6 +108,20 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
|
|
|
|
|
return grads;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrimitivePy::ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(convert_args);
|
|
|
|
|
if (input_args.size() != (*convert_args).size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The size of input_args: " << input_args.size()
|
|
|
|
|
<< " should be equal to the size of convert_args: " << (*convert_args).size();
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < input_args.size(); ++i) {
|
|
|
|
|
(*convert_args)[i] = py::isinstance<tensor::Tensor>(input_args[i])
|
|
|
|
|
? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
|
|
|
|
|
parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i])
|
|
|
|
|
: input_args[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const {
|
|
|
|
|
if (py::isinstance<py::tuple>(expected_grad_out)) {
|
|
|
|
|
if (!py::isinstance<py::tuple>(grad_out)) {
|
|
|
|
@ -150,12 +164,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|
|
|
|
if (is_bprop) {
|
|
|
|
|
SyncData(py_args);
|
|
|
|
|
py::tuple convert_args(py_args.size());
|
|
|
|
|
for (size_t i = 0; i < py_args.size(); i++) {
|
|
|
|
|
convert_args[i] = py::isinstance<tensor::Tensor>(py_args[i])
|
|
|
|
|
? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
|
|
|
|
|
parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, py_args[i])
|
|
|
|
|
: py_args[i];
|
|
|
|
|
}
|
|
|
|
|
ConvertCTensorToPyTensor(py_args, &convert_args);
|
|
|
|
|
py::object grads_obj = hook_(*convert_args);
|
|
|
|
|
py::tuple grads = check_bprop_out(grads_obj, py_args);
|
|
|
|
|
return std::make_shared<PyObjectRef>(grads);
|
|
|
|
@ -167,10 +176,15 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|
|
|
|
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
|
|
|
|
|
auto iter = hook_grad_.find(cell_id);
|
|
|
|
|
if (iter != hook_grad_.end()) {
|
|
|
|
|
py::tuple convert_args(2);
|
|
|
|
|
py::tuple input_args(2);
|
|
|
|
|
input_args[0] = iter->second;
|
|
|
|
|
input_args[1] = py_args[2];
|
|
|
|
|
ConvertCTensorToPyTensor(input_args, &convert_args);
|
|
|
|
|
auto hook_args = py::tuple(3);
|
|
|
|
|
hook_args[0] = cell_id;
|
|
|
|
|
hook_args[1] = py::make_tuple(iter->second);
|
|
|
|
|
hook_args[2] = py::make_tuple(py_args[2]);
|
|
|
|
|
hook_args[1] = py::make_tuple(convert_args[0]);
|
|
|
|
|
hook_args[2] = py::make_tuple(convert_args[1]);
|
|
|
|
|
obj = hook_(*hook_args);
|
|
|
|
|
if (py::isinstance<py::none>(obj)) {
|
|
|
|
|
obj = py_args[2];
|
|
|
|
|