pull/915/head
chang zherui 5 years ago
commit 3c1785a121

2
.gitmodules vendored

@ -12,4 +12,4 @@
url = https://github.com/protocolbuffers/protobuf.git
[submodule "graphengine"]
path = graphengine
url = https://gitee.com/mindspore/graphengine.git
url = https://gitee.com/ms-incubator/graphengine.git

@ -1 +1 @@
Subproject commit 43f5d24337bf785251eefae2d810c7d5684194d6
Subproject commit cfc99f95f722918025b0eaeb93440d92265f09fe

@ -230,7 +230,6 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
// Debug ops
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");

@ -234,7 +234,6 @@ extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimNotInDict;
// Comm ops
extern const PrimitivePtr kPrimAllReduce;
extern const PrimitivePtr kPrimMirror;
extern const PrimitivePtr kPrimVirtualDiv;
extern const PrimitivePtr kPrimVirtualDataset;

@ -48,7 +48,7 @@ namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimAddN, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});

@ -228,86 +228,6 @@ class ConstantDuplicateMul : public AnfVisitor {
CNodePtr cnode_;
};
// grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy
// ->
// grad = grad + weight * decy
// grad = AllReduce(grad) / worker_number
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
class AdjustAllReduceMulAdd : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
// {prim::kPrimAddN, Zs}
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
return nullptr;
}
auto addn = node->cast<CNodePtr>();
if (addn->size() != 2) {
return nullptr;
}
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) {
return nullptr;
}
auto addn_op_node = addn->input(0);
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
auto fg = node->func_graph();
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
return NewCNode({mul_, all_reduce, y_}, fg);
}
void Visit(const AnfNodePtr &node) override {
if (level_ == 0) {
level_ = 1;
is_reduce_match_ = false;
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
AnfVisitor::Match(prim::kPrimMul)(node);
level_ = 0;
if (is_reduce_match_) {
mul_ = node->cast<CNodePtr>()->input(0);
y_ = tmp_;
} else {
z_ = node;
}
}
if (level_ == 1) {
// {prim::kPrimAllReduce, X}
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
auto cnode = node->cast<CNodePtr>();
if (cnode->size() > 1) {
all_reduce_ = cnode->input(0);
x_ = cnode->input(1);
is_reduce_match_ = true;
}
} else {
tmp_ = node;
}
}
}
void Reset() {
level_ = 0;
is_reduce_match_ = false;
x_ = nullptr;
y_ = nullptr;
z_ = nullptr;
tmp_ = nullptr;
}
private:
int level_{0};
bool is_reduce_match_{false};
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr};
};
class ArithmeticSimplify {
public:
ArithmeticSimplify()
@ -323,7 +243,6 @@ class ArithmeticSimplify {
eliminaters_.emplace_back(identity_);
eliminaters_.emplace_back(opt_update_zero_tensor_);
eliminaters_.emplace_back(constant_duplicate_mul_);
eliminaters_.emplace_back(adjust_allreduce_mul_add_);
}
~ArithmeticSimplify() = default;
@ -345,7 +264,6 @@ class ArithmeticSimplify {
PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_;
AdjustAllReduceMulAdd adjust_allreduce_mul_add_;
std::vector<TransformFuncType> eliminaters_{};
};
} // namespace irpass

