|
|
|
@ -64,7 +64,7 @@ class GNNFeatureTransform(nn.Cell):
|
|
|
|
|
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
|
|
|
|
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
|
|
|
|
"""
|
|
|
|
|
@cell_attr_register(attrs=['has_bias', 'activation'])
|
|
|
|
|
@cell_attr_register
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
@ -125,7 +125,7 @@ class _BaseAggregator(nn.Cell):
|
|
|
|
|
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
|
|
|
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
|
|
|
|
dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
|
|
|
|
|
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
|
|
|
|
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> class MyAggregator(_BaseAggregator):
|
|
|
|
@ -203,12 +203,12 @@ class MeanAggregator(_BaseAggregator):
|
|
|
|
|
super(MeanAggregator, self).__init__(
|
|
|
|
|
feature_in_dim,
|
|
|
|
|
feature_out_dim,
|
|
|
|
|
use_fc=True,
|
|
|
|
|
weight_init="normal",
|
|
|
|
|
bias_init="zeros",
|
|
|
|
|
has_bias=True,
|
|
|
|
|
dropout_ratio=None,
|
|
|
|
|
activation=None)
|
|
|
|
|
use_fc,
|
|
|
|
|
weight_init,
|
|
|
|
|
bias_init,
|
|
|
|
|
has_bias,
|
|
|
|
|
dropout_ratio,
|
|
|
|
|
activation)
|
|
|
|
|
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
|
|
|
|
|
|
|
|
|
def construct(self, input_feature):
|
|
|
|
@ -220,3 +220,157 @@ class MeanAggregator(_BaseAggregator):
|
|
|
|
|
input_feature = self.activation(input_feature)
|
|
|
|
|
output_feature = self.reduce_mean(input_feature, 1)
|
|
|
|
|
return output_feature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionHead(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
Attention Head for Graph Attention Networks.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
in_channel (int): The number of input channel, input feature dim.
|
|
|
|
|
out_channel (int): The number of output channel, output feature dim.
|
|
|
|
|
in_drop_ratio (float): Input feature dropout ratio, default 0.0.
|
|
|
|
|
coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
|
|
|
|
|
residual (bool): Whether to use residual connection, default False.
|
|
|
|
|
coef_activation (Cell): The attention coefficient activation function,
|
|
|
|
|
default nn.LeakyReLU().
|
|
|
|
|
activation (Cell): The output activation function, default nn.ELU().
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
|
|
|
|
|
- **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> head = AttentionHead(1433,
|
|
|
|
|
8,
|
|
|
|
|
in_drop_ratio=0.6,
|
|
|
|
|
coef_drop_ratio=0.6,
|
|
|
|
|
residual=False)
|
|
|
|
|
>>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32))
|
|
|
|
|
>>> output = net(input_data)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_channel,
|
|
|
|
|
out_channel,
|
|
|
|
|
in_drop_ratio=0.0,
|
|
|
|
|
coef_drop_ratio=0.0,
|
|
|
|
|
residual=False,
|
|
|
|
|
coef_activation=nn.LeakyReLU(),
|
|
|
|
|
activation=nn.ELU()):
|
|
|
|
|
super(AttentionHead, self).__init__()
|
|
|
|
|
self.in_channel = check_int_positive(in_channel)
|
|
|
|
|
self.out_channel = check_int_positive(out_channel)
|
|
|
|
|
self.in_drop_ratio = in_drop_ratio
|
|
|
|
|
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
|
|
|
|
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
|
|
|
|
self.feature_transform = GNNFeatureTransform(
|
|
|
|
|
in_channels=self.in_channel,
|
|
|
|
|
out_channels=self.out_channel,
|
|
|
|
|
has_bias=False)
|
|
|
|
|
|
|
|
|
|
self.f_1_transform = GNNFeatureTransform(
|
|
|
|
|
in_channels=self.out_channel,
|
|
|
|
|
out_channels=1)
|
|
|
|
|
self.f_2_transform = GNNFeatureTransform(
|
|
|
|
|
in_channels=self.out_channel,
|
|
|
|
|
out_channels=1)
|
|
|
|
|
self.softmax = nn.Softmax()
|
|
|
|
|
|
|
|
|
|
self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio)
|
|
|
|
|
self.batch_matmul = P.BatchMatMul()
|
|
|
|
|
self.bias_add = P.BiasAdd()
|
|
|
|
|
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
|
|
|
|
|
self.residual = check_bool(residual)
|
|
|
|
|
if self.residual:
|
|
|
|
|
if in_channel != out_channel:
|
|
|
|
|
self.residual_transform_flag = True
|
|
|
|
|
self.residual_transform = GNNFeatureTransform(
|
|
|
|
|
in_channels=self.in_channel,
|
|
|
|
|
out_channels=self.out_channel)
|
|
|
|
|
else:
|
|
|
|
|
self.residual_transform = None
|
|
|
|
|
self.coef_activation = coef_activation
|
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
|
|
def construct(self, input_feature, bias_mat):
|
|
|
|
|
input_feature = self.in_drop(input_feature)
|
|
|
|
|
|
|
|
|
|
feature = self.feature_transform(input_feature)
|
|
|
|
|
# self attention following the author
|
|
|
|
|
f_1 = self.f_1_transform(feature)
|
|
|
|
|
f_2 = self.f_2_transform(feature)
|
|
|
|
|
logits = f_1 + P.Transpose()(f_2, (0, 2, 1))
|
|
|
|
|
logits = self.coef_activation(logits) + bias_mat
|
|
|
|
|
coefs = self.softmax(logits)
|
|
|
|
|
|
|
|
|
|
coefs = self.coef_drop(coefs)
|
|
|
|
|
feature = self.in_drop_2(feature)
|
|
|
|
|
|
|
|
|
|
ret = self.batch_matmul(coefs, feature)
|
|
|
|
|
ret = P.Squeeze(0)(ret)
|
|
|
|
|
ret = self.bias_add(ret, self.bias)
|
|
|
|
|
ret = P.ExpandDims()(ret, 0)
|
|
|
|
|
# residual connection
|
|
|
|
|
if self.residual:
|
|
|
|
|
if self.residual_transform_flag:
|
|
|
|
|
res = self.residual_transform(input_feature)
|
|
|
|
|
ret = ret + res
|
|
|
|
|
else:
|
|
|
|
|
ret = ret + input_feature
|
|
|
|
|
# activation
|
|
|
|
|
ret = self.activation(ret)
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionAggregator(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
Attention Head for Graph Attention Networks,can be regarded as one
|
|
|
|
|
GAT layer.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
in_channel (int): Input channel.
|
|
|
|
|
out_channel (int): Output channel.
|
|
|
|
|
num_heads (int): Number of attention heads for this layer, default 1.
|
|
|
|
|
in_drop_ratio (float): Input feature dropout ratio, default 0.0.
|
|
|
|
|
coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
|
|
|
|
|
activation (Cell): The output activation function, default nn.ELU().
|
|
|
|
|
residual (bool): Whether to use residual connection, default False.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
|
|
|
|
|
- **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
|
|
|
|
|
>>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
|
|
|
|
|
>>> net = AttentionAggregator(1433,
|
|
|
|
|
8,
|
|
|
|
|
8)
|
|
|
|
|
>>> net(input_data, biases)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
num_heads=1,
|
|
|
|
|
in_drop=0.0,
|
|
|
|
|
coef_drop=0.0,
|
|
|
|
|
activation=nn.ELU(),
|
|
|
|
|
residual=False):
|
|
|
|
|
super(AttentionAggregator, self).__init__()
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.attns = []
|
|
|
|
|
for _ in range(num_heads):
|
|
|
|
|
self.attns.append(AttentionHead(in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
in_drop_ratio=in_drop,
|
|
|
|
|
coef_drop_ratio=coef_drop,
|
|
|
|
|
activation=activation,
|
|
|
|
|
residual=residual))
|
|
|
|
|
self.attns = nn.layer.CellList(self.attns)
|
|
|
|
|
|
|
|
|
|
def construct(self, input_data, bias_mat):
|
|
|
|
|
res = ()
|
|
|
|
|
for i in range(self.num_heads):
|
|
|
|
|
res += (self.attns[i](input_data, bias_mat),)
|
|
|
|
|
return P.Concat(-1)(res)
|
|
|
|
|