|
|
|
@ -17,7 +17,7 @@
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
from mindspore import context, Tensor, Parameter, ParameterTuple
|
|
|
|
|
from mindspore import context, Tensor, Parameter, ParameterTuple, nn
|
|
|
|
|
from mindspore._checkparam import _check_str_by_regular
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
@ -229,3 +229,25 @@ def test_parameter_lazy_init():
|
|
|
|
|
para.set_parameter_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True)
|
|
|
|
|
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2)))
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_parameter_as_output():
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
|
|
|
|
initial_input = initializer('One', shape=(2,), dtype=mstype.int32)
|
|
|
|
|
updated_input = Tensor([2, 2], mstype.int32)
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
|
def __init__(self, initial, updated):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.initial = initial
|
|
|
|
|
self.updated = updated
|
|
|
|
|
self.p = Parameter(self.initial, name="weight")
|
|
|
|
|
self.new_p = self.p.init_data()
|
|
|
|
|
self.new_p.set_parameter_data(self.updated)
|
|
|
|
|
def construct(self):
|
|
|
|
|
return self.new_p
|
|
|
|
|
|
|
|
|
|
net = Net(initial_input, updated_input)
|
|
|
|
|
output = net()
|
|
|
|
|
assert np.array_equal(output.asnumpy(), np.array([2, 2], np.int32))
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|