change feature map flag

pull/9087/head
LianLiguang 4 years ago
parent 1321483749
commit 687cd623f9

@ -187,7 +187,7 @@ class CNodeDecoder {
if ((node.first)->isa<Parameter>()) { if ((node.first)->isa<Parameter>()) {
auto parameter = (node.first)->cast<ParameterPtr>(); auto parameter = (node.first)->cast<ParameterPtr>();
bool is_weight = AnfAlgo::IsParameterWeight(parameter); bool is_weight = AnfAlgo::IsParameterWeight(parameter);
kernel_info->SetFeatureMapFlag(!is_weight); kernel_info->set_feature_map_flag(!is_weight);
if (!is_weight) { if (!is_weight) {
feature_map_input_indexs.push_back(index - 1); feature_map_input_indexs.push_back(index - 1);
} }
@ -200,7 +200,7 @@ class CNodeDecoder {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode_); AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode_);
} }
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true); kernel_info->set_feature_map_flag(true);
} }
if (AnfAlgo::IsRealCNodeKernel(cnode_)) { if (AnfAlgo::IsRealCNodeKernel(cnode_)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode_); AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode_);

@ -31,6 +31,8 @@ namespace session {
namespace { namespace {
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
prim::kPrimAssignSub->name()};
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes) { std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -417,21 +419,41 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
} }
} }
void KernelGraph::ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const {
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
"flag but got the node :"
<< cnode->DebugString();
}
auto input_node = AnfAlgo::GetInputNode(cnode, 0);
auto assign_value_node = AnfAlgo::GetInputNode(cnode, 1);
if (AnfAlgo::IsFeatureMapOutput(input_node)) {
return;
}
if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) {
auto kernel_info = static_cast<device::KernelInfo *>(input_node->kernel_info());
kernel_info->set_feature_map_flag(true);
}
}
void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto kernel_info = std::make_shared<device::KernelInfo>(); auto kernel_info = std::make_shared<device::KernelInfo>();
node->set_kernel_info(kernel_info); node->set_kernel_info(kernel_info);
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
ResetAssignInputFeaatureMapFlag(node->cast<CNodePtr>());
}
std::vector<size_t> feature_map_input_indexs; std::vector<size_t> feature_map_input_indexs;
kernel_info->SetFeatureMapFlag(false); kernel_info->set_feature_map_flag(false);
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
if (AnfAlgo::IsFeatureMapInput(node, index)) { if (AnfAlgo::IsFeatureMapInput(node, index)) {
kernel_info->SetFeatureMapFlag(true); kernel_info->set_feature_map_flag(true);
feature_map_input_indexs.push_back(index); feature_map_input_indexs.push_back(index);
} }
} }
if (AnfAlgo::GetInputTensorNum(node) == 0) { if (AnfAlgo::GetInputTensorNum(node) == 0) {
kernel_info->SetFeatureMapFlag(true); kernel_info->set_feature_map_flag(true);
} }
if (AnfAlgo::IsRealKernel(node)) { if (AnfAlgo::IsRealKernel(node)) {
// if the node only has the primitive(such as getNext) or the node's input has a feature map input // if the node only has the primitive(such as getNext) or the node's input has a feature map input
@ -446,7 +468,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
std::vector<TypeId> types; std::vector<TypeId> types;
std::vector<std::string> formats = {kOpFormat_DEFAULT}; std::vector<std::string> formats = {kOpFormat_DEFAULT};
if (node->isa<ValueNode>()) { if (node->isa<ValueNode>()) {
kernel_info->SetFeatureMapFlag(false); kernel_info->set_feature_map_flag(false);
types.emplace_back(kTypeUnknown); types.emplace_back(kTypeUnknown);
auto value_node = node->cast<ValueNodePtr>(); auto value_node = node->cast<ValueNodePtr>();
SyncDeviceInfoToValueNode(value_node, &formats, &types); SyncDeviceInfoToValueNode(value_node, &formats, &types);
@ -455,7 +477,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
auto parameter = node->cast<ParameterPtr>(); auto parameter = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
bool is_weight = AnfAlgo ::IsParameterWeight(parameter); bool is_weight = AnfAlgo ::IsParameterWeight(parameter);
kernel_info->SetFeatureMapFlag(!is_weight); kernel_info->set_feature_map_flag(!is_weight);
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0)); types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
} }
// set parameter initaial device data type // set parameter initaial device data type

@ -100,6 +100,7 @@ class KernelGraph : public FuncGraph {
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override; CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
CNodePtr NewCNode(const CNodePtr &cnode); CNodePtr NewCNode(const CNodePtr &cnode);
void ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const;
ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr); ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract); ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value); ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value);

@ -837,7 +837,6 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
new_value_node->set_abstract(value_node->abstract()); new_value_node->set_abstract(value_node->abstract());
// create new kernel_info of new value_node // create new kernel_info of new value_node
auto kernel_info = std::make_shared<device::KernelInfo>(); auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->SetFeatureMapFlag(false);
new_value_node->set_kernel_info(kernel_info); new_value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node // create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();

@ -48,7 +48,7 @@ class KernelInfo : public KernelInfoDevice {
void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
select_kernel_build_info_ = select_kernel_build_info; select_kernel_build_info_ = select_kernel_build_info;
} }
void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; } void set_feature_map_flag(bool flag) { is_feature_map_ = flag; }
const DeviceAddress *GetOutputAddr(size_t index) const; const DeviceAddress *GetOutputAddr(size_t index) const;
DeviceAddressPtr GetMutableOutputAddr(size_t index) const; DeviceAddressPtr GetMutableOutputAddr(size_t index) const;
bool OutputAddrExist(size_t index) const; bool OutputAddrExist(size_t index) const;

Loading…
Cancel
Save