|
|
|
@ -31,6 +31,8 @@ namespace session {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
|
|
|
|
|
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
|
|
|
|
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
|
|
|
|
|
prim::kPrimAssignSub->name()};
|
|
|
|
|
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
|
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
@ -417,21 +419,41 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const {
|
|
|
|
|
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
|
|
|
|
|
"flag but got the node :"
|
|
|
|
|
<< cnode->DebugString();
|
|
|
|
|
}
|
|
|
|
|
auto input_node = AnfAlgo::GetInputNode(cnode, 0);
|
|
|
|
|
auto assign_value_node = AnfAlgo::GetInputNode(cnode, 1);
|
|
|
|
|
if (AnfAlgo::IsFeatureMapOutput(input_node)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) {
|
|
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(input_node->kernel_info());
|
|
|
|
|
kernel_info->set_feature_map_flag(true);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
|
|
|
node->set_kernel_info(kernel_info);
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
|
|
|
|
|
ResetAssignInputFeaatureMapFlag(node->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> feature_map_input_indexs;
|
|
|
|
|
kernel_info->SetFeatureMapFlag(false);
|
|
|
|
|
kernel_info->set_feature_map_flag(false);
|
|
|
|
|
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
|
|
|
|
|
if (AnfAlgo::IsFeatureMapInput(node, index)) {
|
|
|
|
|
kernel_info->SetFeatureMapFlag(true);
|
|
|
|
|
kernel_info->set_feature_map_flag(true);
|
|
|
|
|
feature_map_input_indexs.push_back(index);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::GetInputTensorNum(node) == 0) {
|
|
|
|
|
kernel_info->SetFeatureMapFlag(true);
|
|
|
|
|
kernel_info->set_feature_map_flag(true);
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::IsRealKernel(node)) {
|
|
|
|
|
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
|
|
|
|
@ -446,7 +468,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
|
|
|
|
std::vector<TypeId> types;
|
|
|
|
|
std::vector<std::string> formats = {kOpFormat_DEFAULT};
|
|
|
|
|
if (node->isa<ValueNode>()) {
|
|
|
|
|
kernel_info->SetFeatureMapFlag(false);
|
|
|
|
|
kernel_info->set_feature_map_flag(false);
|
|
|
|
|
types.emplace_back(kTypeUnknown);
|
|
|
|
|
auto value_node = node->cast<ValueNodePtr>();
|
|
|
|
|
SyncDeviceInfoToValueNode(value_node, &formats, &types);
|
|
|
|
@ -455,7 +477,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
|
|
|
|
auto parameter = node->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
bool is_weight = AnfAlgo ::IsParameterWeight(parameter);
|
|
|
|
|
kernel_info->SetFeatureMapFlag(!is_weight);
|
|
|
|
|
kernel_info->set_feature_map_flag(!is_weight);
|
|
|
|
|
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
|
|
|
|
|
}
|
|
|
|
|
// set parameter initaial device data type
|
|
|
|
|