|
|
|
@ -572,7 +572,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|
|
|
|
// run graph steps
|
|
|
|
|
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|
|
|
|
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
|
|
|
|
size_t input_ctrl_size = 1;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_);
|
|
|
|
@ -585,6 +584,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|
|
|
|
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
|
|
|
|
<< ", input_ctrl_size:" << input_ctrl_size;
|
|
|
|
|
}
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
|
|
|
auto tensor = inputs[i];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor);
|
|
|
|
@ -594,8 +595,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|
|
|
|
auto pk_node = input_node->cast<ParameterPtr>();
|
|
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
|
|
|
|
bool need_sync = false;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
|
|
|
|
if (MsContext::GetInstance()->enable_pynative_infer()) {
|
|
|
|
|
if (ms_context->enable_pynative_infer()) {
|
|
|
|
|
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
|
|
|
|
|
need_sync = true;
|
|
|
|
|
}
|
|
|
|
|