Skip to content
Snippets Groups Projects

metrics function

  • Clone with SSH
  • Clone with HTTPS
  • Embed
  • Share
    The snippet can be accessed without any authentication.
    Authored by Subham Sekhar Sahoo
    mst.py 415 B
    def metrics(preds, labels):
      _, num_edges, _ = preds.shape
      preds = preds.reshape((-1, num_edges * num_edges))
      labels = labels.reshape((-1, num_edges * num_edges))
      edge_mismatches = np.abs(preds - labels)
      print(np.unique(np.sum(labels, axis=1)))
      accuracy = 1 - np.mean(np.sum(
        labels * edge_mismatches, axis=1) / np.sum(labels, axis=1))
      return accuracy, np.mean(np.sum(edge_mismatches, axis=1) == 0)
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment