|
|
|
@ -120,8 +120,8 @@ class AddByZero : public AnfVisitor {
|
|
|
|
|
AnfNodePtr x_{nullptr};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimTensorAdd, {PrimZerosLikeTensor, Y}, X},
|
|
|
|
|
// {prim::kPrimTensorAdd, X, {PrimZerosLikeTensor, Y}}
|
|
|
|
|
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
|
|
|
|
|
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
|
|
|
|
|
class TensorAddByZero : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
@ -135,7 +135,7 @@ class TensorAddByZero : public AnfVisitor {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override {
|
|
|
|
|
if (IsPrimitive(node, prim::kPrimZerosLikeTensor)) {
|
|
|
|
|
if (IsPrimitive(node, prim::kPrimZerosLike)) {
|
|
|
|
|
is_zero_ = true;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -153,7 +153,7 @@ class TensorAddByZero : public AnfVisitor {
|
|
|
|
|
AnfNodePtr x_{nullptr};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// {PrimMomentum, {PrimZerosLikeTensor, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
|
|
|
|
|
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
|
|
|
|
|
class OptUpdateZeroTensor : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
@ -163,13 +163,13 @@ class OptUpdateZeroTensor : public AnfVisitor {
|
|
|
|
|
|
|
|
|
|
// {PrimMomentum, {...}, Y, Z, Xs}
|
|
|
|
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
|
|
|
|
if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLikeTensor)) {
|
|
|
|
|
if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto y = inputs[2];
|
|
|
|
|
auto z = inputs[3];
|
|
|
|
|
|
|
|
|
|
// {PrimZerosLikeTensor, X}
|
|
|
|
|
// {kPrimZerosLike, X}
|
|
|
|
|
if (inputs[1]->cast<CNodePtr>()->size() != 2) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|