-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathras.py
More file actions
727 lines (650 loc) · 25.2 KB
/
ras.py
File metadata and controls
727 lines (650 loc) · 25.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
from dataclasses import dataclass
from types import MethodType
import torch
from torch import nn
from torch import Tensor
from einops import rearrange
from comfy.ldm.flux.model import Flux
from comfy.ldm.hunyuan_video.model import HunyuanVideo
from comfy.ldm.flux.layers import DoubleStreamBlock, SingleStreamBlock, LastLayer
from comfy.ldm.wan.model import (
WanModel,
VaceWanModel,
CameraWanModel,
WanModel_S2V,
HumoWanModel,
WanAttentionBlock,
VaceWanAttentionBlock,
Head,
)
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.modules.attention import optimized_attention
from comfy.model_patcher import ModelPatcher
import comfy.model_management
def apply_pe(x: Tensor, pe: Tensor) -> Tensor:
"""
The PE application from flux.math.attention, removed, so that we can cache the keys post-PE
"""
shape = x.shape
dtype = x.dtype
x = x.float().reshape(*x.shape[:-1], -1, 1, 2)
x = (pe[..., 0] * x[..., 0] + pe[..., 1] * x[..., 1]).reshape(*shape).to(dtype)
return x
def take_attributes_from(source, target, keys):
for x in keys:
setattr(target, x, getattr(source, x))
@dataclass
class RASConfig:
warmup_steps: int = 4
hydrate_every: int = 5
sample_ratio: float = 0.5
starvation_scale: float = 0.1
high_ratio: float = 1.0
class RASManager:
"""
Coordinates the live indices, metrics, and model wrapping.
"""
def __init__(self, config: RASConfig):
self.flipped_img_txt = False
self.timestep: int = 0
self.n_txt: int = 0
self.n_img: int = 0
self.cached_output: Tensor | None = None
self.live_txt_indices: Tensor | None = None
self.live_img_indices: Tensor | None = None
self.drop_count: torch.Tensor | None = None
self.config = config
self.patch_size: list[int]
self.model: Flux | HunyuanVideo
assert (
self.config.high_ratio >= 0 and self.config.high_ratio <= 1
), "High ratio should be in the range of [0, 1]"
def wrap_layer(self, layer, first: bool = False, last: bool = False):
if isinstance(
layer,
(
DoubleStreamBlockWrapper,
SingleStreamBlockWrapper,
LastLayerWrapper,
WanAttentionBlockWrapper,
VaceWanAttentionBlockWrapper,
),
):
raise TypeError("Old wrapping wasn't removed!")
if isinstance(layer, DoubleStreamBlock):
wrapped = DoubleStreamBlockWrapper(layer, self, first)
elif isinstance(layer, SingleStreamBlock):
wrapped = SingleStreamBlockWrapper(layer, self, last)
elif isinstance(layer, LastLayer):
wrapped = LastLayerWrapper(layer, self)
# note: vacewaneattentionblock is a subclass of wanattentionblock
# so we have to check for the vace block first
elif isinstance(layer, VaceWanAttentionBlock):
wrapped = VaceWanAttentionBlockWrapper(layer, self, first, last)
elif isinstance(layer, WanAttentionBlock):
wrapped = WanAttentionBlockWrapper(layer, self, first, last)
elif isinstance(layer, Head):
wrapped = HeadWrapper(layer, self)
else:
raise TypeError(f"Can't wrap layer of type {layer.__class__.__name__}")
return wrapped
def wrap_model(self, patcher: ModelPatcher):
model = patcher.model.diffusion_model
self.model = model
if isinstance(model, Flux):
self.patch_size = [model.patch_size, model.patch_size]
elif isinstance(model, HunyuanVideo):
self.patch_size = model.patch_size
elif isinstance(
model, (WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel)
):
self.patch_size = model.patch_size
else:
raise TypeError(f"Can't wrap model of type {model.__class__.__name__}")
# Handle different model architectures
if isinstance(model, (Flux, HunyuanVideo)):
# wrap the single and double blocks to have caching
for i, v in enumerate(model.double_blocks):
# first block has the special responsibility of removing tokens
if i == 0:
self.flipped_img_txt = v.flipped_img_txt
layer = self.wrap_layer(v, first=True)
else:
layer = self.wrap_layer(v)
patcher.add_object_patch(f"diffusion_model.double_blocks.{i}", layer)
for i, v in enumerate(model.single_blocks):
# last block will put them back
patcher.add_object_patch(
f"diffusion_model.single_blocks.{i}",
self.wrap_layer(v, last=(i == (len(model.single_blocks) - 1))),
)
elif isinstance(
model, (WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel)
):
# Wan models have a different structure with just 'blocks'
self.flipped_img_txt = False # Wan doesn't use the flipped pattern
for i, v in enumerate(model.blocks):
# first block has the special responsibility of removing tokens
# last block will put them back
layer = self.wrap_layer(
v, first=(i == 0), last=(i == (len(model.blocks) - 1))
)
patcher.add_object_patch(f"diffusion_model.blocks.{i}", layer)
# todo, add the vace blocks as well here
if hasattr(model, "vace_blocks"):
for i, v in enumerate(model.vace_blocks):
layer = self.wrap_layer(
v,
first=(i == 0),
# we DONT put back the conditioning tokens
# because the last vace block is well before the last real block
last=False,
# last=(i == (len(model.vace_blocks) - 1))
)
patcher.add_object_patch(f"diffusion_model.vace_blocks.{i}", layer)
# wrap the forward_orig method, to be able to get the timestep
forward_orig = model.forward_orig
def new_forward(_self, *args, **kwargs):
# Get transformer_options from kwargs for Wan models, or from args for Flux/Hunyuan
if "transformer_options" in kwargs:
transformer_options = kwargs["transformer_options"]
else:
transformer_options = args[-1]
self.timestep = self.timestep_from_sigmas(
transformer_options["sigmas"], transformer_options["sample_sigmas"]
)
if self.timestep == 0:
# reset as much as possible
self.live_img_indices = None
self.live_txt_indices = None
self.drop_count = None
return forward_orig(*args, **kwargs)
patcher.add_object_patch(
"diffusion_model.forward_orig", MethodType(new_forward, model)
)
# wrap the last_layer, to be able to read the output and calculate the metric
if isinstance(model, (Flux, HunyuanVideo)):
patcher.add_object_patch(
"diffusion_model.final_layer", self.wrap_layer(model.final_layer)
)
elif isinstance(
model, (WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel)
):
# Wan models use 'head' instead of 'final_layer'
patcher.add_object_patch(
"diffusion_model.head", self.wrap_layer(model.head)
)
@staticmethod
def timestep_from_sigmas(sigmas: Tensor, sample_sigmas: Tensor):
# we assume that one element of sample_sigmas is exactly equal to sigmas
# but we'll still check explicitly, using an argmin, in case of some loss of precision
s = sigmas.item()
i = torch.argmin(torch.abs(sample_sigmas - s).flatten())
return int(i.item())
def skip_ratio(self, timestep: int) -> float:
if timestep < self.config.warmup_steps:
return 0
if self.config.hydrate_every:
if (
1 + timestep - self.config.warmup_steps
) % self.config.hydrate_every == 0:
return 0
result = 1.0 - self.config.sample_ratio
return result
def select_indices(self, diff: Tensor, timestep: int):
if isinstance(self.model, Flux):
# b (h w) (c ph pw) = model_out.shape
metric = rearrange(
diff,
"b s (c ph pw) -> b s ph pw c",
ph=self.patch_size[0],
pw=self.patch_size[1],
)
metric = torch.std(metric, dim=-1).mean((-1, -2))
elif isinstance(
self.model,
(
HunyuanVideo,
WanModel,
VaceWanModel,
CameraWanModel,
WanModel_S2V,
HumoWanModel,
),
):
# b (h w) (c ph pw) = model_out.shape
# Both HunyuanVideo and Wan models are video models with 3D patches
metric = rearrange(
diff,
"b s (c pt ph pw ) -> b s pt ph pw c",
pt=self.patch_size[0],
ph=self.patch_size[1],
pw=self.patch_size[2],
)
metric = torch.std(metric, dim=-1).mean((-1, -2, -3))
else:
raise TypeError("Unknown latent type!")
# for batch size > 1, we pick separate indices per batch
# for now, JUST FOR TESTING, we'll merge all the batches and use the indices that are the most relevant for all batches
metric = metric.mean(dim=0)
metric = metric.flatten()
if self.drop_count is None:
self.drop_count = torch.zeros(
metric.shape, dtype=torch.int, device=diff.device
)
# hmm, what if we do a gaussian blur or some sort of spatial lowpass, to improve the spatial continuity of the patches?
metric *= torch.exp(self.config.starvation_scale * self.drop_count)
indices = torch.sort(metric, dim=-1, descending=False).indices
skip_ratio = self.skip_ratio(timestep)
if skip_ratio <= 0.01:
# we're not dropping anything -- remove the live_indices
# we use the value None to indicate a full hydrate
self.live_img_indices = None
else:
low_bar = int(skip_ratio * len(metric) * (1 - self.config.high_ratio))
high_bar = int(skip_ratio * len(metric) * self.config.high_ratio)
cache_indices = torch.cat([indices[:low_bar], indices[-high_bar:]])
self.live_img_indices = indices[low_bar:-high_bar]
self.drop_count[cache_indices] += 1
# TODO: for now we keep all txt tokens
# in the future, we can probably do something like randomly keep a fraction of them
if self.n_txt > 0:
self.live_txt_indices = torch.arange(
0, self.n_txt, dtype=torch.int, device=diff.device
)
def live_indices(self):
if self.live_img_indices is None or self.live_txt_indices is None:
return self.live_img_indices
if self.flipped_img_txt:
result = torch.cat(
(self.live_img_indices, self.live_txt_indices + self.n_img)
)
else:
result = torch.cat(
(self.live_txt_indices, self.live_img_indices + self.n_txt)
)
return result
class DoubleStreamBlockWrapper(DoubleStreamBlock):
"""
Same as the DoubleStreamBlock, but uses a RASManager and RASCache to do KV caching.
"""
def __init__(self, original: DoubleStreamBlock, manager: RASManager, first=False):
nn.Module.__init__(self)
take_attributes_from(
original,
self,
[
"num_heads",
"hidden_size",
"img_mod",
"img_norm1",
"img_attn",
"img_norm2",
"img_mlp",
"txt_mod",
"txt_norm1",
"txt_attn",
"txt_norm2",
"txt_mlp",
"flipped_img_txt",
],
)
self.manager = manager
self.k_cache: torch.Tensor
self.v_cache: torch.Tensor
self.first = first
def forward(self, img, txt, vec, pe, attn_mask=None):
# RAS: if this is the first doublestreamblock, then we should drop some of the img and txt tokens
idx = self.manager.live_indices()
if self.first:
self.manager.n_txt = txt.shape[1]
self.manager.n_img = img.shape[1]
img_idx = self.manager.live_img_indices
txt_idx = self.manager.live_txt_indices
if idx is not None:
img = img[..., img_idx, :]
txt = txt[..., txt_idx, :]
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(
img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1
).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(
txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1
).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# RAS: KV Cache and Attention Call
# select part of the PE
if idx is not None:
pe = pe[:, :, idx]
# create queries, keys, and values
if self.flipped_img_txt:
queries = apply_pe(torch.cat((img_q, txt_q), dim=2), pe)
keys = apply_pe(torch.cat((img_k, txt_k), dim=2), pe)
values = torch.cat((img_v, txt_v), dim=2)
else:
queries = apply_pe(torch.cat((txt_q, img_q), dim=2), pe)
keys = apply_pe(torch.cat((txt_k, img_k), dim=2), pe)
values = torch.cat((txt_v, img_v), dim=2)
# fill in the KV cache
if idx is None:
self.k_cache = keys
self.v_cache = values
else:
self.k_cache[..., idx, :] = keys
self.v_cache[..., idx, :] = values
# actual attention call
attn = optimized_attention(
queries,
self.k_cache,
self.v_cache,
img_q.shape[1],
skip_reshape=True,
mask=attn_mask,
)
if self.flipped_img_txt:
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
else:
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# End of RAS code
# calculate the img blocks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
# calculate the txt blocks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlockWrapper(SingleStreamBlock):
"""
Same as the SingleStreamBlock, but uses a RASManager and RASCache to do KV caching.
"""
def __init__(self, original: SingleStreamBlock, manager: RASManager, last=False):
nn.Module.__init__(self)
# steal all the attributes from the SingleStreamBlock
take_attributes_from(
original,
self,
[
"hidden_dim",
"num_heads",
"scale",
"mlp_hidden_dim",
"linear1",
"linear2",
"norm",
"hidden_size",
"pre_norm",
"mlp_act",
"modulation",
],
)
self.manager = manager
self.k_cache: torch.Tensor
self.v_cache: torch.Tensor
self.last = last
def forward(self, x, vec, pe, attn_mask=None):
idx = self.manager.live_indices()
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(
self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift),
[3 * self.hidden_size, self.mlp_hidden_dim],
dim=-1,
)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(
2, 0, 3, 1, 4
)
q, k = self.norm(q, k, v)
# RAS: KV Cache
if idx is not None:
pe = pe[:, :, idx]
q = apply_pe(q, pe)
k = apply_pe(k, pe)
# full hydrate
if idx is None:
self.k_cache = k
self.v_cache = v
# partial update
else:
self.k_cache[..., idx, :] = k
self.v_cache[..., idx, :] = v
attn = optimized_attention(
q,
self.k_cache,
self.v_cache,
q.shape[1],
skip_reshape=True,
mask=attn_mask,
)
# End of RAS code
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
if self.last:
# put the relevant tokens back into the cached output
if idx is None or self.manager.cached_output is None:
self.manager.cached_output = x.clone()
else:
self.manager.cached_output[..., idx, :] = x
return self.manager.cached_output
return x
class LastLayerWrapper(LastLayer):
"""
Same as the LastLayer, but reports its output to a manager.
"""
def __init__(self, original: LastLayer, manager: RASManager):
nn.Module.__init__(self)
take_attributes_from(
original, self, ["norm_final", "linear", "adaLN_modulation"]
)
self.manager = manager
def forward(self, x, vec) -> Tensor:
output = super().forward(x, vec)
self.manager.select_indices(output, self.manager.timestep)
return output
class WanAttentionBlockWrapper(WanAttentionBlock):
"""
Same as the WanAttentionBlock, but uses a RASManager and RASCache to do KV caching.
"""
def __init__(self, original: WanAttentionBlock, manager: RASManager, first, last):
nn.Module.__init__(self)
take_attributes_from(
original,
self,
[
"dim",
"ffn_dim",
"num_heads",
"window_size",
"qk_norm",
"cross_attn_norm",
"eps",
"norm1",
"self_attn",
"norm3",
"cross_attn",
"norm2",
"ffn",
"modulation",
],
)
self.manager = manager
self.k_cache: torch.Tensor
self.v_cache: torch.Tensor
self.first = first
self.last = last
def forward(
self, x, e, freqs, context, context_img_len=257, transformer_options={}
):
# RAS: if this is the first block, then we should drop some tokens
idx = self.manager.live_indices()
if self.first:
# we just have img tokens
self.manager.n_txt = 0
self.manager.n_img = x.shape[1]
if idx is not None:
x = x[..., idx, :]
# Modulation handling (copied from original)
if e.ndim < 4:
e = (
comfy.model_management.cast_to(
self.modulation, dtype=x.dtype, device=x.device
)
+ e
).chunk(6, dim=1)
else:
e = (
comfy.model_management.cast_to(
self.modulation, dtype=x.dtype, device=x.device
).unsqueeze(0)
+ e
).unbind(2)
# Self-attention with RAS caching
y = self.self_attn_with_cache(
torch.addcmul(
self.repeat_e(e[0], x), self.norm1(x), 1 + self.repeat_e(e[1], x)
),
freqs,
idx,
transformer_options=transformer_options,
)
x = torch.addcmul(x, y, self.repeat_e(e[2], x))
# Cross-attention & ffn (unchanged from original)
x = x + self.cross_attn(
self.norm3(x),
context,
context_img_len=context_img_len,
transformer_options=transformer_options,
)
y = self.ffn(
torch.addcmul(
self.repeat_e(e[3], x), self.norm2(x), 1 + self.repeat_e(e[4], x)
)
)
x = torch.addcmul(x, y, self.repeat_e(e[5], x))
if self.last:
# put the relevant tokens back into the cached output
if idx is None or self.manager.cached_output is None:
self.manager.cached_output = x.clone()
else:
self.manager.cached_output[..., idx, :] = x
return self.manager.cached_output
return x
def repeat_e(self, e, x):
"""Helper function for modulation broadcasting"""
repeats = 1
if e.size(1) > 1:
repeats = x.size(1) // e.size(1)
if repeats == 1:
return e
if repeats * e.size(1) == x.size(1):
return torch.repeat_interleave(e, repeats, dim=1)
else:
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, : x.size(1)]
def self_attn_with_cache(self, x, freqs, idx, transformer_options={}):
"""Modified self-attention that uses KV caching"""
b, s, n, d = *x.shape[:2], self.num_heads, self.self_attn.head_dim
# just pull out part of the frequencies
if idx is not None:
freqs = freqs[:, idx]
# Compute QKV like original Wan self-attention
q = self.self_attn.norm_q(self.self_attn.q(x)).view(b, s, n, d)
q = apply_rope1(q, freqs).view(b, s, n * d)
k = self.self_attn.norm_k(self.self_attn.k(x)).view(b, s, n, d)
k = apply_rope1(k, freqs).view(b, s, n * d)
v = self.self_attn.v(x).view(b, s, n * d)
# RAS: KV Cache management
if idx is None:
self.k_cache = k
self.v_cache = v
else:
self.k_cache[:, idx, :] = k
self.v_cache[:, idx, :] = v
x = optimized_attention(
q,
self.k_cache,
self.v_cache,
heads=n,
transformer_options=transformer_options,
)
x = self.self_attn.o(x)
return x
class VaceWanAttentionBlockWrapper(WanAttentionBlockWrapper):
"""
Same as the VaceWanAttentionBlock, but uses a RASManager and RASCache to do KV caching.
"""
def __init__(
self,
original: VaceWanAttentionBlock,
manager: RASManager,
first=False,
last=False,
):
nn.Module.__init__(self)
# Copy attributes, handling the case where some might not exist
attributes_to_copy = [
"dim",
"ffn_dim",
"num_heads",
"window_size",
"qk_norm",
"cross_attn_norm",
"eps",
"block_id",
"norm1",
"self_attn",
"norm3",
"cross_attn",
"norm2",
"ffn",
"modulation",
]
# Handle optional attributes that might not exist
take_attributes_from(original, self, attributes_to_copy)
# Copy optional attributes if they exist
for attr in ["before_proj", "after_proj"]:
if hasattr(original, attr):
setattr(self, attr, getattr(original, attr))
self.manager = manager
self.k_cache: torch.Tensor
self.v_cache: torch.Tensor
self.first = first
self.last = last
def forward(self, c, x, **kwargs):
if hasattr(self, "before_proj"):
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c_skip, c
class HeadWrapper(Head):
"""
Same as the Head (Wan final layer), but reports its output to a manager.
"""
def __init__(self, original: Head, manager: RASManager):
nn.Module.__init__(self)
take_attributes_from(
original,
self,
["dim", "out_dim", "patch_size", "eps", "norm", "head", "modulation"],
)
self.manager = manager
def forward(self, x, e) -> Tensor:
output = super().forward(x, e)
self.manager.select_indices(output, self.manager.timestep)
return output