|
|
|
@ -75,7 +75,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
|
|
|
|
|
temp_shape.emplace_back(1);
|
|
|
|
|
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
|
|
|
|
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
|
|
|
|
tensor->set_dirty(false);
|
|
|
|
|
tensor->set_sync_status(kNoNeedSync);
|
|
|
|
|
tensor->SetNeedWait(true);
|
|
|
|
|
return tensor;
|
|
|
|
|
}
|
|
|
|
@ -96,12 +96,13 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
|
|
|
|
|
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
|
|
|
|
tensor->set_need_sync(true);
|
|
|
|
|
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
|
|
|
|
} else {
|
|
|
|
|
tensor->set_sync_status(kNeedSyncDeviceToHost);
|
|
|
|
|
}
|
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
|
|
|
|
tensor->SetNeedWait(true);
|
|
|
|
|
}
|
|
|
|
|
tensor->set_dirty(false);
|
|
|
|
|
return tensor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -198,7 +199,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
|
|
|
|
|
auto *cur_val = static_cast<int32_t *>(cur_loop_tensor->data_c());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cur_val);
|
|
|
|
|
*cur_val = 0;
|
|
|
|
|
cur_loop_tensor->set_dirty(true);
|
|
|
|
|
cur_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
|
|
|
|
|
// set loop_count to zero
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs);
|
|
|
|
|
inputs->push_back(cur_loop_tensor);
|
|
|
|
@ -209,7 +210,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
|
|
|
|
|
auto *next_val = static_cast<int32_t *>(next_loop_tensor->data_c());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(next_val);
|
|
|
|
|
*next_val = 0;
|
|
|
|
|
next_loop_tensor->set_dirty(true);
|
|
|
|
|
next_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
|
|
|
|
|
// set loop_count to zero
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs);
|
|
|
|
|
inputs->push_back(next_loop_tensor);
|
|
|
|
@ -219,7 +220,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
|
|
|
|
|
auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(epoch_val);
|
|
|
|
|
*epoch_val = graph->current_epoch();
|
|
|
|
|
epoch_tensor->set_dirty(true);
|
|
|
|
|
epoch_tensor->set_sync_status(kNeedSyncHostToDevice);
|
|
|
|
|
inputs->push_back(epoch_tensor);
|
|
|
|
|
MS_LOG(INFO) << "Load epoch_val:" << *epoch_val;
|
|
|
|
|
|
|
|
|
@ -927,7 +928,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor
|
|
|
|
|
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
|
|
|
|
|
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
|
|
|
|
|
}
|
|
|
|
|
if (tensor->is_dirty()) {
|
|
|
|
|
if (tensor->NeedSyncHostToDevice()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (tensor->device_address() != device_address) {
|
|
|
|
@ -976,7 +977,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|
|
|
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tensor->set_dirty(false);
|
|
|
|
|
tensor->set_sync_status(kNoNeedSync);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1124,7 +1125,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
|
|
|
|
tensor->data_type(), tensor->data_c())) {
|
|
|
|
|
MS_LOG(ERROR) << "Failed to sync output from device to host.";
|
|
|
|
|
}
|
|
|
|
|
tensor->set_dirty(false);
|
|
|
|
|
tensor->set_sync_status(kNoNeedSync);
|
|
|
|
|
params_list[output_item.first] = tensor;
|
|
|
|
|
}
|
|
|
|
|
// call callback function here
|
|
|
|
|