|
|
@ -13,7 +13,7 @@
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
#include "backend/optimizer/graph_kernel/split_assign.h"
|
|
|
|
#include "backend/optimizer/graph_kernel/split_umonad.h"
|
|
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
@ -35,31 +35,63 @@ const BaseRef SplitAssign::DefinePattern() const {
|
|
|
|
return VectorRef({v, Xs, Us, UMonad});
|
|
|
|
return VectorRef({v, Xs, Us, UMonad});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool CanSplit(const AnfNodePtr &node) {
|
|
|
|
bool CanSplit(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimAssign); }
|
|
|
|
return IsPrimitiveCNode(node, prim::kPrimAssignAdd) || IsPrimitiveCNode(node, prim::kPrimAssign) ||
|
|
|
|
|
|
|
|
IsPrimitiveCNode(node, prim::kPrimAssignSub);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
|
|
|
AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int input_idx) {
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
if (!CanSplit(node)) return node;
|
|
|
|
|
|
|
|
CNodePtr cnode = node->cast<CNodePtr>();
|
|
|
|
CNodePtr cnode = node->cast<CNodePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
CheckCNodeInputSize(cnode, kAssignInputTensorNum);
|
|
|
|
|
|
|
|
// Get original assign op's abstract and inputs
|
|
|
|
// Get original op's abstract and inputs
|
|
|
|
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
|
|
|
|
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
|
|
|
|
auto original_inputs = cnode->inputs();
|
|
|
|
auto original_inputs = cnode->inputs();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int input_node_size = cnode->size() - 1;
|
|
|
|
// Create depend node
|
|
|
|
// Create depend node
|
|
|
|
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[1], original_inputs[3]};
|
|
|
|
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[input_idx],
|
|
|
|
|
|
|
|
original_inputs[input_node_size]};
|
|
|
|
auto depend_cnode = func_graph->NewCNode(depend_inputs);
|
|
|
|
auto depend_cnode = func_graph->NewCNode(depend_inputs);
|
|
|
|
depend_cnode->set_abstract(original_inputs[1]->abstract());
|
|
|
|
depend_cnode->set_abstract(original_inputs[input_idx]->abstract());
|
|
|
|
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
|
|
|
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
|
|
|
// Create new assign node, delete U from inputs.
|
|
|
|
// Create new node, delete U from inputs.
|
|
|
|
AnfNodePtrList new_assign_inputs = {cnode->input(0), depend_cnode, original_inputs[2]};
|
|
|
|
AnfNodePtrList new_inputs = {cnode->input(0)};
|
|
|
|
auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs);
|
|
|
|
for (int i = 1; i < input_node_size; i++) {
|
|
|
|
new_assign_cnode->set_abstract(original_abstract);
|
|
|
|
if (i == input_idx) {
|
|
|
|
new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr());
|
|
|
|
new_inputs.push_back(depend_cnode);
|
|
|
|
return new_assign_cnode;
|
|
|
|
} else {
|
|
|
|
|
|
|
|
new_inputs.push_back(cnode->input(i));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto new_cnode = func_graph->NewCNode(new_inputs);
|
|
|
|
|
|
|
|
new_cnode->set_abstract(original_abstract);
|
|
|
|
|
|
|
|
new_cnode->set_kernel_info(cnode->kernel_info_ptr());
|
|
|
|
|
|
|
|
return new_cnode;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
|
|
|
if (!CanSplit(node)) return node;
|
|
|
|
|
|
|
|
return ProcessNode(node->func_graph(), node, 1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr OpUMonadExpander::Run(const AnfNodePtr &node) {
|
|
|
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool has_umonad = false;
|
|
|
|
|
|
|
|
for (unsigned int i = 1; i < cnode->size(); i++) {
|
|
|
|
|
|
|
|
if (HasAbstractUMonad(cnode->input(i))) {
|
|
|
|
|
|
|
|
has_umonad = true;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (has_umonad) {
|
|
|
|
|
|
|
|
auto new_node = ProcessNode(node->func_graph(), node, input_idx_);
|
|
|
|
|
|
|
|
return DefaultExpander::Run(new_node);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return DefaultExpander::Run(node);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace mindspore
|
|
|
|
} // namespace mindspore
|