Skip to content

Commit fcd8a02

Browse files
committed
added tests with real pytorch models and real datasets
1 parent 08bda59 commit fcd8a02

24 files changed

Lines changed: 185 additions & 267 deletions

docs/examples/basic_example_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ def __getitem__(self, index):
115115

116116
# Create a set of rotation transformations
117117
from tmeasures.transformations.parameters import UniformRotation
118-
from tmeasures.pytorch.transformations.affine import AffineGenerator
118+
from tmeasures.pytorch.transformations.affine import AffineTransformationSet
119119

120120
rotation_parameters = UniformRotation(n=128, angles=1.0)
121-
transformations = AffineGenerator(r=rotation_parameters)
121+
transformations = AffineTransformationSet(r=rotation_parameters)
122122

123123

124124

docs/examples/resnet_pretrained_rotation_invariance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def __getitem__(self, index):
4949

5050
# Create a set of 128 rotation transformations, with angles from 0 to 360
5151
from tmeasures.transformations.parameters import UniformRotation
52-
from tmeasures.pytorch.transformations.affine import AffineGenerator
52+
from tmeasures.pytorch.transformations.affine import AffineTransformationSet
5353

5454
rotation_parameters = UniformRotation(n=128, angles=1.0)
55-
transformations = AffineGenerator(r=rotation_parameters)
55+
transformations = AffineTransformationSet(r=rotation_parameters)
5656

5757

5858
# evaluate measure

tests/old_prototypes/activations.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

tests/old_prototypes/activations2.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

tests/old_prototypes/activations3.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

tests/old_prototypes/activations4.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

tests/old_prototypes/activations5.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

tests/pytorch/test_real_models.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import itertools
2+
from pathlib import Path
3+
import tempfile
4+
from typing import Callable, Type, TypeVar
5+
import torch
6+
import pytest
7+
import torchvision
8+
import tmeasures as tm
9+
import numpy as np
10+
from numpy.testing import assert_allclose
11+
12+
import numpy as np
13+
from sklearn.model_selection import train_test_split
14+
from torch.utils.data import Subset
15+
16+
from tmeasures import transformations
17+
from .utils import ConstantModel,ConstantDataset,RandomModel,RepeatedIdentitySet
18+
from tmeasures.pytorch.transformations.affine import RotationTransformationSet,ScaleTransformationSet,TranslationTransformationSet
19+
20+
transformation_sets = {
21+
"rotation":RotationTransformationSet([0.0,.25,.5,.75,]),
22+
"scale":ScaleTransformationSet([(0.5,0.5),(0.5,1.5),(1.5,0.5),(1.5,1.5),(1.0,1.0)]),"translation":TranslationTransformationSet([(0.5,0),(0.0,0.5),(-0.5,0),(0.0,-0.5),(0.0,0.0)])
23+
}
24+
25+
import dataclasses
26+
27+
@dataclasses.dataclass
28+
class Fixture:
29+
model:Callable[[torch.device,torch.utils.data.Dataset],torch.nn.Module]
30+
measure:tm.pytorch.PyTorchMeasure
31+
transformations:tuple[str,tm.pytorch.transformations.TransformationSet]
32+
dataset:torch.utils.data.Dataset
33+
34+
35+
def model_loader(model_function,**kwargs):
36+
def loader(device,dataset):
37+
model = model_function(**kwargs)
38+
model = model.to(device)
39+
model.eval()
40+
return model
41+
return loader
42+
43+
44+
def options():
45+
datasets = [cifar10()]
46+
models = [
47+
model_loader(torchvision.models.resnet18,weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1),
48+
]
49+
average_fm=tm.pytorch.AverageFeatureMaps()
50+
measures = [
51+
tm.pytorch.NormalizedVarianceInvariance(average_fm),
52+
tm.pytorch.NormalizedVarianceInvariance(),
53+
]
54+
fixtures = [Fixture(*f) for f in itertools.product(models,measures,transformation_sets.items(),datasets)]
55+
return fixtures
56+
57+
T = TypeVar('T', bound=torch.utils.data.Dataset)
58+
59+
def dataset_for_tmeasures(dataset_class:Type[T],mean,std,N=20):
60+
tmp_path = tempfile.gettempdir()
61+
preprocessing_transforms = torchvision.transforms.Compose([
62+
torchvision.transforms.ToTensor(),
63+
torchvision.transforms.Normalize(mean=mean,
64+
std=std)
65+
])
66+
data_path = Path(tmp_path).expanduser()
67+
# Iterate over images from CIFAR10 without labels
68+
class Dataset(dataset_class,torchvision.datasets.VisionDataset):
69+
def __getitem__(self, index):
70+
x, y = super().__getitem__(index)
71+
return x
72+
dataset_nolabels = Dataset(data_path, train=False, download=True,
73+
transform=preprocessing_transforms,)
74+
# Get a subset of the whole dataset; no need for a large number of samples
75+
# to calculate the invariance
76+
indices, _ = train_test_split(np.arange(len(dataset_nolabels)), train_size=N, stratify=dataset_nolabels.targets,random_state=0)
77+
dataset_nolabels = Subset(dataset_nolabels, indices)
78+
return dataset_nolabels
79+
80+
def cifar10(N=20):
81+
return dataset_for_tmeasures(torchvision.datasets.CIFAR10,
82+
mean = [0.4914, 0.4822, 0.4465],
83+
std=[0.2470, 0.2435, 0.2616],
84+
N=N
85+
)
86+
def mnist(N=20):
87+
class RGBMNIST(torchvision.datasets.MNIST):
88+
def __getitem__(self, index):
89+
x, y = super().__getitem__(index)
90+
x = x.repeat(3,1,1)
91+
return x,y
92+
return dataset_for_tmeasures(RGBMNIST,
93+
mean = [0.1307,],
94+
std=[0.3081],
95+
N=N
96+
)
97+
98+
99+
@pytest.mark.parametrize("f",options())
100+
def test_cifar10(f:Fixture,):
101+
torch.manual_seed(0)
102+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103+
cpu_device =torch.device("cpu")
104+
t_name,t_set = f.transformations
105+
dataset = f.dataset
106+
model = f.model(device,dataset)
107+
108+
activations_module = tm.pytorch.AutoActivationsModule(model)
109+
print("Activations in model:")
110+
print(activations_module.activation_names())
111+
112+
print(f"Evaluating measure {f.measure} with model {model} and transformations {t_name}...")
113+
# evaluate measure
114+
115+
options = tm.pytorch.PyTorchMeasureOptions(batch_size=16, num_workers=0,model_device=device,measure_device=device,data_device=cpu_device)
116+
measure_result:tm.pytorch.PyTorchMeasureResult = f.measure.eval(dataset,t_set,activations_module,options)
117+
measure_result = measure_result.numpy()
118+

tmeasures/np/activations_iterator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def samples_first(self):
3232
pass
3333

3434
@abc.abstractmethod
35-
def layer_names(self) -> List[str]:
35+
def activation_names(self) -> List[str]:
3636
pass
3737

3838
@abc.abstractmethod
@@ -52,7 +52,7 @@ def row_from_iterator(self,transformation_activations_iterator):
5252
Get a row of the ST matrix from the :param transformation_activations_iterator
5353
:return: row of the ST matrix with the activations for all the transformations of sample, and also the transformed samples
5454
'''
55-
activations=[[] for i in range(len(self.layer_names()))]
55+
activations=[[] for i in range(len(self.activation_names()))]
5656
x_transformed=[]
5757
for x_transformed_batch,activations_batch in transformation_activations_iterator:
5858
x_transformed.append(x_transformed_batch)

0 commit comments

Comments
 (0)