memcpy_async infershape

pull/8222/head
liubuyu 5 years ago
parent 813e4624ab
commit 0b79a94e22

@ -78,6 +78,12 @@ AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &gr
// when input is also a hccl op and just part outputs of it linking with cur_hccl_op
if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
if (memcpy_async == nullptr) {
MS_LOG(EXCEPTION) << "Create memcpy_async op failed.";
}
if (AnfAlgo::IsNodeDynamicShape(input)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async);
}
auto kernel_info = std::make_shared<device::KernelInfo>();
memcpy_async->set_kernel_info(kernel_info);
MS_EXCEPTION_IF_NULL(kernel_select_);

@ -43,6 +43,9 @@ AnfNodePtr InsertMemcpyAsyncForGetNextOutputs(const FuncGraphPtr &func_graph, co
if (new_node == nullptr) {
MS_LOG(EXCEPTION) << "Create memcpy_async op failed!";
}
if (AnfAlgo::IsNodeDynamicShape(tuple_get_item)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), new_node);
}
AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node);
make_tuple_inputs.push_back(new_node);
}

@ -158,6 +158,12 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
auto input = hccl_node->input(i);
if (NeedInsertMemcpy(graph, input, hccl_node)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
if (memcpy_async == nullptr) {
MS_LOG(EXCEPTION) << "Create memcpy_async op failed.";
}
if (AnfAlgo::IsNodeDynamicShape(input)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async);
}
new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async);
} else {

@ -29,6 +29,7 @@
#include "backend/kernel_compiler/kernel_build_info.h"
#include "common/trans.h"
#include "abstract/param_validator.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
namespace mindspore {
namespace session {
@ -1279,7 +1280,8 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri
}
bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape);
return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) ||
GetBooleanAttr(node, kAttrIsDynamicShape);
}
void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector<size_t> &shape,
@ -1358,5 +1360,36 @@ std::vector<int> AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_no
MS_LOG(EXCEPTION) << "Invalid Shape Type";
}
}
bool CheckDynamic(const NotNull<abstract::ShapePtr> &shape) {
return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; });
}
bool AnfRuntimeAlgorithm::IsNodeDynamicShape(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto base_shape = node->Shape();
if (base_shape == nullptr) {
MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
return false;
}
if (base_shape->isa<abstract::Shape>()) {
if (CheckDynamic(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
return true;
}
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
for (size_t i = 0; i < tuple_shape->size(); i++) {
auto b_shape = (*tuple_shape)[i];
if (!b_shape->isa<abstract::Shape>()) {
continue;
}
if (CheckDynamic(NOT_NULL(b_shape->cast<abstract::ShapePtr>()))) {
return true;
}
}
}
return false;
}
} // namespace session
} // namespace mindspore

@ -229,6 +229,7 @@ class AnfRuntimeAlgorithm {
static std::vector<int> GetInputMinShape(const AnfNodePtr &anf_node, size_t index);
static std::vector<int> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index);
static std::vector<int> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index);
static bool IsNodeDynamicShape(const AnfNodePtr &node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;

@ -221,6 +221,9 @@ AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr
AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.

@ -430,5 +430,15 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv
tmp_shape[0] = IntMulWithOverflowCheck(tmp_shape[0], rank_size);
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape));
}
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
}
} // namespace abstract
} // namespace mindspore

@ -127,6 +127,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimBroadcast, {InferImplBroadcast, true}},
{prim::kPrimAllGather, {InferImplAllGather, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
};
return prim_eval_implement_map;
}

@ -181,6 +181,7 @@ inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduc
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");
inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather");
inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter");
inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared<Primitive>("memcpy_async");
// RowTensor
inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor");

Loading…
Cancel
Save