!1103 Bugfix: support read_variable_op for unknown shape

From: @wan_xuelei
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
pull/1103/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1cb979c4e2

@ -36,7 +36,7 @@ const std::map<std::string, std::vector<uint32_t>>
{BROADCASTGRADIENTARGS, {}} {BROADCASTGRADIENTARGS, {}}
}; };
const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE}; const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE, NOOP};
Status RefInputTask::UpdateArgs(TaskContext &) { Status RefInputTask::UpdateArgs(TaskContext &) {
// no need update args // no need update args

@ -17,6 +17,7 @@
#include "rts_node_executor.h" #include "rts_node_executor.h"
#include "common/debug/log.h" #include "common/debug/log.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "common/types.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "hybrid/model/hybrid_model.h" #include "hybrid/model/hybrid_model.h"
#include "runtime/rt.h" #include "runtime/rt.h"
@ -50,6 +51,20 @@ Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) {
return SUCCESS; return SUCCESS;
} }
Status ReadVariableOpNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", context.GetNodeName());
for (int i = 0; i < context.NumInputs(); ++i) {
GE_CHK_STATUS_RET(DoCopyTensor(context, i));
}
if (done_callback) {
GE_CHK_STATUS_RET(context.RegisterCallback(done_callback));
}
GELOGD("[%s] Done executing successfully.", context.GetNodeName());
return SUCCESS;
}
Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", context.GetNodeName()); GELOGD("[%s] Start to execute.", context.GetNodeName());
GE_CHK_STATUS_RET(DoCopyTensor(context, 0)); GE_CHK_STATUS_RET(DoCopyTensor(context, 0));
@ -111,6 +126,8 @@ Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node,
task = MakeShared<IdentityNodeTask>(); task = MakeShared<IdentityNodeTask>();
} else if (op_type == IDENTITYN) { } else if (op_type == IDENTITYN) {
task = MakeShared<IdentityNNodeTask>(); task = MakeShared<IdentityNNodeTask>();
} else if (op_type == READVARIABLEOP) {
task = MakeShared<ReadVariableOpNodeTask>();
} else if (op_type == PROFILINGTRAININGTRACE) { } else if (op_type == PROFILINGTRAININGTRACE) {
auto *task_defs = model.GetTaskDefs(node); auto *task_defs = model.GetTaskDefs(node);
if (task_defs == nullptr || task_defs->empty()) { if (task_defs == nullptr || task_defs->empty()) {

@ -36,6 +36,11 @@ class IdentityNNodeTask : public IdentityNodeTask {
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
}; };
class ReadVariableOpNodeTask : public IdentityNodeTask {
public:
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
};
class ProfilingTraceNodeTask : public NodeTask { class ProfilingTraceNodeTask : public NodeTask {
public: public:
explicit ProfilingTraceNodeTask(const std::vector<domi::TaskDef> &task_defs) : task_defs_(task_defs) {} explicit ProfilingTraceNodeTask(const std::vector<domi::TaskDef> &task_defs) : task_defs_(task_defs) {}

Loading…
Cancel
Save