|
|
|
@ -141,6 +141,17 @@ def test_init_abnormal():
|
|
|
|
|
with py.raises(TypeError):
|
|
|
|
|
init.initializer([''], [5, 4], ms.float32)
|
|
|
|
|
|
|
|
|
|
def test_initializer_reinit():
|
|
|
|
|
weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16)
|
|
|
|
|
assert weights.dtype == ms.float16
|
|
|
|
|
assert weights.shape == (10, 1, 10, 10)
|
|
|
|
|
weights = init.initializer(weights)
|
|
|
|
|
assert weights.dtype == ms.float16
|
|
|
|
|
assert weights.shape == (10, 1, 10, 10)
|
|
|
|
|
weights.shape = None
|
|
|
|
|
weights = init.initializer(weights, (10, 1))
|
|
|
|
|
assert weights.dtype == ms.float16
|
|
|
|
|
assert weights.shape == (10, 1)
|
|
|
|
|
|
|
|
|
|
def test_init_xavier_uniform():
|
|
|
|
|
""" test_init_xavier_uniform """
|
|
|
|
|