Skip to content

feat: initial implementation for rapidata#581

Open
begumcig wants to merge 4 commits intomainfrom
feat/rapiddata-metric
Open

feat: initial implementation for rapidata#581
begumcig wants to merge 4 commits intomainfrom
feat/rapiddata-metric

Conversation

@begumcig
Copy link
Member

@begumcig begumcig commented Mar 19, 2026

Description

This PR introduces a new stateful, asynchronous metric that submits generative model outputs (images, videos, etc.) to the Rapidata platform for human evaluation. Raters compare outputs across models on configurable criteria (e.g. image quality, prompt alignment), and results are retrieved later once enough votes are collected.

New Stuffs Implemented

New AsyncEvaluationMixin: an abstract mixin defining the create_request() / retrieve_results() contract for metrics that delegate evaluation to external services asynchronously.
New CompositeMetricResult: a result type for metrics that return multiple labeled scores (e.g. one score per model), alongside a MetricResultProtocol to unify the interface between MetricResult and CompositeMetricResult.
EvaluationAgent updates: the agent now calls set_current_context(model_name=...) on all stateful metrics before each evaluation run, accepts a model_name parameter in evaluate(), and handles None returns from compute() (for async metrics that don't produce immediate results).

Details

RapidataMetric lifecycle:

  • Authenticate via client ID/secret or interactive browser login
  • Create a benchmark from a prompt list or PrunaDataModule (or attach an existing one via from_benchmark() / from_benchmark_id())
  • Create one or more leaderboards, each with a different evaluation instruction
  • For each model: accumulate outputs via update(), submit via compute()
  • Retrieve aggregated or per-leaderboard results once human evaluation completes
  • Media handling supports str (URLs/paths), PIL.Image, and torch.Tensor: tensors and PIL images are saved to a temp directory for upload, then cleaned up.

Other changes:

Added set_current_context() as a no-op hook on StatefulMetric so the agent can uniformly notify all metrics of the current model without changing the structure of update() and compute()
Typed EvaluationAgent result lists as MetricResultProtocol instead of concrete MetricResult
Added rapidata as an optional dependency extra in pyproject.toml
Updated CI to install the rapidata extra

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Usage

from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric

# Initialization
rdm = RapidataMetric()
rdm.create_benchmark("my_bench_standalone", prompt_to_test)
rdm.create_request("Quality", instruction="Which video looks better?")

#  Use with Evaluation agent
agent = EvaluationAgent(
    request=[rapidata_metric],
    datamodule=datamodule,
)

results_a = agent.evaluate(model_a, model_name="model_a")

results_b = agent.evaluate(model_b, model_name="model_b")

# Use standalone
rdm.set_current_context("model_a")
rdm.update(prompts, model_a_gt, model_a_outputs)
rdm.compute()

rdm.set_current_context("model_b")
rdm.update(prompts, model_b_gt, model_b_outputs)
rdm.compute()

# Retrieve rankings
rdm.retrieve_results()

@begumcig begumcig force-pushed the feat/rapiddata-metric branch 2 times, most recently from c5e302a to 312d056 Compare March 20, 2026 10:48
@begumcig begumcig force-pushed the feat/rapiddata-metric branch from 312d056 to aa2b198 Compare March 20, 2026 11:01
@begumcig begumcig marked this pull request as ready for review March 20, 2026 11:03
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 3 potential issues.

Fix All in Cursor

Bugbot Autofix prepared fixes for all 3 issues found in the latest run.

  • ✅ Fixed: Unconditional import of optional dependency breaks metrics package
    • I wrapped the RapidataMetric import in metrics/__init__.py with a ModuleNotFoundError guard for missing rapidata and only export it when available.
  • ✅ Fixed: Missing newline between concatenated warning message strings
    • I added the missing newline separator in the warning string so the message now renders as separate sentences.
  • ✅ Fixed: Missing benchmark validation in compute() method
    • I added self._require_benchmark() at the start of compute() so missing benchmark state raises the intended ValueError.

Create PR

Or push these changes by commenting:

@cursor push fc62a066b6
Preview (fc62a066b6)
diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py
--- a/src/pruna/evaluation/metrics/__init__.py
+++ b/src/pruna/evaluation/metrics/__init__.py
@@ -22,10 +22,15 @@
 from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric
 from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric
 from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore
-from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric
 from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric
 from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper
 
+try:
+    from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric
+except ModuleNotFoundError as e:
+    if e.name != "rapidata":
+        raise
+
 __all__ = [
     "MetricRegistry",
     "TorchMetricWrapper",
@@ -44,5 +49,7 @@
     "DinoScore",
     "SharpnessMetric",
     "AestheticLAION",
-    "RapidataMetric",
 ]
+
+if "RapidataMetric" in globals():
+    __all__.append("RapidataMetric")

diff --git a/src/pruna/evaluation/metrics/metric_rapiddata.py b/src/pruna/evaluation/metrics/metric_rapiddata.py
--- a/src/pruna/evaluation/metrics/metric_rapiddata.py
+++ b/src/pruna/evaluation/metrics/metric_rapiddata.py
@@ -299,6 +299,7 @@
         :meth:`retrieve_granular_results` once enough votes have been
         collected.
         """
+        self._require_benchmark()
         self._require_model()
         if not self.media_cache:
             raise ValueError("No data accumulated. Call update() before compute().")
@@ -348,7 +349,7 @@
             if "ValidationError" in type(e).__name__:
                 pruna_logger.warning(
                     "The benchmark hasn't finished yet.\n "
-                    "Please wait for more votes and try again."
+                    "Please wait for more votes and try again.\n "
                     "Skipping."
                 )
                 return None

This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric
from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric
from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore
from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional import of optional dependency breaks metrics package

High Severity

RapidataMetric is unconditionally imported in __init__.py, but metric_rapidata.py has top-level imports of rapidata (an optional dependency under [project.optional-dependencies]). This causes an ImportError for any user who hasn't installed the rapidata extra, breaking the entire pruna.evaluation.metrics package — including unrelated metrics like MetricRegistry, CMMD, DinoScore, etc. Other modules like benchmarks.py also import from this package and would break.

Additional Locations (1)
Fix in Cursor Fix in Web

"The benchmark hasn't finished yet.\n "
"Please wait for more votes and try again."
"Skipping."
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing newline between concatenated warning message strings

Low Severity

Adjacent string literals on lines 351–352 are implicitly concatenated without a separator, producing "Please wait for more votes and try again.Skipping.". A \n is likely intended before "Skipping." to match the formatting pattern used elsewhere in this message and throughout the file.

Fix in Cursor Fix in Web

"https://app.rapidata.ai/mri/benchmarks/%s",
self.current_benchmarked_model,
self.benchmark.id,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing benchmark validation in compute() method

Low Severity

The compute() method accesses self.benchmark.evaluate_model(...) and self.benchmark.id without calling _require_benchmark() first. Every other public method that accesses self.benchmarkcreate_request(), update(), retrieve_results(), retrieve_granular_results() — properly calls _require_benchmark(). This inconsistency means a user who calls compute() without a benchmark configured gets a confusing AttributeError on NoneType instead of the clear ValueError produced by _require_benchmark().

Fix in Cursor Fix in Web

Copy link
Member

@davidberenstein1957 davidberenstein1957 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, left some comments. I feel that some documentation is required for this too right, or is this generated automatically?

"""

media_cache: List[torch.Tensor | PIL.Image.Image | str]
prompt_cache: List[str]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't do image editing yet?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm maybe I don't understand the question, but rapidata doesn't support image editing models, so the input can only be "prompts", which makes it not possible to support image editing!

Copy link
Member

@davidberenstein1957 davidberenstein1957 Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, currently it seems we can pass prompts and media as generations, but I understood they also support image-editing tasks where we'd pass custom formatting, but this is not something we want to support for now?

default_call_type: str = "x_y"
higher_is_better: bool = True
metric_name: str = METRIC_RAPIDATA
runs_on: List[str] = ["cpu", "cuda"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need cuda?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it but we support it! do you think it should be only supported in cpu?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be both in this case, or should it require CUDA? I assumed it was a recommendation, but was not sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense to add something like runs_on "any"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you are so correct, because by default the runs_on in StatefulMetric is configured to run on everything. So I could remove this parameter from this all together actually!

vbench = [
"vbench-pruna; sys_platform != 'darwin'",
]
rapidata = [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, so we already start with the extra seperation, nice!
@begumcig I know that we can also do something like.

evaluation = [
    rapidata,
    vbench
]
could be nice to already start structuring like this, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think vbench and rapidata have a lot of shared dependencies, so doesn't really make sense to me to group them together

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed it could be a shared dependency group for all evaluation metrics. You can add extras to extras, but perhaps you'd like to keep them seperate?

self.benchmark.id,
)

def retrieve_results(self, *args, **kwargs) -> CompositeMetricResult | None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be nice to add the functionality to wait till done, WDYT? It feels a bit harsh to fail outright without giving the option to do so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, do you think it makes sense to wait 15 mins - 1 hour?
If we do not have any results yet I am catching the error that rapidata is throwing, supressing it, and returning None, as a result. I thought it to be a middle ground between failing, and waiting indefinitely, I think the user also gets some sort of notification from the platform when the benchmark is finished anyway, does it make sense to wait for a long time?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, but currently, you have to check, then get an error, then wait 15 minutes, and then manually check again.

@begumcig
Copy link
Member Author

@cursor push fc62a06

Copy link
Member

@simlang simlang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a first quick pass, david already covered most of my comments, so waiting for the second iteration!
but already looks super amazing 💅

The keyword arguments to pass to the metric.
"""

def set_current_context(self, *args, **kwargs) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agree, maybe also with a mixin, as you did for the async metric. some kind of MulitStateMetricMixin or something

----------
call_type : str
How to extract inputs from (x, gt, outputs). Default is "single".
client : RapidataClient | None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to what we spoke about with benchmark and benchmark_id
maybe a single rapidata client, which can be RapidataClient | str | None? and then based on the type do different things

default_call_type: str = "x_y"
higher_is_better: bool = True
metric_name: str = METRIC_RAPIDATA
runs_on: List[str] = ["cpu", "cuda"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense to add something like runs_on "any"


self.benchmark = self.client.mri.create_new_benchmark(name, prompts=data, **kwargs)

def create_request(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree that the naming is not optimal, as we already have requests which are something different - maybe something which includes async?

self._require_benchmark()
self.benchmark.create_leaderboard(name, instruction, show_prompt, **kwargs)

def set_current_context(self, model_name: str, **kwargs) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess rapidate only supports comparing two models, right?
i had a mulitstate (better name maybe multimodel idk) mixin comment before. does it make sense to have a max_number of models in there? so we keep all contexts and if a user tries to add a third model we say nope?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can actually compare as many models as you like!


self._cleanup_temp_media()

pruna_logger.info(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think i disagree, but just from the naming.
it is an info, but it probably should be printed by default - so like a warning, hahahaha

@begumcig begumcig force-pushed the feat/rapiddata-metric branch from 2f66795 to a9f0e40 Compare March 20, 2026 16:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants