1+ import itertools
2+ from pathlib import Path
3+ import tempfile
4+ from typing import Callable , Type , TypeVar
5+ import torch
6+ import pytest
7+ import torchvision
8+ import tmeasures as tm
9+ import numpy as np
10+ from numpy .testing import assert_allclose
11+
12+ import numpy as np
13+ from sklearn .model_selection import train_test_split
14+ from torch .utils .data import Subset
15+
16+ from tmeasures import transformations
17+ from .utils import ConstantModel ,ConstantDataset ,RandomModel ,RepeatedIdentitySet
18+ from tmeasures .pytorch .transformations .affine import RotationTransformationSet ,ScaleTransformationSet ,TranslationTransformationSet
19+
20+ transformation_sets = {
21+ "rotation" :RotationTransformationSet ([0.0 ,.25 ,.5 ,.75 ,]),
22+ "scale" :ScaleTransformationSet ([(0.5 ,0.5 ),(0.5 ,1.5 ),(1.5 ,0.5 ),(1.5 ,1.5 ),(1.0 ,1.0 )]),"translation" :TranslationTransformationSet ([(0.5 ,0 ),(0.0 ,0.5 ),(- 0.5 ,0 ),(0.0 ,- 0.5 ),(0.0 ,0.0 )])
23+ }
24+
25+ import dataclasses
26+
27+ @dataclasses .dataclass
28+ class Fixture :
29+ model :Callable [[torch .device ,torch .utils .data .Dataset ],torch .nn .Module ]
30+ measure :tm .pytorch .PyTorchMeasure
31+ transformations :tuple [str ,tm .pytorch .transformations .TransformationSet ]
32+ dataset :torch .utils .data .Dataset
33+
34+
35+ def model_loader (model_function ,** kwargs ):
36+ def loader (device ,dataset ):
37+ model = model_function (** kwargs )
38+ model = model .to (device )
39+ model .eval ()
40+ return model
41+ return loader
42+
43+
44+ def options ():
45+ datasets = [cifar10 ()]
46+ models = [
47+ model_loader (torchvision .models .resnet18 ,weights = torchvision .models .ResNet18_Weights .IMAGENET1K_V1 ),
48+ ]
49+ average_fm = tm .pytorch .AverageFeatureMaps ()
50+ measures = [
51+ tm .pytorch .NormalizedVarianceInvariance (average_fm ),
52+ tm .pytorch .NormalizedVarianceInvariance (),
53+ ]
54+ fixtures = [Fixture (* f ) for f in itertools .product (models ,measures ,transformation_sets .items (),datasets )]
55+ return fixtures
56+
57+ T = TypeVar ('T' , bound = torch .utils .data .Dataset )
58+
59+ def dataset_for_tmeasures (dataset_class :Type [T ],mean ,std ,N = 20 ):
60+ tmp_path = tempfile .gettempdir ()
61+ preprocessing_transforms = torchvision .transforms .Compose ([
62+ torchvision .transforms .ToTensor (),
63+ torchvision .transforms .Normalize (mean = mean ,
64+ std = std )
65+ ])
66+ data_path = Path (tmp_path ).expanduser ()
67+ # Iterate over images from CIFAR10 without labels
68+ class Dataset (dataset_class ,torchvision .datasets .VisionDataset ):
69+ def __getitem__ (self , index ):
70+ x , y = super ().__getitem__ (index )
71+ return x
72+ dataset_nolabels = Dataset (data_path , train = False , download = True ,
73+ transform = preprocessing_transforms ,)
74+ # Get a subset of the whole dataset; no need for a large number of samples
75+ # to calculate the invariance
76+ indices , _ = train_test_split (np .arange (len (dataset_nolabels )), train_size = N , stratify = dataset_nolabels .targets ,random_state = 0 )
77+ dataset_nolabels = Subset (dataset_nolabels , indices )
78+ return dataset_nolabels
79+
80+ def cifar10 (N = 20 ):
81+ return dataset_for_tmeasures (torchvision .datasets .CIFAR10 ,
82+ mean = [0.4914 , 0.4822 , 0.4465 ],
83+ std = [0.2470 , 0.2435 , 0.2616 ],
84+ N = N
85+ )
86+ def mnist (N = 20 ):
87+ class RGBMNIST (torchvision .datasets .MNIST ):
88+ def __getitem__ (self , index ):
89+ x , y = super ().__getitem__ (index )
90+ x = x .repeat (3 ,1 ,1 )
91+ return x ,y
92+ return dataset_for_tmeasures (RGBMNIST ,
93+ mean = [0.1307 ,],
94+ std = [0.3081 ],
95+ N = N
96+ )
97+
98+
99+ @pytest .mark .parametrize ("f" ,options ())
100+ def test_cifar10 (f :Fixture ,):
101+ torch .manual_seed (0 )
102+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
103+ cpu_device = torch .device ("cpu" )
104+ t_name ,t_set = f .transformations
105+ dataset = f .dataset
106+ model = f .model (device ,dataset )
107+
108+ activations_module = tm .pytorch .AutoActivationsModule (model )
109+ print ("Activations in model:" )
110+ print (activations_module .activation_names ())
111+
112+ print (f"Evaluating measure { f .measure } with model { model } and transformations { t_name } ..." )
113+ # evaluate measure
114+
115+ options = tm .pytorch .PyTorchMeasureOptions (batch_size = 16 , num_workers = 0 ,model_device = device ,measure_device = device ,data_device = cpu_device )
116+ measure_result :tm .pytorch .PyTorchMeasureResult = f .measure .eval (dataset ,t_set ,activations_module ,options )
117+ measure_result = measure_result .numpy ()
118+
0 commit comments