@ -96,7 +96,6 @@ const char kNameConfusionMatrix[] = "ConfusionMatrix";
const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor";
const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad";
const char kNameApplyAdam[] = "Adam";
const char kNameExtractImagePatches[] = "ExtractImagePatches";
const char kNameReLU6[] = "ReLU6";
const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu";
@ -111,8 +110,6 @@ const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits
const char kNameSigmoidCrossEntropyWithLogitsGrad[] = "SigmoidCrossEntropyWithLogitsGrad";
const char kNameScatterNdD[] = "ScatterNd";
const char kNamePadD[] = "Pad";
const char kNameMirrorPad[] = "MirrorPad";
const char kNameMirrorPadGrad[] = "MirrorPadGrad";
const char kNameGatherNd[] = "GatherNd";
const char kNameArgmax[] = "Argmax";
const char kNameArgmin[] = "Argmin";
@ -176,6 +173,7 @@ const char kNameAbsGrad[] = "AbsGrad";
const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy";
const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad";
const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad";
const char kNameSparseApplyFtrlD[] = "SparseApplyFtrlD";
const char kNameAcosh[] = "Acosh";
const char kNameFloorMod[] = "FloorMod";
const char kNameSpaceToDepth[] = "SpaceToDepth";
@ -204,7 +202,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameMaxPool), ADPT_DESC(MaxPool)},
{string(kNameAvgPool), ADPT_DESC(AvgPool)},
{string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)},
{string(kNameTopK), ADPT_DESC(TopKV2)},
{string(kNameTopK), ADPT_DESC(TopK)},
{string(kNamePack), ADPT_DESC(Pack)},
{string(kNameUnpack), ADPT_DESC(Unpack)},
{string(kNameSplitD), ADPT_DESC(SplitD)},
@ -215,7 +213,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)},
{string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)},
{string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)},
{string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)},
{prim::kPrimAssign->name(), ADPT_DESC(Assign)},
{prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)},
{prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)},
@ -240,15 +237,15 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameSquare), ADPT_DESC(Square)},
{prim::kPrimTanh->name(), ADPT_DESC(Tanh)},
{prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)},
{string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborD)},
{string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborGrad)},
{string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborV2D)},
{string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborV2Grad)},
{string(kNameApplyAdam), ADPT_DESC(ApplyAdam)},
{string(kNameReLU6), ADPT_DESC(Relu6)},
{string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)},
{string(kNameElu), ADPT_DESC(Elu)},
{string(kNameEluGrad), ADPT_DESC(EluGrad)},
{string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearGrad)},
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearD)},
{string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearV2Grad)},
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
@ -260,8 +257,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameSigmoidCrossEntropyWithLogitsGrad), ADPT_DESC(SigmoidCrossEntropyWithLogitsGrad)},
{string(kNameScatterNdD), ADPT_DESC(ScatterNdD)},
{string(kNamePadD), ADPT_DESC(PadD)},
{string(kNameMirrorPad), ADPT_DESC(MirrorPad)},
{string(kNameMirrorPadGrad), ADPT_DESC(MirrorPadGrad)},
{string(kNameGatherNd), ADPT_DESC(GatherNd)},
{string(kNameArgmax), ADPT_DESC(ArgMaxD)},
{string(kNameArgmin), ADPT_DESC(ArgMinD)},
@ -329,7 +324,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimMinimum->name(), ADPT_DESC(Minimum)},
{prim::kPrimSelect->name(), ADPT_DESC(Select)},
{string(kNameLessEqual), ADPT_DESC(LessEqual)},
{prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmax)},
{prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmaxV2)},
{string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)},
{string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)},
{prim::kPrimGelu->name(), ADPT_DESC(Gelu)},
@ -363,7 +358,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimMatMul->name(), ADPT_DESC(MatMul)},
{string(kNameConst), ADPT_DESC(Constant, Const)},
{string(kNameSoftmax), ADPT_DESC(Softmax)},
{string(kNameSoftmax), ADPT_DESC(SoftmaxV2)},
{string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)},
{string(kNameParam), ADPT_DESC(Data)},
{string(kNameROIAlign), ADPT_DESC(ROIAlign)},
@ -373,6 +368,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)},
{string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)},
{string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)},
{string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)},
{string(kNameAcosh), ADPT_DESC(Acosh)},
{string(kNameFloorMod), ADPT_DESC(FloorMod)},
{string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)},
@ -1126,8 +1122,8 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr
if (desc == nullptr) {
MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
} else {
(void)std::static_pointer_cast<Data>(op)->update_input_desc_data(*desc);
(void)std::static_pointer_cast<Data>(op)->update_output_desc_out(*desc);
(void)std::static_pointer_cast<Data>(op)->update_input_desc_x(*desc);
(void)std::static_pointer_cast<Data>(op)->update_output_desc_y(*desc);
}
}

