|
|
|
@ -23,6 +23,10 @@ SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_graphdata_getfullneighbor():
|
|
|
|
|
"""
|
|
|
|
|
Test get all neighbors
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test get all neighbors.\n')
|
|
|
|
|
g = ds.GraphData(DATASET_FILE, 2)
|
|
|
|
|
nodes = g.get_all_nodes(1)
|
|
|
|
|
assert len(nodes) == 10
|
|
|
|
@ -33,6 +37,10 @@ def test_graphdata_getfullneighbor():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_graphdata_getnodefeature_input_check():
|
|
|
|
|
"""
|
|
|
|
|
Test get node feature input check
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test getnodefeature input check.\n')
|
|
|
|
|
g = ds.GraphData(DATASET_FILE)
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
input_list = [1, [1, 1]]
|
|
|
|
@ -80,6 +88,10 @@ def test_graphdata_getnodefeature_input_check():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_graphdata_getsampledneighbors():
|
|
|
|
|
"""
|
|
|
|
|
Test sampled neighbors
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test get sampled neighbors.\n')
|
|
|
|
|
g = ds.GraphData(DATASET_FILE, 1)
|
|
|
|
|
edges = g.get_all_edges(0)
|
|
|
|
|
nodes = g.get_nodes_from_edges(edges)
|
|
|
|
@ -90,6 +102,10 @@ def test_graphdata_getsampledneighbors():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_graphdata_getnegsampledneighbors():
|
|
|
|
|
"""
|
|
|
|
|
Test neg sampled neighbors
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test get negative sampled neighbors.\n')
|
|
|
|
|
g = ds.GraphData(DATASET_FILE, 2)
|
|
|
|
|
nodes = g.get_all_nodes(1)
|
|
|
|
|
assert len(nodes) == 10
|
|
|
|
@ -98,6 +114,10 @@ def test_graphdata_getnegsampledneighbors():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_graphdata_graphinfo():
|
|
|
|
|
"""
|
|
|
|
|
Test graph info
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test graph info.\n')
|
|
|
|
|
g = ds.GraphData(DATASET_FILE, 2)
|
|
|
|
|
graph_info = g.graph_info()
|
|
|
|
|
assert graph_info['node_type'] == [1, 2]
|
|
|
|
@ -155,6 +175,10 @@ class GNNGraphDataset():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_graphdata_generatordataset():
|
|
|
|
|
"""
|
|
|
|
|
Test generator dataset
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test generator dataset.\n')
|
|
|
|
|
g = ds.GraphData(DATASET_FILE)
|
|
|
|
|
batch_num = 2
|
|
|
|
|
edge_num = g.graph_info()['edge_num'][0]
|
|
|
|
@ -173,7 +197,11 @@ def test_graphdata_generatordataset():
|
|
|
|
|
assert i == 40
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_graphdata_randomwalk():
|
|
|
|
|
def test_graphdata_randomwalkdefault():
|
|
|
|
|
"""
|
|
|
|
|
Test random walk defaults
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test randomwalk with default parameters.\n')
|
|
|
|
|
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
|
|
|
|
|
nodes = g.get_all_nodes(1)
|
|
|
|
|
print(len(nodes))
|
|
|
|
@ -184,18 +212,27 @@ def test_graphdata_randomwalk():
|
|
|
|
|
assert walks.shape == (33, 40)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_graphdata_randomwalk():
|
|
|
|
|
"""
|
|
|
|
|
Test random walk
|
|
|
|
|
"""
|
|
|
|
|
logger.info('test random walk with given parameters.\n')
|
|
|
|
|
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
|
|
|
|
|
nodes = g.get_all_nodes(1)
|
|
|
|
|
print(len(nodes))
|
|
|
|
|
assert len(nodes) == 33
|
|
|
|
|
|
|
|
|
|
meta_path = [1 for _ in range(39)]
|
|
|
|
|
walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1)
|
|
|
|
|
assert walks.shape == (33, 40)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
test_graphdata_getfullneighbor()
|
|
|
|
|
logger.info('test_graphdata_getfullneighbor Ended.\n')
|
|
|
|
|
test_graphdata_getnodefeature_input_check()
|
|
|
|
|
logger.info('test_graphdata_getnodefeature_input_check Ended.\n')
|
|
|
|
|
test_graphdata_getsampledneighbors()
|
|
|
|
|
logger.info('test_graphdata_getsampledneighbors Ended.\n')
|
|
|
|
|
test_graphdata_getnegsampledneighbors()
|
|
|
|
|
logger.info('test_graphdata_getnegsampledneighbors Ended.\n')
|
|
|
|
|
test_graphdata_graphinfo()
|
|
|
|
|
logger.info('test_graphdata_graphinfo Ended.\n')
|
|
|
|
|
test_graphdata_generatordataset()
|
|
|
|
|
logger.info('test_graphdata_generatordataset Ended.\n')
|
|
|
|
|
test_graphdata_randomwalkdefault()
|
|
|
|
|
test_graphdata_randomwalk()
|
|
|
|
|
logger.info('test_graphdata_randomwalk Ended.\n')
|
|
|
|
|