|
|
|
@ -35,16 +35,18 @@ log.setLevel(level=logging.ERROR)
|
|
|
|
|
relu_test = Primitive('relu_test')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_ops_f1(x, y):
|
|
|
|
|
foo = relu_test(x)
|
|
|
|
|
return foo
|
|
|
|
|
def test_ops_f1(x):
|
|
|
|
|
test = relu_test(x)
|
|
|
|
|
return test
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# use method2: create instance outside function use an operator with parameters
|
|
|
|
|
class Conv_test(Primitive):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, stride=0, pad=1):
|
|
|
|
|
print('in conv_test init', self.stride)
|
|
|
|
|
self.stride = stride
|
|
|
|
|
self.pad = pad
|
|
|
|
|
print('in conv_test init', self.stride, self.pad)
|
|
|
|
|
|
|
|
|
|
def __call__(self, x=0, y=1, z=2):
|
|
|
|
|
pass
|
|
|
|
@ -65,7 +67,7 @@ class ResNet(nn.Cell):
|
|
|
|
|
self.weight = Parameter(tensor, name="weight")
|
|
|
|
|
self.conv = Conv_test(3, 5)
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y, train="train"):
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
return x + y * self.weight + self.conv(x)
|
|
|
|
|
|
|
|
|
|
def get_params(self):
|
|
|
|
@ -78,7 +80,7 @@ class SimpleNet(nn.Cell):
|
|
|
|
|
self.weight = Parameter(tensor, name="weight")
|
|
|
|
|
self.network = network
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y, train="train"):
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
return self.network(x) + self.weight * y
|
|
|
|
|
|
|
|
|
|
def get_params(self):
|
|
|
|
@ -106,7 +108,7 @@ class SimpleNet_1(nn.Cell):
|
|
|
|
|
super(SimpleNet_1, self).__init__()
|
|
|
|
|
self.conv = Conv_test(2, 3)
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y, train="train"):
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
return self.conv(x, y)
|
|
|
|
|
|
|
|
|
|
def get_params(self):
|
|
|
|
|