|
|
|
@ -107,6 +107,8 @@ AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNode
|
|
|
|
|
auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr {
|
|
|
|
|
auto fg = all_reduce_pat.GetFuncGraph();
|
|
|
|
|
auto z_ = z.GetNode(node);
|
|
|
|
|
auto x_ = x.GetNode(node);
|
|
|
|
|
|
|
|
|
|
// If addn inputs cross the graph, make the inputs same as allreduce node.
|
|
|
|
|
if (z_->isa<CNode>() && fg != z_->func_graph()) {
|
|
|
|
|
auto cnode_z = z_->cast<CNodePtr>();
|
|
|
|
@ -121,7 +123,43 @@ AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNode
|
|
|
|
|
auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0);
|
|
|
|
|
auto addn_maketuple = admktup_pat.GetOriginalNode();
|
|
|
|
|
|
|
|
|
|
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x.GetNode(node)}, fg);
|
|
|
|
|
ShapeVector x_shape, z_shape;
|
|
|
|
|
if (!x_->isa<ValueNode>()) {
|
|
|
|
|
if ((x_->abstract() == nullptr) || !x_->abstract()->isa<abstract::AbstractTensor>()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto x_abstract = x_->abstract()->cast<abstract::AbstractTensorPtr>();
|
|
|
|
|
x_shape = x_abstract->shape()->shape();
|
|
|
|
|
} else {
|
|
|
|
|
ValuePtr x_value = x_->cast<ValueNodePtr>()->value();
|
|
|
|
|
if (!x_value->isa<tensor::Tensor>()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto x_tensor = GetValueNode<tensor::TensorPtr>(x_->cast<ValueNodePtr>());
|
|
|
|
|
x_shape = x_tensor->shape();
|
|
|
|
|
}
|
|
|
|
|
if (!z_->isa<ValueNode>()) {
|
|
|
|
|
if ((z_->abstract() == nullptr) || !z_->abstract()->isa<abstract::AbstractTensor>()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto z_abstract = z_->abstract()->cast<abstract::AbstractTensorPtr>();
|
|
|
|
|
z_shape = z_abstract->shape()->shape();
|
|
|
|
|
} else {
|
|
|
|
|
ValuePtr z_value = z_->cast<ValueNodePtr>()->value();
|
|
|
|
|
if (!z_value->isa<tensor::Tensor>()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto z_tensor = GetValueNode<tensor::TensorPtr>(z_->cast<ValueNodePtr>());
|
|
|
|
|
z_shape = z_tensor->shape();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (x_shape != z_shape) {
|
|
|
|
|
// AddN requires x_ and z_ have the same shape.
|
|
|
|
|
// If broadcasting TensorAdd is supported then can use this
|
|
|
|
|
// AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimTensorAdd), z_, x_}, fg);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
|
|
|
|
|
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
|
|
|
|
|
AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg);
|
|
|
|
|
AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg);
|
|
|
|
|