From 793737ab62b5db47e9b35f4eecec0a6552be4f7c Mon Sep 17 00:00:00 2001 From: duxiutao Date: Thu, 2 Jul 2020 16:03:01 +0800 Subject: [PATCH] add primitive operator to test_lamb --- tests/st/ops/graph_kernel/test_lamb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/st/ops/graph_kernel/test_lamb.py b/tests/st/ops/graph_kernel/test_lamb.py index d34c0eea57..dfe975c91a 100644 --- a/tests/st/ops/graph_kernel/test_lamb.py +++ b/tests/st/ops/graph_kernel/test_lamb.py @@ -33,7 +33,8 @@ class LambNet(Cell): def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3, x1, x2, x3, x4, x5, gy, se, my): - return self.lamb_next(i1, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0, + i1_ = i1 + i3 + return self.lamb_next(i1_, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0, ix1, ix2, ix3), \ self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my) @@ -113,7 +114,8 @@ def test_graph_kernel_lamb(): context.set_context(enable_graph_kernel=False) - a3, a0, a1, up = LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0, + i1_ = i1 + i3 + a3, a0, a1, up = LambNextMVNumpy(i1_, i2, i3, i4, i5, i6, i7, i8, i9, ix0, ix1, ix2, ix3) np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my)