|
|
|
@ -77,15 +77,15 @@ class GNNFeatureTransform(nn.Cell):
|
|
|
|
|
self.has_bias = check_bool(has_bias)
|
|
|
|
|
|
|
|
|
|
if isinstance(weight_init, Tensor):
|
|
|
|
|
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
|
|
|
|
weight_init.shape()[1] != in_channels:
|
|
|
|
|
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
|
|
|
|
weight_init.shape[1] != in_channels:
|
|
|
|
|
raise ValueError("weight_init shape error")
|
|
|
|
|
|
|
|
|
|
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
|
|
|
|
|
|
|
|
|
if self.has_bias:
|
|
|
|
|
if isinstance(bias_init, Tensor):
|
|
|
|
|
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
|
|
|
|
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
|
|
|
|
raise ValueError("bias_init shape error")
|
|
|
|
|
|
|
|
|
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
|
|
|
|