def metrics(preds, labels): _, num_edges, _ = preds.shape preds = preds.reshape((-1, num_edges * num_edges)) labels = labels.reshape((-1, num_edges * num_edges)) edge_mismatches = np.abs(preds - labels) print(np.unique(np.sum(labels, axis=1))) accuracy = 1 - np.mean(np.sum( labels * edge_mismatches, axis=1) / np.sum(labels, axis=1)) return accuracy, np.mean(np.sum(edge_mismatches, axis=1) == 0)