|
|
|
@ -256,41 +256,84 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
|
|
|
|
|
return RunOp(args)[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExecInfoPtr &op_exec_info) {
|
|
|
|
|
auto &out_args = op_exec_info->op_inputs;
|
|
|
|
|
auto signature = prim->signatures();
|
|
|
|
|
std::vector<SignatureEnumDType> dtypes;
|
|
|
|
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
|
|
|
|
[](const Signature &sig) { return sig.dtype; });
|
|
|
|
|
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
|
|
|
|
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
|
|
|
|
return;
|
|
|
|
|
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) {
|
|
|
|
|
auto tensor = py::cast<tensor::TensorPtr>(obj);
|
|
|
|
|
auto cast_type = tensor->cast_dtype();
|
|
|
|
|
py::object cast_output;
|
|
|
|
|
if (cast_type != nullptr) {
|
|
|
|
|
auto source_element = tensor->Dtype();
|
|
|
|
|
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
|
|
|
|
|
MS_LOG(DEBUG) << "cast to " << cast_type->ToString();
|
|
|
|
|
cast_output = DoAutoCast(obj, cast_type->type_id());
|
|
|
|
|
*is_cast = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return cast_output;
|
|
|
|
|
}
|
|
|
|
|
auto type_indexes = GetTypeIndex(dtypes);
|
|
|
|
|
auto dst_type = GetDstType(out_args, type_indexes);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < dtypes.size(); ++i) {
|
|
|
|
|
if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
|
|
|
|
continue;
|
|
|
|
|
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
|
|
|
|
|
auto tuple_size = static_cast<int>(tuple.size());
|
|
|
|
|
py::tuple result(tuple_size);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < tuple_size; i++) {
|
|
|
|
|
if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
|
|
|
|
|
MS_LOG(DEBUG) << "call cast for item " << i;
|
|
|
|
|
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]);
|
|
|
|
|
} else if (py::isinstance<py::tuple>(tuple[i])) {
|
|
|
|
|
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]);
|
|
|
|
|
}
|
|
|
|
|
auto it = dst_type.find(dtypes[i]);
|
|
|
|
|
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
|
|
|
|
|
auto signature = prim->signatures();
|
|
|
|
|
bool has_sig_dtype = false;
|
|
|
|
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
|
|
|
|
|
[&has_sig_dtype](const Signature &sig) {
|
|
|
|
|
auto dtype = sig.dtype;
|
|
|
|
|
if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
|
|
|
|
has_sig_dtype = true;
|
|
|
|
|
}
|
|
|
|
|
return dtype;
|
|
|
|
|
});
|
|
|
|
|
return has_sig_dtype;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
|
|
|
|
|
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info) {
|
|
|
|
|
const auto &signature = prim->signatures();
|
|
|
|
|
auto &out_args = op_exec_info->op_inputs;
|
|
|
|
|
bool has_dtype_sig = (dtypes.size() > 0);
|
|
|
|
|
for (size_t i = 0; i < out_args.size(); ++i) {
|
|
|
|
|
MS_LOG(DEBUG) << "check inputs " << i;
|
|
|
|
|
auto obj = out_args[i];
|
|
|
|
|
auto sig = signature[i].rw;
|
|
|
|
|
auto sig = SignatureEnumRW::kRWDefault;
|
|
|
|
|
if (signature.size() > 0) {
|
|
|
|
|
sig = signature[i].rw;
|
|
|
|
|
}
|
|
|
|
|
bool is_parameter = false;
|
|
|
|
|
bool is_same_type = false;
|
|
|
|
|
TypeId arg_type_id = kTypeUnknown;
|
|
|
|
|
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite);
|
|
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
|
|
|
|
auto arg = py::cast<tensor::MetaTensorPtr>(obj);
|
|
|
|
|
if (arg->is_parameter()) {
|
|
|
|
|
is_parameter = true;
|
|
|
|
|
MS_LOG(DEBUG) << "parameter is read " << i;
|
|
|
|
|
}
|
|
|
|
|
arg_type_id = arg->data_type();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// No need to implicit cast if no dtype.
|
|
|
|
|
if (!has_dtype_sig || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto it = dst_type.find(dtypes[i]);
|
|
|
|
|
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// implicit cast
|
|
|
|
|
bool is_same_type = false;
|
|
|
|
|
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite);
|
|
|
|
|
if (arg_type_id != 0) {
|
|
|
|
|
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
|
|
|
|
|
}
|
|
|
|
@ -317,7 +360,6 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe
|
|
|
|
|
}
|
|
|
|
|
py::object cast_output = DoAutoCast(out_args[i], it->second);
|
|
|
|
|
out_args[i] = cast_output;
|
|
|
|
|
ValuePtr input_value = PyAttrValue(cast_output);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -346,7 +388,6 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
|
|
|
|
op_exec_info->py_primitive = prim;
|
|
|
|
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
|
|
|
|
op_exec_info->op_inputs = args[PY_INPUTS];
|
|
|
|
|
ConvertInputs(prim, args[PY_INPUTS], op_exec_info);
|
|
|
|
|
return op_exec_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -697,11 +738,53 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|
|
|
|
inputs.push_back(NewValueNode(prim));
|
|
|
|
|
|
|
|
|
|
size_t size = op_exec_info->op_inputs.size();
|
|
|
|
|
auto const_input_index = prim->get_const_input_indexes();
|
|
|
|
|
bool have_const_input = !const_input_index.empty();
|
|
|
|
|
bool is_const_prim = prim->is_const_prim();
|
|
|
|
|
// ignore signature for cast op
|
|
|
|
|
bool is_cast_op = (op_exec_info->op_name == "Cast");
|
|
|
|
|
if (!is_cast_op) {
|
|
|
|
|
const auto &signature = prim->signatures();
|
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
|
auto obj = op_exec_info->op_inputs[i];
|
|
|
|
|
auto sig = SignatureEnumRW::kRWDefault;
|
|
|
|
|
if (signature.size() > 0) {
|
|
|
|
|
sig = signature[i].rw;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
|
|
|
|
|
<< std::string(py::repr(obj));
|
|
|
|
|
// mix precision for non param
|
|
|
|
|
bool is_cast = false;
|
|
|
|
|
py::object cast_output;
|
|
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
|
|
|
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
|
|
|
|
if (meta_tensor && meta_tensor->is_parameter()) {
|
|
|
|
|
if (sig != SignatureEnumRW::kRWRead) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// redundant cast call if the tensor is a const Tensor.
|
|
|
|
|
cast_output = DoParamMixPrecisionCast(&is_cast, obj);
|
|
|
|
|
} else if (py::isinstance<py::tuple>(obj)) {
|
|
|
|
|
// mix precision for tuple inputs
|
|
|
|
|
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj);
|
|
|
|
|
}
|
|
|
|
|
if (is_cast) {
|
|
|
|
|
op_exec_info->op_inputs[i] = cast_output;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<SignatureEnumDType> dtypes;
|
|
|
|
|
|
|
|
|
|
bool has_dtype_sig = GetSignatureType(prim, &dtypes);
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> dst_types;
|
|
|
|
|
if (has_dtype_sig) {
|
|
|
|
|
// fetch info for implicit cast
|
|
|
|
|
auto type_indexes = GetTypeIndex(dtypes);
|
|
|
|
|
dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name;
|
|
|
|
|
DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name;
|
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
|
const auto &obj = op_exec_info->op_inputs[i];
|
|
|
|
|
bool op_mask = false;
|
|
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
|
|
|
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
|
|
|
@ -709,9 +792,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|
|
|
|
op_mask = meta_tensor->is_parameter();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
(*op_masks).push_back(op_mask);
|
|
|
|
|
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
|
|
|
|
|
MS_LOG(DEBUG) << "gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ "
|
|
|
|
|
<< grad_flag_;
|
|
|
|
|
|
|
|
|
|
AnfNodePtr node = nullptr;
|
|
|
|
@ -726,6 +808,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|
|
|
|
if (node != nullptr && node->abstract() != nullptr) {
|
|
|
|
|
abs = node->abstract();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto const_input_index = prim->get_const_input_indexes();
|
|
|
|
|
bool have_const_input = !const_input_index.empty();
|
|
|
|
|
bool is_const_prim = prim->is_const_prim();
|
|
|
|
|
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
|
|
|
|
|
<< prim->is_const_prim();
|
|
|
|
|
bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i);
|
|
|
|
|