|
|
|
@ -253,11 +253,11 @@ class TestSparseAdamOp(unittest.TestCase):
|
|
|
|
|
row_numel = 12
|
|
|
|
|
self.row_numel = row_numel
|
|
|
|
|
self.dense_inputs = {
|
|
|
|
|
"Param": np.full((height, row_numel), 1.0).astype("float32"),
|
|
|
|
|
"Moment1": np.full((height, row_numel), 1.0).astype("float32"),
|
|
|
|
|
"Moment2": np.full((height, row_numel), 1.0).astype("float32"),
|
|
|
|
|
'Beta1Pow': np.array([beta1**3]).astype("float32"),
|
|
|
|
|
'Beta2Pow': np.array([beta2**3]).astype("float32"),
|
|
|
|
|
"Param": np.full((height, row_numel), 5.0).astype("float32"),
|
|
|
|
|
"Moment1": np.full((height, row_numel), 5.0).astype("float32"),
|
|
|
|
|
"Moment2": np.full((height, row_numel), 5.0).astype("float32"),
|
|
|
|
|
'Beta1Pow': np.array([beta1**10]).astype("float32"),
|
|
|
|
|
'Beta2Pow': np.array([beta2**10]).astype("float32"),
|
|
|
|
|
"LearningRate": np.full((1), 2.0).astype("float32")
|
|
|
|
|
}
|
|
|
|
|
self.init_output = np.full((height, row_numel), 0.0).astype("float32")
|
|
|
|
|