|
|
|
@ -266,23 +266,12 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input,
|
|
|
|
|
return make_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NeedInsertSwitch() {
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() &&
|
|
|
|
|
ConfigManager::GetInstance().iter_num() > 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t LoadCtrlInputTensor(const std::shared_ptr<Context> &context, std::vector<tensor::TensorPtr> *inputs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
if (!NeedInsertSwitch()) {
|
|
|
|
|
(void)context->results_.erase(kInputCtrlTensors);
|
|
|
|
|
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
|
|
|
|
|
MS_LOG(INFO) << "Load kInputCtrlTensors";
|
|
|
|
|
auto inputs_params = graph->input_ctrl_tensors();
|
|
|
|
|
if (inputs_params == nullptr) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Load kInputCtrlTensors";
|
|
|
|
|
auto inputs_params =
|
|
|
|
|
context->GetResult(kInputCtrlTensors).cast<const std::shared_ptr<std::vector<tensor::TensorPtr>>>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs_params);
|
|
|
|
|
if (inputs_params->empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
|
|
|
|
|
}
|
|
|
|
@ -686,11 +675,10 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|
|
|
|
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
|
|
|
|
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
|
|
|
|
size_t input_ctrl_size = 1;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_);
|
|
|
|
|
if (context_->HasResult(kInputCtrlTensors)) {
|
|
|
|
|
input_ctrl_size = LoadCtrlInputTensor(context_, &inputs);
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
if (kernel_graph->input_ctrl_tensors()) {
|
|
|
|
|
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
|
|
|
|
|
}
|
|
|
|
|
auto input_nodes = kernel_graph->inputs();
|
|
|
|
|
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
|
|
|
|