change mixedprecision of pynative

pull/8522/head
LianLiguang 4 years ago
parent 00b41244ac
commit bb6148661f

@ -124,7 +124,7 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
SetKernelType(kernel_build_info->kernel_type());
SetFusionType(kernel_build_info->fusion_type());
SetProcessor(kernel_build_info->processor());
OpPattern(kernel_build_info->op_pattern());
SetOpPattern(kernel_build_info->op_pattern());
for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));

@ -1396,6 +1396,10 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
cnode->set_abstract(op_run_info.abstract);
// get output dynamic shape info
AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info.is_dynamic_shape), cnode);
if (op_run_info.is_auto_mixed_precision) {
AnfAlgo::SetNodeAttr(kAttrPynativeNextOpName, MakeValue(op_run_info.next_op_name), cnode);
AnfAlgo::SetNodeAttr(kAttrPynativeNextIndex, MakeValue(op_run_info.next_input_index), cnode);
}
// set execution order
std::vector<CNodePtr> exe_order = {cnode};
graph->set_execution_order(exe_order);

@ -51,6 +51,9 @@ struct OpRunInfo {
AbstractBasePtr abstract;
ValuePtr value = nullptr;
bool is_dynamic_shape = false;
bool is_auto_mixed_precision = false;
std::string next_op_name = "";
size_t next_input_index = 0;
};
using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class Executor;

@ -60,6 +60,9 @@ struct OpExecInfo {
py::list op_inputs;
py::dict op_attrs;
std::vector<bool> inputs_mask;
std::string next_op_name = "";
bool is_mixed_precision_cast = false;
size_t next_input_index = 0;
};
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);

File diff suppressed because it is too large Load Diff

@ -81,7 +81,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool grad_flag() { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; }
py::tuple RunOpInner(const py::args &args);
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
void NewGraph(const py::object &cell, const py::args &args);
py::object Run(const py::tuple &args, const py::object &phase);
py::object CheckGraph(const py::object &cell, const py::args &args);
@ -108,6 +108,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
parse::AstMainType type);
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name, size_t index);
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple, const std::string &op_name,
size_t index);
py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index);
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info);
// run op
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
@ -129,8 +135,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
// construct grad graph
void Pushp();
void Popp();
void PushCurrentGraphToStack();
void PopGraphStack();
void NewGraphInner(const py::object &cell, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
@ -148,19 +154,17 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args);
// hold graph(forward and grad) info
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
void set_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, bool is_param = false);
void set_node_map(const FuncGraphPtr &g, const std::string &obj, AnfNodePtr node) {
graph_info_map_[g].node_map[obj] = std::make_pair(node, std::vector<int64_t>{-1});
void SetPyObjInGraphInfoMap(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
bool is_param = false);
void SetNodeMapInGraphInfoMap(FuncGraphPtr g, const std::string id, AnfNodePtr node, int64_t index = -1) {
graph_info_map_[g].node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
}
void set_node_map(const FuncGraphPtr &g, const std::string &obj, AnfNodePtr node, int index) {
graph_info_map_[g].node_map[obj] = std::make_pair(node, std::vector<int64_t>{index});
void SetNodeMapInGraphInfoMap(FuncGraphPtr g, const std::string id, AnfNodePtr node, std::vector<int64_t> index) {
graph_info_map_[g].node_map[id] = std::make_pair(node, index);
}
void set_node_map(const FuncGraphPtr &g, const std::string &obj, AnfNodePtr node, std::vector<int64_t> index) {
graph_info_map_[g].node_map[obj] = std::make_pair(node, index);
}
void set_tuple_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode,
const std::vector<int64_t> &idx, bool is_param = false);
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
const std::vector<int64_t> &index_sequence, bool is_param = false);
static std::shared_ptr<PynativeExecutor> executor_;
static std::mutex instance_lock_;
@ -176,7 +180,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr df_builder_{nullptr};
ResourcePtr resource_{nullptr};
// Records forwrad graph, the bottom is top graph
std::stack<FuncGraphPtr> graph_context_;
std::stack<FuncGraphPtr> graph_stack_;
std::unordered_set<std::string> top_graph_cells_;
// record all info of a graph
@ -195,6 +199,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_map<std::string, std::string> obj_to_forward_id_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
const inline static std::string kOpsFunctionModelName = "mindspore.ops.functional";
const inline static std::string kMSDtypeModelName = "mindspore.common.dtype";
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;

@ -50,8 +50,10 @@ enum MatchCountPriority : int {
MATCH_OUTPUT_DTYPE_COUNT,
MATCH_COUNT_PRIORITY_END
};
const int kUnSupportMixedDataTypeIndex = -1;
const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
{prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}},
{prim::kPrimFusedBatchNorm->name(),
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}}};
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
@ -313,10 +315,41 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
}
return filtered_kernel_info_list;
}
} // namespace
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
if (!AnfAlgo::HasNodeAttr(kAttrPynativeNextIndex, kernel_node) ||
!AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
MS_LOG(EXCEPTION) << "The node [" << kernel_node->DebugString() << "] attr of " << kAttrPynativeNextIndex << " or "
<< kAttrPynativeNextOpName << " has been not setted yet!";
}
auto next_index = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPynativeNextIndex);
auto next_op_name = AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrPynativeNextOpName);
auto iter = kNextOpFormatList.find(next_op_name);
if (iter == kNextOpFormatList.end()) {
MS_LOG(WARNING) << "The op name " << next_op_name << "has been not setted in the next op map ";
return;
}
if (iter->second.size() < next_index) {
MS_LOG(EXCEPTION) << "Next input index " << next_index << "is out of range in the next op map max size is "
<< iter->second.size();
}
if (AnfAlgo::GetCNodeName(kernel_node) != prim::kPrimCast->name()) {
MS_LOG(INFO) << "Only supported to change the node Cast's build info!!!";
return;
}
auto format = iter->second[next_index];
auto info_builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(kernel_node));
info_builder->SetInputsFormat({format});
info_builder->SetOutputsFormat({format});
AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), kernel_node.get());
}
} // namespace
void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
MS_EXCEPTION_IF_NULL(selected_kernel_info);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
@ -329,7 +362,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue;
}
if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
if (selected_kernel_info->GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
continue;
}
// we set special device info of a input tensor.
@ -344,17 +377,17 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
@ -388,7 +421,10 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
// Set kernel info to the anfnode
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
if (AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
SetCastAndWeightFormat(kernel_node);
}
SetTensorDeviceInfo(kernel_node);
return select_status;
}
@ -428,7 +464,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern
auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
SetTensorDeviceInfo(kernel_node);
} else {
MS_LOG(WARNING) << " <<<";
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()

@ -29,7 +29,7 @@ enum KernelSelectStatus {
};
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node,
KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node);
void SetTensorDeviceInfo(const CNodePtr &kernel_node);
void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph);
} // namespace ascend
} // namespace device

@ -95,7 +95,7 @@ void UpdateKernelInfo(const std::vector<AnfNodePtr> &node_list) {
auto selected_kernel_info_ptr = kernel_info_list[index];
ResetKernelBuildInfo(cnode);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get());
SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode);
SetTensorDeviceInfo(cnode);
break;
}
}
@ -477,7 +477,7 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair
auto graph_selected_info = graph_info_builder.Build();
MS_EXCEPTION_IF_NULL(graph_selected_info);
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
SetTensorDeviceInfo(*graph_selected_info, kernel_node);
SetTensorDeviceInfo(kernel_node);
}
void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {

@ -327,6 +327,8 @@ constexpr auto kAttrSize = "size";
constexpr auto kAttrIsDynamicShape = "is_dynamic_shape";
constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape";
constexpr auto kAttrOutputIsDynamicShape = "output_is_dynamic_shape";
constexpr auto kAttrPynativeNextOpName = "next_op";
constexpr auto kAttrPynativeNextIndex = "next_index";
constexpr auto kAttrCompileInfo = "compile_info";
constexpr auto kAttrFusionType = "fusion_type";

Loading…
Cancel
Save