Using Cache and Code Optimization to Improve Performance of AMP

pull/13033/head
chenyijie6 4 years ago
parent f733d8a746
commit 03fad862dd

@ -77,6 +77,7 @@ PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
ForwardExecutorPtr PynativeExecutor::forward_executor_ = nullptr;
GradExecutorPtr PynativeExecutor::grad_executor_ = nullptr;
std::mutex PynativeExecutor::instance_lock_;
constexpr auto implcast = "implcast";
template <typename T, typename... Args>
void PynativeExecutorTry(std::function<void(T *ret, const Args &...)> method, T *ret, const Args &... args) {
@ -272,33 +273,42 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
for (size_t index = 0; index < input_tensors.size(); ++index) {
MS_EXCEPTION_IF_NULL(input_tensors[index]);
auto tensor_shape = input_tensors[index]->shape();
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
[&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
(void)graph_info.append(std::to_string(input_tensors[index]->data_type()) + "_");
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(), [&](const auto &dim) {
(void)graph_info.append(std::to_string(dim));
graph_info += "_";
});
(void)graph_info.append(std::to_string(input_tensors[index]->data_type()));
graph_info += "_";
auto tensor_addr = input_tensors[index]->device_address();
if (tensor_addr != nullptr) {
(void)graph_info.append(std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->type_id()) +
"_");
(void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->format() + "_");
(void)graph_info.append(std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->type_id()));
graph_info += "_";
(void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->format());
graph_info += "_";
}
if (static_cast<int64_t>(op_exec_info->inputs_mask[index]) == kValueNodeTensorMask) {
if (input_tensors[index]->Dtype()->type_id() == kNumberTypeInt64) {
(void)graph_info.append(std::to_string(*reinterpret_cast<int *>(input_tensors[index]->data_c())) + "_");
(void)graph_info.append(std::to_string(*reinterpret_cast<int *>(input_tensors[index]->data_c())));
graph_info += "_";
} else if (input_tensors[index]->Dtype()->type_id() == kNumberTypeFloat32) {
(void)graph_info.append(std::to_string(*reinterpret_cast<float *>(input_tensors[index]->data_c())) + "_");
(void)graph_info.append(std::to_string(*reinterpret_cast<float *>(input_tensors[index]->data_c())));
graph_info += "_";
} else {
MS_LOG(EXCEPTION) << "The dtype of the constant input is not int64 or float32!";
}
}
}
// get prim and abstract info
(void)graph_info.append(op_exec_info->op_name + "_");
graph_info += (op_exec_info->op_name);
graph_info += "_";
// get attr info
const auto &op_prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(op_prim);
const auto &attr_map = op_prim->attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
(void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) {
graph_info += (element.second->ToString());
graph_info += "_";
});
// Add output information(shape, type id) of the operator to graph_info to solve the problem of cache missing
// caused by operators like DropoutGenMask whose output is related to values of input when input shapes are
@ -307,10 +317,12 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL(abstr);
auto build_shape = abstr->BuildShape();
MS_EXCEPTION_IF_NULL(build_shape);
(void)graph_info.append(build_shape->ToString() + "_");
graph_info += (build_shape->ToString());
graph_info += "_";
auto build_type = abstr->BuildType();
MS_EXCEPTION_IF_NULL(build_type);
(void)graph_info.append(std::to_string(build_type->type_id()) + "_");
graph_info += std::to_string(build_type->type_id());
graph_info += "_";
return graph_info;
}
@ -681,6 +693,26 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
return op_exec_info;
}
bool ForwardExecutor::FindOpMask(py::object obj, std::vector<int64_t> *op_masks, std::string id) {
bool op_mask = false;
auto temp = op_mask_map_.find(id);
if (temp != op_mask_map_.end()) {
op_mask = temp->second;
(*op_masks).emplace_back(op_mask);
} else {
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor) {
op_mask = meta_tensor->is_parameter();
}
}
MS_LOG(DEBUG) << "Gen args op_mask " << op_mask;
op_mask_map_[id] = op_mask;
(*op_masks).emplace_back(op_mask);
}
return op_mask;
}
void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
auto prim = op_exec_info->py_primitive;
@ -692,15 +724,8 @@ void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector
if (it != node_abs_map_.end()) {
abs = it->second;
}
bool op_mask = false;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor) {
op_mask = meta_tensor->is_parameter();
}
}
MS_LOG(DEBUG) << "Gen args i " << i << " op_mask " << op_mask;
(*op_masks).emplace_back(op_mask);
// Find the opmask of input obj
bool op_mask = FindOpMask(obj, op_masks, id);
// Construct grad graph
if (grad()->need_construct_graph()) {
@ -794,16 +819,19 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
auto op_name = op_exec_info->op_name;
auto prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(prim);
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
auto abs_list = prim_abs_list_[prim->id()];
auto temp = prim_abs_list_.find(prim->id());
if (temp != prim_abs_list_.end()) {
MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
if (abs_list.find(args_spec_list) != abs_list.end()) {
auto iter = temp->second.find(args_spec_list);
if (iter != temp->second.end()) {
MS_LOG(DEBUG) << "Match prim ok " << op_name;
op_exec_info->abstract = abs_list[args_spec_list].abs;
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
op_exec_info->abstract = iter->second.abs;
prim->set_evaluate_added_attrs(iter->second.attrs);
*is_find = true;
}
}
if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
// use python infer method
if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
@ -822,7 +850,7 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
}
py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
size_t index) {
size_t index, const std::string &obj_id) {
py::tuple cast_args(3);
cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
cast_args[PY_NAME] = prim::kPrimCast->name();
@ -836,6 +864,10 @@ py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type
op_exec->is_mixed_precision_cast = true;
op_exec->next_op_name = op_name;
op_exec->next_input_index = index;
// Cache the cast struct
if (obj_id != implcast) {
cast_struct_map_[obj_id] = op_exec;
}
py::object ret = py::none();
RunOpInner(&ret, op_exec);
return ret;
@ -852,7 +884,20 @@ py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::obj
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
MS_LOG(DEBUG) << "Cast to " << cast_type->ToString();
*is_cast = true;
return DoAutoCast(obj, cast_type->type_id(), op_name, index);
// Get obj id
auto id = GetId(obj);
// Find obj id in unorder map
auto cast_struct_pair = cast_struct_map_.find(id);
if (cast_struct_pair != cast_struct_map_.end()) {
// Update input for cast struct
auto cast_struct = cast_struct_pair->second;
cast_struct->op_inputs[0] = obj;
py::object ret = py::none();
RunOpInner(&ret, cast_struct);
return ret;
} else {
return DoAutoCast(obj, cast_type->type_id(), op_name, index, id);
}
}
}
return cast_output;
@ -933,7 +978,7 @@ void ForwardExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map
<< py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
<< py::cast<py::str>(obj) << ".";
}
py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i);
py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i, implcast);
out_args[i] = cast_output;
}
}
@ -1474,6 +1519,8 @@ void ForwardExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear forward res";
prim_abs_list_.clear();
node_abs_map_.clear();
cast_struct_map_.clear();
op_mask_map_.clear();
cell_op_index_with_tensor_id_.clear();
cell_tensor_id_with_tensor_.clear();
}

@ -406,6 +406,7 @@ class ForwardExecutor {
PynativeStatusCode *status);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
bool FindOpMask(py::object obj, std::vector<int64_t> *op_masks, std::string id);
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
@ -420,7 +421,8 @@ class ForwardExecutor {
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index);
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple, const std::string &op_name,
size_t index);
py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index);
py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index,
const std::string &obj_id);
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info);
@ -431,6 +433,10 @@ class ForwardExecutor {
// Used for runop and replace forward result of grad graph
std::unordered_map<std::string, OpIndexWithTensorId> cell_op_index_with_tensor_id_;
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
// Used to cache cast struct
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
// Used to cache op_mask
std::unordered_map<std::string, int64_t> op_mask_map_;
};
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {

@ -184,9 +184,8 @@ AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op
if (op_attr_map_it == PrimAttrConvertMap.end()) {
return attr_pair;
}
auto op_attr_map = op_attr_map_it->second;
auto attr_pair_it = op_attr_map.find(attr_name);
if (attr_pair_it == op_attr_map.end()) {
auto attr_pair_it = op_attr_map_it->second.find(attr_name);
if (attr_pair_it == op_attr_map_it->second.end()) {
return attr_pair;
}

Loading…
Cancel
Save