File diff suppressed because it is too large Load Diff

@ -95,8 +95,6 @@ DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax)
DECLARE_OP_ADAPTER(Conv2D)
DECLARE_OP_USE_ENUM(Conv2D)
DECLARE_OP_USE_OUTPUT(Conv2D)
DECLARE_OP_ADAPTER(ExtractImagePatches)
DECLARE_OP_USE_OUTPUT(ExtractImagePatches)
DECLARE_OP_ADAPTER(Conv2DBackpropInputD)
DECLARE_OP_USE_ENUM(Conv2DBackpropInputD)
DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropInputD)
@ -118,20 +116,20 @@ DECLARE_OP_ADAPTER(Reshape)
DECLARE_OP_USE_OUTPUT(Reshape)
DECLARE_OP_ADAPTER(Iou)
DECLARE_OP_USE_OUTPUT(Iou)
DECLARE_OP_ADAPTER(ResizeNearestNeighborD)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborD)
DECLARE_OP_ADAPTER(ResizeNearestNeighborGrad)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborGrad)
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2D)
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad)
DECLARE_OP_ADAPTER(ApplyAdam)
DECLARE_OP_USE_OUTPUT(ApplyAdam)
DECLARE_OP_ADAPTER(Relu6)
DECLARE_OP_USE_OUTPUT(Relu6)
DECLARE_OP_ADAPTER(Relu6Grad)
DECLARE_OP_USE_OUTPUT(Relu6Grad)
DECLARE_OP_ADAPTER(ResizeBilinearD)
DECLARE_OP_USE_OUTPUT(ResizeBilinearD)
DECLARE_OP_ADAPTER(ResizeBilinearGrad)
DECLARE_OP_USE_OUTPUT(ResizeBilinearGrad)
DECLARE_OP_ADAPTER(ResizeBilinearV2D)
DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D)
DECLARE_OP_ADAPTER(ResizeBilinearV2Grad)
DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad)
DECLARE_OP_ADAPTER(ZerosLike)
DECLARE_OP_USE_OUTPUT(ZerosLike)
DECLARE_OP_ADAPTER(OnesLike)
@ -157,10 +155,6 @@ DECLARE_OP_USE_INPUT_ATTR(ScatterNdD)
DECLARE_OP_USE_OUTPUT(ScatterNdD)
DECLARE_OP_ADAPTER(PadD)
DECLARE_OP_USE_OUTPUT(PadD)
DECLARE_OP_ADAPTER(MirrorPad)
DECLARE_OP_USE_OUTPUT(MirrorPad)
DECLARE_OP_ADAPTER(MirrorPadGrad)
DECLARE_OP_USE_OUTPUT(MirrorPadGrad)
DECLARE_OP_ADAPTER(BoundingBoxEncode)
DECLARE_OP_USE_OUTPUT(BoundingBoxEncode)
DECLARE_OP_ADAPTER(BoundingBoxDecode)
@ -217,8 +211,8 @@ DECLARE_OP_USE_OUTPUT(Merge)
DECLARE_OP_ADAPTER(Switch)
DECLARE_OP_USE_OUTPUT(Switch)
DECLARE_OP_ADAPTER(TopKV2)
DECLARE_OP_USE_OUTPUT(TopKV2)
DECLARE_OP_ADAPTER(TopK)
DECLARE_OP_USE_OUTPUT(TopK)
DECLARE_OP_ADAPTER(RealDiv)
DECLARE_OP_USE_OUTPUT(RealDiv)
@ -268,8 +262,8 @@ DECLARE_OP_ADAPTER(Select)
DECLARE_OP_USE_OUTPUT(Select)
DECLARE_OP_ADAPTER(LessEqual)
DECLARE_OP_USE_OUTPUT(LessEqual)
DECLARE_OP_ADAPTER(LogSoftmax)
DECLARE_OP_USE_OUTPUT(LogSoftmax)
DECLARE_OP_ADAPTER(LogSoftmaxV2)
DECLARE_OP_USE_OUTPUT(LogSoftmaxV2)
DECLARE_OP_ADAPTER(TruncatedNormal)
DECLARE_OP_USE_OUTPUT(TruncatedNormal)
DECLARE_OP_ADAPTER(StridedSliceGrad)
@ -402,8 +396,8 @@ DECLARE_OP_ADAPTER(Sigmoid)
DECLARE_OP_USE_OUTPUT(Sigmoid)
DECLARE_OP_ADAPTER(SigmoidGrad)
DECLARE_OP_USE_OUTPUT(SigmoidGrad)
DECLARE_OP_ADAPTER(Softmax)
DECLARE_OP_USE_OUTPUT(Softmax)
DECLARE_OP_ADAPTER(SoftmaxV2)
DECLARE_OP_USE_OUTPUT(SoftmaxV2)
DECLARE_OP_ADAPTER(SoftmaxGrad)
DECLARE_OP_USE_OUTPUT(SoftmaxGrad)
DECLARE_OP_ADAPTER(Greater)
@ -446,6 +440,8 @@ DECLARE_OP_ADAPTER(Round)
DECLARE_OP_USE_OUTPUT(Round)
DECLARE_OP_ADAPTER(ApplyFtrl)
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
DECLARE_OP_ADAPTER(Diag)
DECLARE_OP_USE_OUTPUT(Diag)
DECLARE_OP_ADAPTER(DiagPart)

