Skip to content

Save logits with model predictions#74

Open
mihow wants to merge 4 commits intomainfrom
feat/add-logits
Open

Save logits with model predictions#74
mihow wants to merge 4 commits intomainfrom
feat/add-logits

Conversation

@mihow
Copy link
Copy Markdown
Collaborator

@mihow mihow commented Mar 13, 2025

No description provided.

@mihow
Copy link
Copy Markdown
Collaborator Author

mihow commented Mar 13, 2025

Take it away @rhine3!

@mihow
Copy link
Copy Markdown
Collaborator Author

mihow commented Mar 26, 2026

Code review

Found 1 issue:

  1. save_results methods expect 3-tuples but some post_process_batch implementations still return 2-tuples, causing ValueError at runtime. The PR updates Resnet50Classifier.post_process_batch() to return (label, score, logits) 3-tuples and updates both BinaryClassifier.save_results() and SpeciesClassifier.save_results() to unpack 3 values. However, EfficientNetClassifier.post_process_batch() and Resnet50Classifier_Turing.post_process_batch() still return 2-tuples (label, score). Concrete classes that inherit from these will crash:
    • MothNonMothClassifier2022(EfficientNetClassifier, BinaryClassifier) — MRO resolves post_process_batch to EfficientNetClassifier (2-tuple), but save_results to BinaryClassifier (expects 3-tuple)
    • TuringCostaRicaSpeciesClassifier, TuringAnguillaSpeciesClassifier, TuringUKSpeciesClassifier — all inherit post_process_batch from Resnet50Classifier_Turing (2-tuple), but save_results from SpeciesClassifier (expects 3-tuple)

EfficientNetClassifier.post_process_batch returns 2-tuples:

scores = predictions.max(axis=1).astype(float)
result = list(zip(labels, scores))
logger.debug(f"Post-processing result batch: {result}")
return result

Resnet50Classifier_Turing.post_process_batch returns 2-tuples:

scores = predictions.max(axis=1).astype(float)
result = list(zip(labels, scores))
logger.debug(f"Post-processing result batch: {result}")
return result

BinaryClassifier.save_results unpacks 3-tuples:

"model_name": self.name,
}
for label, score, _logits in batch_output
]
save_classified_objects(self.db_path, object_ids, classified_objects_data)

SpeciesClassifier.save_results unpacks 3-tuples:

"in_queue": True,
}
for label, top_score, logits in batch_output
]
save_classified_objects(self.db_path, object_ids, classified_objects_data)

🤖 Generated with Claude Code

- If this code review was useful, please react with 👍. Otherwise, react with 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant