|
|
|
@ -226,8 +226,7 @@ void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tenso
|
|
|
|
|
}
|
|
|
|
|
auto value_tuple = input_value->cast<ValueTuplePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_tuple);
|
|
|
|
|
tensor::TensorPtr tensor_ptr = nullptr;
|
|
|
|
|
tensor_ptr = opt::CreateTupleTensor(value_tuple);
|
|
|
|
|
tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
|
|
|
|
input_tensor->push_back(tensor_ptr);
|
|
|
|
|
}
|
|
|
|
@ -583,12 +582,9 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
auto input_nodes = kernel_graph->inputs();
|
|
|
|
|
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "tensor input size:" << inputs.size()
|
|
|
|
|
<< " is not equal graph inputs size:" << input_nodes.size()
|
|
|
|
|
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
|
|
|
|
<< ", input_ctrl_size:" << input_ctrl_size;
|
|
|
|
|
}
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
|
|
|
auto tensor = inputs[i];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor);
|
|
|
|
@ -598,7 +594,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|
|
|
|
auto pk_node = input_node->cast<ParameterPtr>();
|
|
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
|
|
|
|
bool need_sync = false;
|
|
|
|
|
if (ms_context->enable_pynative_infer()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
|
|
|
|
if (MsContext::GetInstance()->enable_pynative_infer()) {
|
|
|
|
|
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
|
|
|
|
|
need_sync = true;
|
|
|
|
|
}
|
|
|
|
|