add cpu pynative mode

pull/8109/head
chujinjin 4 years ago
parent 51d885815a
commit 701ab0d05f

@ -93,10 +93,20 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector<
runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node);
}
void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
return;
}
runtime_.SyncValueNodeDeviceAddr(kernel_graph.get());
}
void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
SyncValueNodeDeviceAddr(kernel_graph);
MS_LOG(INFO) << "Bind input output address";
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
@ -130,6 +140,65 @@ void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
MS_LOG(INFO) << "Run graph end";
}
void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
// Check if the graph cache exists.
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
return;
}
// Prepare the graph
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(kernel_graph);
SetKernelInfo(kernel_graph.get());
BuildKernel(kernel_graph.get());
run_op_graphs_[graph_info] = kernel_graph;
}
void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors) {
for (size_t i = 0; i < base_ref.size(); ++i) {
if (utils::isa<VectorRef>(base_ref[i])) {
auto ref_iter = utils::cast<VectorRef>(base_ref[i]);
SetOutputFlags(ref_iter, outputs_tensors);
} else if (utils::isa<tensor::TensorPtr>(base_ref[i])) {
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref[i]);
tensor_ptr->SetNeedWait(false);
outputs_tensors->push_back(tensor_ptr);
}
}
}
void CPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
const std::vector<int64_t> &tensors_mask) {
MS_EXCEPTION_IF_NULL(input_tensors);
BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph);
runtime_.AssignKernelAddress(kernel_graph.get());
std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
runtime_.CreateOutputTensors(kernel_graph.get(), *input_tensors, outputs, &tensor_to_node);
runtime_.BindInputOutput(kernel_graph.get(), *input_tensors, outputs);
MS_LOG(INFO) << "Run Op start";
auto execution_order = kernel_graph->execution_order();
Reorder(&execution_order);
kernel_graph->set_execution_order(execution_order);
bool ret = runtime_.Run(kernel_graph.get(), false);
if (!ret) {
MS_LOG(EXCEPTION) << "Run Op failed";
}
std::vector<tensor::TensorPtr> output_tensors;
SetOutputFlags(*outputs, &output_tensors);
MS_LOG(INFO) << "Run Op end";
}
void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto &kernel_nodes = kernel_graph->execution_order();

@ -38,10 +38,18 @@ class CPUSession : public SessionBasic {
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override;
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
const std::vector<int64_t> &tensors_mask) override;
private:
void SetKernelInfo(const KernelGraph *kernel_graph);
void BuildKernel(const KernelGraph *kernel_graph);
void SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors);
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph);
device::cpu::CPUKernelRuntime runtime_;
};
MS_REG_SESSION(kCPUDevice, CPUSession);

@ -994,6 +994,8 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex
});
return;
}
auto ms_context = MsContext::GetInstance();
auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
const auto &tensor_id_list = op_index_with_tensor_id_[op_index];
for (size_t i = 0; i < tensor_id_list.size(); ++i) {
auto tensor_id = tensor_id_list[i];
@ -1003,7 +1005,20 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex
std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) {
tensor->set_shape(new_tensor->shape());
tensor->set_data_type(new_tensor->data_type());
if (target != kCPUDevice) {
tensor->set_device_address(new_tensor->device_address());
} else {
auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
auto old_ptr = old_device_address->GetMutablePtr();
auto new_ptr = new_device_address->GetPtr();
MS_EXCEPTION_IF_NULL(old_ptr);
MS_EXCEPTION_IF_NULL(new_ptr);
auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret;
}
}
});
}
}
@ -1264,12 +1279,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati
MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms";
auto ms_context = MsContext::GetInstance();
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target != kAscendDevice && device_target != kGPUDevice) {
MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
}
if (session == nullptr) {
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
session = session::SessionFactory::Get().Create(device_target);
MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));

@ -56,6 +56,11 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
}
auto tensor = node_value->cast<TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
if (tensor->device_address() != nullptr) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), 0,
item_node.get());
continue;
}
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0);
if (output_type_id == kTypeUnknown) {
output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0);

Loading…
Cancel
Save