def metrics(preds, labels):
  _, num_edges, _ = preds.shape
  preds = preds.reshape((-1, num_edges * num_edges))
  labels = labels.reshape((-1, num_edges * num_edges))
  edge_mismatches = np.abs(preds - labels)
  print(np.unique(np.sum(labels, axis=1)))
  accuracy = 1 - np.mean(np.sum(
    labels * edge_mismatches, axis=1) / np.sum(labels, axis=1))
  return accuracy, np.mean(np.sum(edge_mismatches, axis=1) == 0)