-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampling.py
More file actions
62 lines (54 loc) · 2.25 KB
/
sampling.py
File metadata and controls
62 lines (54 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from scheduler import get_index_from_list
from dataset import show_tensor_image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from model import model
IMG_SIZE = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
@torch.no_grad()
def sample_timestep(x, t, betas):
"""
Calls the model to predict the noise in the image and returns the denoised image.
Applies noise to this image, if we are not in the last step yet.
"""
# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
betas_t = get_index_from_list(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
# model.forward(current image - noise prediction)
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
if t == 0:
return model_mean
else:
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
@torch.no_grad()
def sample_plot_image(image_name, T, betas):
# Sample noise
img_size = IMG_SIZE
img = torch.randn((1, 3, img_size, img_size), device=device)
plt.figure(figsize=(15,15))
plt.axis('off')
num_images = 10
stepsize = int(T/num_images)
for i in range(0,T)[::-1]:
t = torch.full((1,), i, device=device, dtype=torch.long)
img = sample_timestep(img, t, betas)
# Edit: This is to maintain the natural range of the distribution
img = torch.clamp(img, -1.0, 1.0)
if i % stepsize == 0:
plt.subplot(1, num_images, int(i/stepsize)+1)
show_tensor_image(img.detach().cpu(), image_name)
plt.close() # In order to solve 'RuntimeWarning: More than 20 figures have been opened'