!12431 move dynamic_shape_depends to backend

From: @zhupuxu
Reviewed-by: @jjfeing
Signed-off-by:
pull/12431/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e80e220f0a

@ -23,11 +23,17 @@
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
constexpr auto kUnsortedSegmentSum = "UnsortedSegmentSum"; const auto kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
constexpr auto kUnsortedSegmentMin = "UnsortedSegmentMin"; const auto kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
constexpr auto kUnsortedSegmentMax = "UnsortedSegmentMax"; const auto kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
const auto kGather = prim::kPrimGather->name();
const auto kGatherV2 = prim::kPrimGatherV2->name();
const auto kDynamicShape = prim::kPrimDynamicShape->name();
const auto kRange = prim::kPrimRange->name();
static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = { static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = {
{kUnsortedSegmentSum, {2}}, {kUnsortedSegmentMin, {2}}, {kUnsortedSegmentMax, {2}}}; {kUnsortedSegmentSum, {2}}, {kUnsortedSegmentMin, {2}}, {kUnsortedSegmentMax, {2}}, {kGather, {2}},
{kGatherV2, {2}}, {kDynamicShape, {0}}, {kRange, {0, 1, 2}},
};
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) { if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs"; MS_LOG(EXCEPTION) << "Invalid inputs";

@ -593,7 +593,6 @@ class DynamicShape(Primitive):
"""init Shape""" """init Shape"""
self.init_prim_io_names(inputs=['tensor'], outputs=['output']) self.init_prim_io_names(inputs=['tensor'], outputs=['output'])
self.add_prim_attr('is_dynamic_shape', True) self.add_prim_attr('is_dynamic_shape', True)
self.add_prim_attr("dynamic_shape_depends", [0])
class Squeeze(PrimitiveWithInfer): class Squeeze(PrimitiveWithInfer):
@ -811,7 +810,6 @@ class Gather(PrimitiveWithCheck):
def __init__(self): def __init__(self):
"""Initialize index_select""" """Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, params, indices, axis): def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
@ -836,7 +834,6 @@ class GatherV2(PrimitiveWithCheck):
def __init__(self): def __init__(self):
"""Initialize index_select""" """Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, params, indices, axis): def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
@ -1987,7 +1984,6 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
def __init__(self): def __init__(self):
"""Initialize UnsortedSegmentMin""" """Initialize UnsortedSegmentMin"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, x, segment_ids, num_segments): def __check__(self, x, segment_ids, num_segments):
x_shape = x['shape'] x_shape = x['shape']
@ -2043,7 +2039,6 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
def __init__(self): def __init__(self):
"""Initialize UnsortedSegmentMax""" """Initialize UnsortedSegmentMax"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, x, segment_ids, num_segments): def __check__(self, x, segment_ids, num_segments):
x_shape = x['shape'] x_shape = x['shape']
@ -4980,10 +4975,6 @@ class Range(PrimitiveWithCheck):
self.maxlen = maxlen self.maxlen = maxlen
self.add_prim_attr('maxlen', maxlen) self.add_prim_attr('maxlen', maxlen)
self.add_prim_attr("dynamic_shape_depends", [0])
self.add_prim_attr("dynamic_shape_depends", [1])
self.add_prim_attr("dynamic_shape_depends", [2])
def check_shape(self, start_shape, limit_shape, delta_shape): def check_shape(self, start_shape, limit_shape, delta_shape):
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name) validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name) validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)

Loading…
Cancel
Save