From 3ba207421362cffe502abd4fbecb40cfe8e4fe03 Mon Sep 17 00:00:00 2001 From: Kang Date: Thu, 11 Jun 2020 21:11:11 +0800 Subject: [PATCH] Add MixedPrecisionCast for Dict --- .../ccsrc/pipeline/static_analysis/prim.cc | 16 +++++++++ .../static_analysis/static_analysis.cc | 2 +- tests/ut/python/model/test_mix_precision.py | 34 +++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index f6c78f0cd2..cda62cbe4b 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -286,6 +286,22 @@ AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node ++idx; } target_node = func_graph->NewCNode(nodes); + } else if (node_type->isa()) { + auto x = node_type->cast(); + auto &items = x->elements(); + std::vector dict_key_nodes; + std::vector dict_value_nodes; + dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (const auto &item : items) { + AnfNodePtr dict_value_node = + func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)}); + AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph); + dict_key_nodes.emplace_back(NewValueNode(item.first)); + dict_value_nodes.emplace_back(node); + } + target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes), + func_graph->NewCNode(dict_value_nodes)}); } return target_node; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc index dc79655348..9299a02002 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc @@ -308,7 +308,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr evaluator = std::make_shared(prim); return evaluator; } - if (prim->name() == prim::kPrimMixedPrecisionCast->name()) { + if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { evaluator = std::make_shared(prim); return evaluator; } diff --git a/tests/ut/python/model/test_mix_precision.py b/tests/ut/python/model/test_mix_precision.py index 30c6002be8..d0e77f901a 100644 --- a/tests/ut/python/model/test_mix_precision.py +++ b/tests/ut/python/model/test_mix_precision.py @@ -25,6 +25,7 @@ from mindspore.nn import Momentum from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.ops import composite as C from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.train.parallel_utils import ParallelMode from tests.ops_common import convert from ....train_step_wrap import train_step_with_loss_warp @@ -185,3 +186,36 @@ def test_grad_conv_prelu(): net = GetParamGrad(net) net.set_train() net(*all_inputs) + + +def test_dict_cast(): + class FirstNet(nn.Cell): + def __init__(self): + super(FirstNet, self).__init__() + self.net = SecondNet() + self.sub = P.Sub() + + def construct(self, tensor_a, tensor_b): + a = F.mixed_precision_cast(mstype.float16, tensor_a) + b = F.mixed_precision_cast(mstype.float16, tensor_b) + c = self.sub(a, b) + dictionary = {"key": a} + result = self.net(c, key1=a, key2=dictionary) + return result + + class SecondNet(nn.Cell): + def __init__(self): + super(SecondNet, self).__init__() + self.add = P.TensorAdd() + + def construct(self, tensor_c, **kwargs): + d = F.mixed_precision_cast(mstype.float16, tensor_c) + dict_cast = F.mixed_precision_cast(mstype.float16, kwargs) + e = self.add(d, dict_cast["key1"]) + f = self.add(e, dict_cast["key2"]["key"]) + return f + + x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32) + y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32) + net = FirstNet() + net(x, y)