diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 99dc085989..0c9764af93 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -321,6 +321,13 @@ AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node } target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes), func_graph->NewCNode(dict_value_nodes)}); + } else if (node_type->isa()) { + auto x = node_type->cast(); + std::string kwarg_key = x->get_key(); + AnfNodePtr kwarg_value_node = + func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node}); + AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph); + target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node}); } return target_node; } diff --git a/tests/ut/python/model/test_mix_precision.py b/tests/ut/python/model/test_mix_precision.py index d0e77f901a..f1fc2cc2f7 100644 --- a/tests/ut/python/model/test_mix_precision.py +++ b/tests/ut/python/model/test_mix_precision.py @@ -219,3 +219,31 @@ def test_dict_cast(): y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32) net = FirstNet() net(x, y) + + +def test_kwarg_cast(): + class FirstNet(nn.Cell): + def __init__(self): + super(FirstNet, self).__init__() + self.net = SecondNet().add_flags_recursive(fp16=True) + self.add = P.TensorAdd() + + def construct(self, tensor_a, tensor_b): + tensor_c = self.add(tensor_a, tensor_b) + dictionary = {"key": tensor_a} + result = self.net(key1=tensor_c, key2=dictionary) + return result + + class SecondNet(nn.Cell): + def __init__(self): + super(SecondNet, self).__init__() + self.add = P.TensorAdd() + + def construct(self, key1=1, key2=2): + tensor_d = self.add(key1, key2["key"]) + return tensor_d + + x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32) + y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32) + net = FirstNet() + net(x, y)