diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 11ef788fee..61b4bfbda1 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -400,8 +400,18 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & op = prim::kPrimListGetItem; } + if (tail_type_ == kGradFirst) { + if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && (*sequeue)[1]->isa()) { + ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))})); + } else { + ret->set_output(NewValueNode(std::make_shared(std::vector{}))); + } + + return ret; + } + for (size_t i = 1; i < sequeue->size(); ++i) { - if (do_grad_) { + if (tail_type_ == kGradAll) { MS_EXCEPTION_IF_NULL((*sequeue)[i]); if ((*sequeue)[i]->isa()) { elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); @@ -581,8 +591,8 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt CNodePtr inputs_bprop = nullptr; if (get_all_) { - TailPtr tail = std::make_shared("tail", true); - inputs_bprop = k_child->NewCNode({NewValueNode(tail), b_app}); + TailPtr tail_grad_all = std::make_shared("tail_grad_all", kGradAll); + inputs_bprop = k_child->NewCNode({NewValueNode(tail_grad_all), b_app}); } // Gradients wrt inputs and parameters @@ -602,11 +612,11 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt k_child->set_output(inputs_bprop); return; } - // Gradients wrt first input. - // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input - k_child->set_output( - k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast(1))})); + // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), + // so obtain first input grad by setting tail_type of Tail to kGradFirst. + TailPtr tail_grad_first = std::make_shared("tail_grad_first", kGradFirst); + k_child->set_output(k_child->NewCNode({NewValueNode(tail_grad_first), b_app})); } // Generate the graph. diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h index e6cb182577..061534bf8a 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.h +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -97,9 +97,11 @@ using HyperMapPyPtr = std::shared_ptr; extern ValuePtr kCompositeHyperMap; +enum TailType { kGradAll, kGradFirst, kNotGrad }; + class Tail : public MetaFuncGraph { public: - explicit Tail(const std::string &name, bool do_grad = false) : MetaFuncGraph(name), do_grad_(do_grad) {} + explicit Tail(const std::string &name, TailType tail_type = kNotGrad) : MetaFuncGraph(name), tail_type_(tail_type) {} ~Tail() override = default; MS_DECLARE_PARENT(Tail, MetaFuncGraph) @@ -109,7 +111,7 @@ class Tail : public MetaFuncGraph { friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } private: - bool do_grad_; + TailType tail_type_; }; using TailPtr = std::shared_ptr; diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py index d1c2b8cebb..246370cc1e 100644 --- a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py @@ -24,9 +24,9 @@ from mindspore.ops import composite as C context.set_context(mode=context.GRAPH_MODE) -class Net(nn.Cell): +class FirstInputTupleNet(nn.Cell): def __init__(self): - super(Net, self).__init__() + super(FirstInputTupleNet, self).__init__() def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag): if flag: @@ -35,11 +35,11 @@ class Net(nn.Cell): class GradNet(nn.Cell): - def __init__(self, net): + def __init__(self, net, get_all): super(GradNet, self).__init__() self.forward_net = net self.sens = Tensor(np.ones((2, 2), np.float32) * 5) - self.grad_all = C.GradOperation(get_all=True) + self.grad_all = C.GradOperation(get_all=get_all) def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag): return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag) @@ -64,8 +64,8 @@ flag_1 = False p = Parameter(x, name="weight") a = np.ones((2, 2)) -forward_net = Net() -grad_net = GradNet(forward_net) +forward_net = FirstInputTupleNet() +grad_all_inputs_net = GradNet(forward_net, get_all=True) def test_outermost_net_inputs_including_non_tensor(): @@ -74,13 +74,31 @@ def test_outermost_net_inputs_including_non_tensor(): def test_grad_net_inputs_including_non_tensor(): - grad_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) - grad_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1) + assert len(grad_all_inputs_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)) == 2 + assert len(grad_all_inputs_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)) == 2 + + +def test_grad_first_input_net(): + class FirstInputTensorNet(nn.Cell): + def __init__(self): + super(FirstInputTensorNet, self).__init__() + + def construct(self, tensor_x, tuple_a, list_b, tensor_y, scalar, dict_c, flag): + if flag: + return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"] + return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"] + + grad_fist_input_tensor_net = GradNet(FirstInputTensorNet(), get_all=False) + ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, sl, args_d0, flag_0) + assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32)) + + grad_fist_input_tuple_net = GradNet(forward_net, get_all=False) + assert not grad_fist_input_tuple_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) def test_net_inputs_including_str(): with pytest.raises(TypeError) as err: - grad_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0) + grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0) assert "The inputs types of the outermost network support bool, int, float, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ @@ -117,7 +135,7 @@ def test_outermost_net_pass_list_including_parameter(): def test_grad_net_pass_dict_including_parameter(): with pytest.raises(TypeError) as err: - grad_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0) + grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0) assert "The inputs types of the outermost network support bool, int, float, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \