|
|
|
@ -9,143 +9,25 @@
|
|
|
|
|
*
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.l
|
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "graph/passes/assign_pass.h"
|
|
|
|
|
|
|
|
|
|
#include "framework/common/debug/ge_log.h"
|
|
|
|
|
#include "framework/common/debug/log.h"
|
|
|
|
|
#include "graph/utils/graph_utils.h"
|
|
|
|
|
#include "graph/debug/ge_attr_define.h"
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr uint32_t kValidInputNodeOutputNum = 1;
|
|
|
|
|
constexpr int32_t kAssignRefInputIndex = 0;
|
|
|
|
|
constexpr int32_t kAssignValueInputIndex = 1;
|
|
|
|
|
static const std::set<std::string> kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
|
|
|
|
|
ge::CONSTANT, ge::CONSTANTOP,
|
|
|
|
|
ge::VARIABLE, ge::VARIABLEV2 };
|
|
|
|
|
const uint32_t kValidInputNodeOutputNum = 1;
|
|
|
|
|
const int32_t kAssignRefInputIndex = 0;
|
|
|
|
|
const int32_t kAssignValueInputIndex = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
|
#ifndef ONLY_COMPILE_OPEN_SRC
|
|
|
|
|
Status AssignPass::Run(NodePtr &node) {
|
|
|
|
|
GELOGD("AssignPass running");
|
|
|
|
|
|
|
|
|
|
if (TransformAttr(node) != SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "Transform assign_var_name attr failed, node=%s", node->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (node->GetType() == ASSIGN) {
|
|
|
|
|
if (OptimizedAssignNode(node) != SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "Optimize for assign_node %s failed", node->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GELOGD("AssignPass success");
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|
/// @brief Optimize for assign_node
|
|
|
|
|
/// @param [in] assign_node
|
|
|
|
|
/// @return Status
|
|
|
|
|
///
|
|
|
|
|
Status AssignPass::OptimizedAssignNode(NodePtr &assign_node) {
|
|
|
|
|
const auto &ref_in_anchor = assign_node->GetInDataAnchor(kAssignRefInputIndex);
|
|
|
|
|
const auto &value_in_anchor = assign_node->GetInDataAnchor(kAssignValueInputIndex);
|
|
|
|
|
if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) {
|
|
|
|
|
GELOGE(FAILED, "In data anchor is null, node:%s", assign_node->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor();
|
|
|
|
|
const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor();
|
|
|
|
|
if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) {
|
|
|
|
|
GELOGE(FAILED, "Peer data anchor is null, node:%s", assign_node->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsCondMatch(assign_node, ref_peer_anchor, value_peer_anchor)) {
|
|
|
|
|
///
|
|
|
|
|
/// variable not-const not-const
|
|
|
|
|
/// \ / |
|
|
|
|
|
/// \ / |
|
|
|
|
|
/// Assign ----> variable
|
|
|
|
|
/// | |
|
|
|
|
|
/// | |
|
|
|
|
|
/// node node
|
|
|
|
|
///
|
|
|
|
|
GELOGD("Optimization for assign_node %s start", assign_node->GetName().c_str());
|
|
|
|
|
if (IsolateAndDeleteNode(assign_node, {kAssignRefInputIndex}) != SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "Isolate and delete assign_node %s failed.", assign_node->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc();
|
|
|
|
|
const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc();
|
|
|
|
|
if ((ref_input == nullptr) || (value_input == nullptr)) {
|
|
|
|
|
GELOGE(FAILED, "value input is null");
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// variable has and only has one input
|
|
|
|
|
if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "Update input_desc for variable %s failed.", ref_input->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "Add data edge %s->%s failed", value_input->GetName().c_str(), ref_input->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s",
|
|
|
|
|
value_input->GetName().c_str(), ref_input->GetName().c_str());
|
|
|
|
|
if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME,
|
|
|
|
|
ref_input->GetName())) {
|
|
|
|
|
GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed.");
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
auto value_node = value_peer_anchor->GetOwnerNode();
|
|
|
|
|
AddRePassNode(value_node);
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|
/// @brief Transform assign_var_name attr
|
|
|
|
|
/// @param [in] node
|
|
|
|
|
/// @return Status
|
|
|
|
|
///
|
|
|
|
|
Status AssignPass::TransformAttr(NodePtr &node) {
|
|
|
|
|
GE_CHECK_NOTNULL(node->GetOpDesc());
|
|
|
|
|
for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDesc()) {
|
|
|
|
|
int32_t inplace_input_idx = -1;
|
|
|
|
|
std::string assign_var_name;
|
|
|
|
|
if (AttrUtils::GetInt(output_desc, INPLACE_SUPPORT_INPUT_INDEX, inplace_input_idx) &&
|
|
|
|
|
AttrUtils::GetStr(output_desc, ASSIGN_VAR_NAME, assign_var_name)) {
|
|
|
|
|
GELOGD("Transform attr ASSIGN_VAR_NAME on node %s, assign_var_name=%s, inplace_input_idx=%d, ",
|
|
|
|
|
node->GetName().c_str(), assign_var_name.c_str(), inplace_input_idx);
|
|
|
|
|
const auto &in_data_anchor = node->GetInDataAnchor(inplace_input_idx);
|
|
|
|
|
GE_CHECK_NOTNULL(in_data_anchor);
|
|
|
|
|
const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
|
|
|
|
|
GE_CHECK_NOTNULL(peer_data_anchor);
|
|
|
|
|
auto in_node = peer_data_anchor->GetOwnerNode();
|
|
|
|
|
GE_CHECK_NOTNULL(in_node->GetOpDesc());
|
|
|
|
|
GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s", in_node->GetName().c_str(), assign_var_name.c_str());
|
|
|
|
|
if (!AttrUtils::SetStr(in_node->GetOpDesc()->MutableOutputDesc(peer_data_anchor->GetIdx()),
|
|
|
|
|
ASSIGN_VAR_NAME, assign_var_name)) {
|
|
|
|
|
GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed.");
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
AddRePassNode(in_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
Status AssignPass::Run(NodePtr &node) {
|
|
|
|
|
GELOGD("AssignPass running");
|
|
|
|
|
if (node->GetType() != ASSIGN) {
|
|
|
|
@ -209,7 +91,7 @@ Status AssignPass::Run(NodePtr &node) {
|
|
|
|
|
GELOGD("AssignPass success");
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|
/// @brief Check if need optimize for assign_node
|
|
|
|
|
/// @param [in] assign_node
|
|
|
|
@ -223,8 +105,9 @@ bool AssignPass::IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_pe
|
|
|
|
|
node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(),
|
|
|
|
|
value_peer_anchor->GetOwnerNode()->GetName().c_str());
|
|
|
|
|
|
|
|
|
|
if (kNoTaskNodeTypes.count(value_peer_anchor->GetOwnerNode()->GetType()) > 0) {
|
|
|
|
|
GELOGD("value input is not calculate node");
|
|
|
|
|
const std::string &value_type = value_peer_anchor->GetOwnerNode()->GetType();
|
|
|
|
|
if ((value_type == CONSTANTOP) || (value_type == CONSTANT)) {
|
|
|
|
|
GELOGD("value input is const");
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|