|
|
@ -189,7 +189,7 @@ bool TbeKernelJsonCreator::GenInputList(const std::shared_ptr<AnfNode> &anf_node
|
|
|
|
input_list->emplace_back(input_desc_json);
|
|
|
|
input_list->emplace_back(input_desc_json);
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MS_LOG(ERROR) << "input num: " << *real_input_index << " is not match op inputs";
|
|
|
|
MS_LOG(ERROR) << "Input num: " << *real_input_index << " is not match op inputs";
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (op_name == "BatchNorm") {
|
|
|
|
if (op_name == "BatchNorm") {
|
|
|
@ -197,7 +197,7 @@ bool TbeKernelJsonCreator::GenInputList(const std::shared_ptr<AnfNode> &anf_node
|
|
|
|
auto attr = primitive->GetAttr("is_training");
|
|
|
|
auto attr = primitive->GetAttr("is_training");
|
|
|
|
MS_EXCEPTION_IF_NULL(attr);
|
|
|
|
MS_EXCEPTION_IF_NULL(attr);
|
|
|
|
bool is_training = GetValue<bool>(attr);
|
|
|
|
bool is_training = GetValue<bool>(attr);
|
|
|
|
MS_LOG(INFO) << "op_name" << op_name << ", tensor_name " << input_ptr->name() << ", is_training "
|
|
|
|
MS_LOG(INFO) << "Op_name" << op_name << ", tensor_name " << input_ptr->name() << ", is_training "
|
|
|
|
<< is_training;
|
|
|
|
<< is_training;
|
|
|
|
if (is_training) {
|
|
|
|
if (is_training) {
|
|
|
|
(*real_input_index)++;
|
|
|
|
(*real_input_index)++;
|
|
|
@ -230,7 +230,7 @@ bool GetInputNameAndRealNum(const std::shared_ptr<AnfNode> &anf_node, const std:
|
|
|
|
|
|
|
|
|
|
|
|
if (input_ptr->param_type() == kParamDynamic) {
|
|
|
|
if (input_ptr->param_type() == kParamDynamic) {
|
|
|
|
if (*dyn_input_index >= dyn_input_sizes.size()) {
|
|
|
|
if (*dyn_input_index >= dyn_input_sizes.size()) {
|
|
|
|
MS_LOG(ERROR) << "dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size();
|
|
|
|
MS_LOG(ERROR) << "Dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size();
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
*input_num = IntToSize(dyn_input_sizes[*dyn_input_index]);
|
|
|
|
*input_num = IntToSize(dyn_input_sizes[*dyn_input_index]);
|
|
|
@ -314,7 +314,7 @@ bool TbeKernelJsonCreator::GenOutputDescJson(
|
|
|
|
output_obj_num = real_output_num;
|
|
|
|
output_obj_num = real_output_num;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
if (output_idx >= real_output_num) {
|
|
|
|
if (output_idx >= real_output_num) {
|
|
|
|
MS_LOG(INFO) << "op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none.";
|
|
|
|
MS_LOG(INFO) << "Op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none.";
|
|
|
|
std::vector<nlohmann::json> output_list;
|
|
|
|
std::vector<nlohmann::json> output_list;
|
|
|
|
nlohmann::json output_obj;
|
|
|
|
nlohmann::json output_obj;
|
|
|
|
output_obj[kJName] = output_ptr->name();
|
|
|
|
output_obj[kJName] = output_ptr->name();
|
|
|
@ -389,7 +389,7 @@ bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_no
|
|
|
|
attr_obj[kJValid] = false;
|
|
|
|
attr_obj[kJValid] = false;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
if (attr_ptr->param_type() == kParamRequred && creater_type_ == SINGLE_BUILD) {
|
|
|
|
if (attr_ptr->param_type() == kParamRequred && creater_type_ == SINGLE_BUILD) {
|
|
|
|
MS_LOG(EXCEPTION) << "op name: " << op_info->op_name() << " attr: " << attr_name
|
|
|
|
MS_LOG(EXCEPTION) << "Op name: " << op_info->op_name() << " attr: " << attr_name
|
|
|
|
<< " is required, but not set.";
|
|
|
|
<< " is required, but not set.";
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
attr_obj[kJValid] = false;
|
|
|
|
attr_obj[kJValid] = false;
|
|
|
@ -451,7 +451,7 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo
|
|
|
|
auto attr_value = GetValue<std::vector<std::vector<int>>>(value);
|
|
|
|
auto attr_value = GetValue<std::vector<std::vector<int>>>(value);
|
|
|
|
(*attr_obj)[kJValue] = attr_value;
|
|
|
|
(*attr_obj)[kJValue] = attr_value;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
MS_LOG(EXCEPTION) << "type: " << type << "not support";
|
|
|
|
MS_LOG(EXCEPTION) << "Type: " << type << "not support";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -536,7 +536,7 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no
|
|
|
|
bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
|
|
|
|
bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
|
|
|
|
std::vector<size_t> *output_size_list) {
|
|
|
|
std::vector<size_t> *output_size_list) {
|
|
|
|
if (input_size_list == nullptr || output_size_list == nullptr) {
|
|
|
|
if (input_size_list == nullptr || output_size_list == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "input size or output size is nullptr";
|
|
|
|
MS_LOG(ERROR) << "Input size or output size is nullptr";
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
input_size_list->clear();
|
|
|
|
input_size_list->clear();
|
|
|
@ -750,7 +750,7 @@ bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr<mindspore::Anf
|
|
|
|
MS_EXCEPTION_IF_NULL(index);
|
|
|
|
MS_EXCEPTION_IF_NULL(index);
|
|
|
|
std::vector<nlohmann::json> output_desc_list;
|
|
|
|
std::vector<nlohmann::json> output_desc_list;
|
|
|
|
if (!data_input) {
|
|
|
|
if (!data_input) {
|
|
|
|
MS_LOG(INFO) << "data input is optional node";
|
|
|
|
MS_LOG(INFO) << "Data input is optional node";
|
|
|
|
auto name = std::string(kOptional) + std::to_string(*index);
|
|
|
|
auto name = std::string(kOptional) + std::to_string(*index);
|
|
|
|
(*data_str)[kJName] = name;
|
|
|
|
(*data_str)[kJName] = name;
|
|
|
|
nlohmann::json output_desc;
|
|
|
|
nlohmann::json output_desc;
|
|
|
@ -766,7 +766,7 @@ bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr<mindspore::Anf
|
|
|
|
auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0);
|
|
|
|
auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0);
|
|
|
|
auto real_node = kernel_idx.first;
|
|
|
|
auto real_node = kernel_idx.first;
|
|
|
|
size_t real_idx = kernel_idx.second;
|
|
|
|
size_t real_idx = kernel_idx.second;
|
|
|
|
MS_LOG(INFO) << "real name " << real_node->fullname_with_scope() << " index:" << real_idx;
|
|
|
|
MS_LOG(INFO) << "Real name " << real_node->fullname_with_scope() << " index:" << real_idx;
|
|
|
|
// kJOutputDesc
|
|
|
|
// kJOutputDesc
|
|
|
|
nlohmann::json output_desc;
|
|
|
|
nlohmann::json output_desc;
|
|
|
|
GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type);
|
|
|
|
GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type);
|
|
|
@ -842,18 +842,18 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
|
|
|
|
auto kernel_idx = AnfAlgo::VisitKernel(input, 0);
|
|
|
|
auto kernel_idx = AnfAlgo::VisitKernel(input, 0);
|
|
|
|
auto real_node = kernel_idx.first;
|
|
|
|
auto real_node = kernel_idx.first;
|
|
|
|
size_t real_idx = kernel_idx.second;
|
|
|
|
size_t real_idx = kernel_idx.second;
|
|
|
|
MS_LOG(INFO) << "real name" << real_node->fullname_with_scope() << "index:" << real_idx;
|
|
|
|
MS_LOG(INFO) << "Real name" << real_node->fullname_with_scope() << "index:" << real_idx;
|
|
|
|
nlohmann::json input_desc;
|
|
|
|
nlohmann::json input_desc;
|
|
|
|
GenDescJson(real_node, real_idx, real_idx, &input_desc);
|
|
|
|
GenDescJson(real_node, real_idx, real_idx, &input_desc);
|
|
|
|
if (is_dynamic_input) {
|
|
|
|
if (is_dynamic_input) {
|
|
|
|
MS_LOG(INFO) << "node has dynamic input.";
|
|
|
|
MS_LOG(INFO) << "Node has dynamic input.";
|
|
|
|
input_desc[kJDynIndex] = (i - 1);
|
|
|
|
input_desc[kJDynIndex] = (i - 1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
input_desc_list_tmp.emplace_back(input_desc);
|
|
|
|
input_desc_list_tmp.emplace_back(input_desc);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
size_t optional_num = GetOptionalInput(cnode, is_dynamic_input);
|
|
|
|
size_t optional_num = GetOptionalInput(cnode, is_dynamic_input);
|
|
|
|
if (optional_num > 0) {
|
|
|
|
if (optional_num > 0) {
|
|
|
|
MS_LOG(INFO) << "node has optional input.";
|
|
|
|
MS_LOG(INFO) << "Node has optional input.";
|
|
|
|
for (size_t i = 0; i < optional_num; ++i) {
|
|
|
|
for (size_t i = 0; i < optional_num; ++i) {
|
|
|
|
nlohmann::json optional_input_desc;
|
|
|
|
nlohmann::json optional_input_desc;
|
|
|
|
optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index);
|
|
|
|
optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index);
|
|
|
@ -871,7 +871,7 @@ std::vector<size_t> TbeKernelBuild::GetDescOutputIndex(const std::vector<int> &o
|
|
|
|
std::vector<size_t> desc_output_index = {};
|
|
|
|
std::vector<size_t> desc_output_index = {};
|
|
|
|
for (size_t idx = 0; idx < output_used_nums.size(); ++idx) {
|
|
|
|
for (size_t idx = 0; idx < output_used_nums.size(); ++idx) {
|
|
|
|
auto output_use_num_item = output_used_nums[idx];
|
|
|
|
auto output_use_num_item = output_used_nums[idx];
|
|
|
|
MS_LOG(INFO) << "output used num[" << idx << "] = " << output_use_num_item;
|
|
|
|
MS_LOG(INFO) << "Output used num[" << idx << "] = " << output_use_num_item;
|
|
|
|
desc_output_index.emplace_back(idx);
|
|
|
|
desc_output_index.emplace_back(idx);
|
|
|
|
if (output_use_num_item > 1) {
|
|
|
|
if (output_use_num_item > 1) {
|
|
|
|
desc_output_index.emplace_back(idx);
|
|
|
|
desc_output_index.emplace_back(idx);
|
|
|
@ -990,7 +990,7 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list,
|
|
|
|
auto op_output_desces = op[kJOutputDesc];
|
|
|
|
auto op_output_desces = op[kJOutputDesc];
|
|
|
|
if (output_node != real_node) {
|
|
|
|
if (output_node != real_node) {
|
|
|
|
// tuple_get item
|
|
|
|
// tuple_get item
|
|
|
|
MS_LOG(INFO) << "output is a tuple getitem node";
|
|
|
|
MS_LOG(INFO) << "Output is a tuple getitem node";
|
|
|
|
auto output_desc = op_output_desces[real_idx];
|
|
|
|
auto output_desc = op_output_desces[real_idx];
|
|
|
|
if (output_desc[kJShape].empty()) {
|
|
|
|
if (output_desc[kJShape].empty()) {
|
|
|
|
MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx;
|
|
|
|
MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx;
|
|
|
|