Scores for subgraphs and non-subgraphs
jcrangel opened this issue · 0 comments
jcrangel commented
Hello, thanks for coding such a great project.
I'm trying to score if a graph is subgraph or not using the code in aligment.py by creating a subgraph using graph, neigh = utils.sample_neigh([target], 7),
and scoring using score = model.predict(model(ttarget, tquery)).
.Also, for comparison, I'm creating a non subgraph using
Gno = nx.Graph()
Gno.add_edges_from([(43, 39), (43, 14),(43,60)]).
But I get bigger values for the non-subgraph than the subgraph:
Subgraph score 338.0877380371094
Non subgraph score 487.8809509277344
I'm creating the score correctly? Here's the complete code:
import sys, os
sys.path.insert(0, os.path.abspath(".."))
import argparse
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 16})
import random
import networkx as nx
from common import data
from common import models
from common import utils
from subgraph_matching.config import parse_encoder
import torch
def subgraph_score(emb_target, emb_query):
ttarget = torch.from_numpy(emb_target).float().to(utils.get_device())
tquery = torch.from_numpy(emb_query).float().to(utils.get_device())
pred = model.predict(model(ttarget, tquery))
return pred.item()
parser = argparse.ArgumentParser()
# Now we load the model and a dataset to analyze embeddings on, here ENZYMES.
utils.parse_optimizer(parser)
parse_encoder(parser)
args = parser.parse_args("")
args.model_path = os.path.join("..", args.model_path)
print("Using dataset {}".format(args.dataset))
model = models.OrderEmbedder(1, args.hidden_dim, args)
model.to(utils.get_device())
model.eval()
model.load_state_dict(torch.load(args.model_path,
map_location=utils.get_device()))
train, test, task = data.load_dataset("enzymes")
motifs = []
for i in range(10):
graph, neigh = utils.sample_neigh(train, 29)
motifs.append(graph.subgraph(neigh))
batch = utils.batch_nx_graphs(motifs)
embs = model.emb_model(batch).detach().cpu().numpy()
max_n_edges = max([len(m.edges) for m in motifs])
max_n_nodes = max([len(m) for m in motifs])
target = motifs[4]
emb_target = embs[4]
print('target nodes:',target.nodes)
# nx.draw(target, with_labels=True)
graph, neigh = utils.sample_neigh([target], 7)
# print(graph, neigh)
query = utils.batch_nx_graphs([graph.subgraph(neigh)])
emb_query = model.emb_model(query).detach().cpu().numpy()
# nx.draw(graph.subgraph(neigh), with_labels=True)
print('subgraph nodes:', graph.subgraph(neigh).nodes)
print('Subgraph score',subgraph_score(emb_target, emb_query))
#Small non subgraph
Gno = nx.Graph()
Gno.add_edges_from([(50, 55), (56, 55)])
# nx.draw(Gno, with_labels=True)
query = utils.batch_nx_graphs([Gno])
emb_query = model.emb_model(query).detach().cpu().numpy()
print('Non subgraph nodes:', Gno.nodes)
print('Non subgraph score',subgraph_score(emb_target, emb_query))