metrics function
The snippet can be accessed without any authentication.
Authored by
Subham Sekhar Sahoo
mst.py 415 B
def metrics(preds, labels):
_, num_edges, _ = preds.shape
preds = preds.reshape((-1, num_edges * num_edges))
labels = labels.reshape((-1, num_edges * num_edges))
edge_mismatches = np.abs(preds - labels)
print(np.unique(np.sum(labels, axis=1)))
accuracy = 1 - np.mean(np.sum(
labels * edge_mismatches, axis=1) / np.sum(labels, axis=1))
return accuracy, np.mean(np.sum(edge_mismatches, axis=1) == 0)
Please register or sign in to comment