!11375 Pre build all ops in bprop graph in PyNaive mode

From: @HulkTang
Reviewed-by: @chujinjin,@chujinjin
Signed-off-by: @chujinjin,@chujinjin
pull/11375/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5eaac67416

File diff suppressed because it is too large Load Diff

@ -41,6 +41,11 @@ struct InputTensorInfo {
std::set<KernelWithIndex> input_kernel;
};
struct OutputTensorInfo {
tensor::TensorPtr output_stub_tensor;
bool is_weight;
};
class AscendSession : public SessionBasic {
public:
AscendSession() { final_graph_id_ = kInvalidGraphId; }
@ -79,6 +84,7 @@ class AscendSession : public SessionBasic {
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void BuildKernel(const std::vector<CNodePtr> &kernels) const;
void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
@ -119,7 +125,11 @@ class AscendSession : public SessionBasic {
void LoadGraphsToDbg(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
KernelGraphPtr PreBuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask);
void BuildOpsInGraph(KernelGraph *graph, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs);
// key is final_graph_id,value is child graph execute order of final graph
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
// key is final_graph_id,value is the graph types of child graphs
@ -128,6 +138,8 @@ class AscendSession : public SessionBasic {
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
// final_graph_id is used in every root graph has it's own session situation
GraphId final_graph_id_;
// record graph ids of bp graphs that has been built in PyNative mode
std::set<GraphId> built_graph_id_;
};
MS_REG_SESSION(kAscendDevice, AscendSession);
} // namespace session

@ -67,12 +67,11 @@ static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
return kernel_mod_ptr;
}
static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
static bool KernelBuildParallelCompile(const std::vector<CNodePtr> &kernels) {
std::vector<AnfNodePtr> tbe_nodes;
std::vector<AnfNodePtr> akg_nodes;
std::vector<AnfNodePtr> other_nodes;
for (const auto &anf_node : kernel_graph_ptr->execution_order()) {
for (const auto &anf_node : kernels) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!AnfAlgo::IsRealKernel(anf_node)) {
continue;
@ -217,12 +216,9 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) {
return !(workspace_indexs.empty() && output_indexs.empty());
}
bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
bool KernelBuild(const std::vector<CNodePtr> &kernels) {
TbeUtils::LoadCache();
bool ret;
ret = device::ascend::KernelBuildParallelCompile(kernel_graph_ptr);
return ret;
return device::ascend::KernelBuildParallelCompile(kernels);
}
std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(

@ -17,6 +17,8 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_
#include <vector>
#include "backend/session/kernel_graph.h"
namespace mindspore {
@ -25,7 +27,7 @@ namespace ascend {
/**
* @brief kernel build for ascend.
*/
bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr);
bool KernelBuild(const std::vector<CNodePtr> &kernels);
/**
* @brief preporcess of kernel build for ascend, e.g. inserting clear_zero node for maxpool, bn.
* Must DO these changes just before kernel build, and after all of other optimizations on AnfGraph

Loading…
Cancel
Save