@ -1235,8 +1235,8 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
Examples:
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
>>> input_x = [1, 2, 3, 4]
>>> segment_ids = [0, 0, 1, 2]
>>> num_segments = 4
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
[3, 3, 4, 0]

@ -1630,7 +1630,7 @@ class LayerNorm(Primitive):
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
.. math::
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.

@ -556,24 +556,5 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) {
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
}
TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell");
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr");
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl");
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr");
FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1");
FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r");
FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_});
ASSERT_TRUE(CheckOpt(beforell, after1, patterns));
ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
}
} // namespace opt
} // namespace mindspore

@ -908,8 +908,8 @@ def test_print_tuple_wrapper(tag):
def test_constant_duplicate_mul(tag):
fns = FnDict()
Mul = Primitive('Mul')
Sqrt = Primitive('Sqrt')
Mul = Primitive('Mul');
Sqrt = Primitive('Sqrt');
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
@ -936,44 +936,3 @@ def test_constant_duplicate_mul(tag):
return Mul(Sqrt(x), Mul(tensor1, tensor2))
return fns[tag]
def test_adjust_allreduce_mul_add(tag):
fns = FnDict()
Mul = Primitive('Mul')
AddN = Primitive('AddN')
AllReduce = Primitive('AllReduce')
@fns
def beforell(x, y, z):
return AddN((z, Mul(y, AllReduce(x))))
@fns
def beforelr(x, y, z):
return AddN((z, Mul(AllReduce(x), y)))
@fns
def beforerl(x, y, z):
return AddN((Mul(y, AllReduce(x)), z))
@fns
def beforerr(x, y, z):
return AddN((Mul(AllReduce(x), y), z))
@fns
def after1(x, y, z):
return Mul(AllReduce(AddN((z, x))), y)
@fns
def before2r(x, y, z):
return AddN((Mul(AllReduce(x), y), Mul(z, z)))
@fns
def before2l(x, y, z):
return AddN((Mul(z, z), Mul(AllReduce(x), y)))
@fns
def after2(x, y, z):
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
return fns[tag]

Loading…
Cancel
Save