!11928 move dynamic_shape_depends to backend

From: @zhupuxu
Reviewed-by: 
Signed-off-by:
pull/11928/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit db3fe6a461

@ -100,13 +100,13 @@ void FeedTeOpTensorOutputArg(const NotNull<CNodePtr> &cnode,
void FeedTeOpConstTensor(const NotNull<CNodePtr> &cnode, const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
NotNull<std::map<std::string, optiling::TeConstTensorData> *> const_inputs) {
MS_LOG(INFO) << "FeedTeOpConstTensor start, node:" << cnode->fullname_with_scope();
if (!AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode.get())) {
auto depends_list_me = abstract::GetDependsFormMap(cnode);
if (depends_list_me.empty()) {
MS_LOG(INFO) << "No input depend found, " << cnode->fullname_with_scope();
return;
}
std::vector<int> depends_list;
std::vector<int64_t> depends_list_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode.get(), kDynamicShapeDepends);
(void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depends_list),
[](const int64_t &value) { return static_cast<int>(value); });
for (auto index : depends_list) {

@ -25,6 +25,7 @@
#include "ir/anf.h"
#include "ir/tensor.h"
#include "register/op_tiling.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace device {

@ -39,16 +39,14 @@ void DynamicKernel::Initialize() {
is_input_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrInputIsDynamicShape);
is_output_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrOutputIsDynamicShape);
auto have_depends = AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode_ptr_);
if (!have_depends) {
auto ret = abstract::GetDependsFormMap(cnode_ptr_);
if (ret.empty()) {
MS_LOG(DEBUG) << "No dynamic_shape_depends found";
return;
}
MS_LOG(INFO) << "Have depends";
std::vector<int64_t> depends_list_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode_ptr_, kDynamicShapeDepends);
(void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depend_list_),
(void)std::transform(ret.begin(), ret.end(), std::back_inserter(depend_list_),
[](const int64_t &value) { return static_cast<int>(value); });
MS_LOG(INFO) << "Init End";
}

@ -23,6 +23,7 @@
#include <map>
#include "ir/anf.h"
#include "ir/tensor.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace device {

@ -22,6 +22,25 @@
namespace mindspore {
namespace abstract {
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
constexpr auto kUnsortedSegmentSum = "UnsortedSegmentSum";
constexpr auto kUnsortedSegmentMin = "UnsortedSegmentMin";
constexpr auto kUnsortedSegmentMax = "UnsortedSegmentMax";
static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = {
{kUnsortedSegmentSum, {2}}, {kUnsortedSegmentMin, {2}}, {kUnsortedSegmentMax, {2}}};
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs";
}
auto primitive = GetValueNode<PrimitivePtr>(cnode->inputs()[0]);
MS_EXCEPTION_IF_NULL(primitive);
auto iter = dynamic_shape_depends.find(primitive->ToString());
if (iter != dynamic_shape_depends.end()) {
return iter->second;
}
return {};
}
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
static PrimitiveEvalImplMap prim_eval_implement_map = {
// Statements

@ -21,6 +21,8 @@
#include "ir/primitive.h"
#include "base/core_ops.h"
#include "abstract/abstract_value.h"
#include "ir/anf.h"
namespace mindspore {
namespace abstract {
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
@ -35,6 +37,8 @@ using PrimitiveEvalImplMap =
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap();
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode);
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg);
class RegisterStandardPrimitiveEvalHelper {

@ -1892,7 +1892,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
def __init__(self):
"""Initialize UnsortedSegmentSum"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2])
def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype']

Loading…
Cancel
Save