Skip to content
Discussion options

You must be logged in to vote

@aedoardo — the reason F1 and Accuracy gave identical values is that with torch.argmax on binary outputs + average="macro" + a balanced dataset, they can mathematically converge. But the deeper issue was the old API's confusing num_classes / multiclass interaction.

The fix with today's API (v1.9.0):

from torchmetrics.classification import BinaryF1Score, BinaryAccuracy
from torchmetrics import MetricCollection

metrics = MetricCollection({
    "acc": BinaryAccuracy(),
    "f1": BinaryF1Score(),
})

# In your step:
y_hat = self(x)                     # (N, 2) logits
loss = F.cross_entropy(y_hat, y)
probs = y_hat.softmax(dim=1)[:, 1]  # probability of positive class
metric_dict = metrics(probs

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by Borda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants