-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathheads.py
More file actions
949 lines (672 loc) · 33.7 KB
/
heads.py
File metadata and controls
949 lines (672 loc) · 33.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
import functools
import math
import einops
from sympy import symbols, Matrix, Rational
import numpy as onp
import copy
import cvxpy as cp
from typing import Sequence, Optional, Callable
from ml_collections import FrozenConfigDict
import jax
from jax import numpy as jnp, jit, lax
from flax import linen as nn
from flax.linen.linear import _conv_dimension_numbers
def compress(images, patch_size):
"""Rearrange images into non-overlapping patches: b (g0*p0) (g1*p1) ... c -> b g0 g1 ... c (p0 p1 ...)"""
image_grid_size = images.shape[1:-1]
# h_t, w_t = h // self.patch_size, w // self.patch_size
grid_size = [h // patch_size for h in image_grid_size]
dim_names = [f"i{i}" for i in range(len(grid_size))]
patch_dims = [f"p{i}" for i in range(len(grid_size))]
# print(dim_names, patch_dims)
patch_dims_str = " ".join(patch_dims)
grid_dims_str = " ".join(dim_names)
original_dims_str = " ".join([f"({h} {d})" for h, d in zip(dim_names, patch_dims)])
patch_dims_dict = {p: patch_size for p in patch_dims}
images = einops.rearrange(
images,
f'b {original_dims_str} c -> b {grid_dims_str} c ({patch_dims_str})',
**patch_dims_dict,
# p1=self.patch_size,
# p2=self.patch_size,
)
# positions = jnp.stack(jnp.meshgrid(jnp.arange(h_t), jnp.arange(w_t), indexing='ij'), axis=-1)
positions = jnp.stack(
jnp.meshgrid(*[jnp.arange(g) for g in grid_size], indexing='ij'), axis=-1
)
# positions = einops.repeat(positions, '... c -> b (...) c', b=n)
return images
prod_fn = lambda x: functools.reduce(lambda x, y: x * y, x)
@functools.partial(jit, static_argnums=(2), donate_argnums=(0))
def mask_out_place_holder(x, is_size, const):
"""Zero out (or fill with const) entries where is_size is False."""
return jnp.where(jnp.expand_dims(is_size, axis=range(is_size.ndim, x.ndim)), x, jnp.array(const))
def binomial_coefficients(n):
"""Generate binomial coefficients for (a + b)^n."""
return jnp.array([math.comb(n, k) for k in range(n + 1)], dtype=jnp.float32)
def finite_diff_1d_kernel(k):
"""
Generate a 1D finite difference kernel to approximate the first derivative
using 2k+1 symmetric stencil points.
"""
offsets = list(range(-k, k + 1))
A = []
b = []
for m in range(2 * k + 1):
row = [Rational(j)**m for j in offsets]
A.append(row)
b.append(1 if m == 1 else 0)
A = Matrix(A)
b = Matrix(b)
coeffs = A.LUsolve(b)
return jnp.array([float(c) for c in coeffs], dtype=jnp.float32)
@jax.ensure_compile_time_eval()
def SobelKernel(ndim, kernel_size=7):
"""Build an N-D Sobel kernel: derivative along each axis, smoothing along others."""
assert kernel_size % 2 == 1, "kernel size should be odd"
D = finite_diff_1d_kernel(kernel_size // 2)
S = binomial_coefficients(kernel_size - 1)
indices = jnp.indices((kernel_size,)*ndim)
kernel = []
for d in range(ndim):
ker = jnp.ones((kernel_size,)*ndim)
ker *= D.at[indices[d]].get()
for dd in range(ndim):
ker *= lax.cond(dd == d, lambda _: jnp.ones((kernel_size,)*ndim).astype(S.dtype), lambda _: S.at[indices[dd]].get(), 0)
ker /= (jnp.sqrt(jnp.sum(ker ** 2)) + 1e-5)
kernel.append(ker)
kernel = jnp.stack(kernel, axis=-1)
# normalize the kernel
return kernel
@functools.partial(jit, static_argnums=(1,), donate_argnums=(0,))
def SobelConv(x, kernel_size):
"""Apply depthwise Sobel convolution and return per-pixel gradient magnitude."""
kernel = SobelKernel(x.ndim-2, kernel_size=kernel_size)[...,None,:]
kernel = einops.repeat(kernel, '... c-> ... (k c)', k=x.shape[-1]).astype(x.dtype)
x = jnp.pad(x, [[0,0]] + [[k // 2, k // 2] for k in kernel.shape[:-2]] + [[0,0]], mode='reflect')
y = lax.conv_general_dilated(x, kernel, window_strides=(1 for _ in range(x.ndim-2)), padding='VALID', dimension_numbers = _conv_dimension_numbers(x.shape), feature_group_count=x.shape[-1])
y = einops.rearrange(y, '... (f c) -> ... f c', c = x.ndim-2)
# l2 norm of the image gradients
y = jnp.linalg.norm(y, axis=-1, ord=2)
return y
# @functools.partial(jit, static_argnums=(1,))
def get_pixel_score(x, normalize, sobel_kernel_size):
"""Compute per-pixel importance score via Sobel gradient magnitude."""
normalize = normalize
if normalize:
mean = x.mean(range(1, x.ndim-1), keepdims=True)
std = x.std(range(1, x.ndim-1), keepdims=True) + 1e-5
x = (x - mean) / std
x = SobelConv(x, sobel_kernel_size)
return x
def split(vertices):
"""Split each axis-aligned box at its midpoint into 2^D children. vertices: ... 2 D"""
choice = jnp.indices([2]*vertices.shape[-1])
choice = jnp.moveaxis(choice, 0, -1)
# S D
choice = einops.rearrange(choice, '... d -> (...) d')
# ... S D
choice = jnp.expand_dims(choice, axis=range(vertices.ndim-2))
# .... D
mid = (vertices[...,0,:] + vertices[...,1,:]) // 2
# assert (mid > vertices[...,0,:]).all(), "Midpoint is not greater than start vertex"
# assert (mid < vertices[...,1,:]).all(), "Midpoint is not less than end vertex"
length = mid - vertices[...,0,:]
v0 = jnp.where(choice, mid[...,None,:], vertices[...,0:1,:])
v1 = v0 + length[...,None,:]
return jnp.stack([v0, v1], axis=-2)
def get_block(x, start_vertices, block_size):
"""Extract fixed-size blocks from x at given start positions. x: N...F, start_vertices: NSD"""
indices = jnp.indices([block_size] * start_vertices.shape[-1])
# NSD...
block_indices = jnp.expand_dims(start_vertices, axis=range(start_vertices.ndim, start_vertices.ndim + indices.ndim-1))
block_indices += indices[None, None]
# DNS ...
block_indices = jnp.moveaxis(block_indices, 2,0)
# print(block_indices.shape)
batch_dim = jnp.expand_dims(jnp.arange(x.shape[0]), axis=tuple(range(1, block_indices.ndim-1)))
blocks = x[batch_dim, *block_indices]
return blocks
def get_score(x, vertices, pow, min_length, block_size):
"""Score each candidate patch by summing pixel scores within its bounding box, optionally normalized by area^pow."""
# pow is legact and is set to be zero exclusively.
grid_size = x.shape[1:1+vertices.shape[-1]]
grid_size = jnp.array(grid_size)
# NSD
start_indices = jnp.minimum(grid_size[None,None,] - block_size, vertices[...,0,:])
# NS2D
vertices = vertices - start_indices[...,None,:]
# NS...F
blocks = get_block(x, start_indices, block_size)
# D ...
idx = jnp.indices(blocks.shape[2:2+len(grid_size)])
# NS... D
idx = jnp.moveaxis(idx, 0, -1)[None,None]
edge_length = vertices[...,1,:] - vertices[...,0,:]
vertices = jnp.expand_dims(vertices, axis=range(2, idx.ndim-1))
start_vertices = vertices[...,0,:]
end_vertices = vertices[...,1,:]
# NS...
choice = (idx >=start_vertices).all(-1) & (idx < end_vertices).all(-1)
blocks = mask_out_place_holder(blocks, choice, 0)
raw_score = blocks.sum(range(2, blocks.ndim))
area = edge_length.prod(-1)
if pow == 0:
adjusted_score = raw_score
else:
adjusted_score = raw_score / jnp.pow(area, pow)
adjusted_score = jnp.where((edge_length > min_length).all(-1), adjusted_score, -jnp.inf)
return adjusted_score
def insert_sorted(existing_scores, new_scores, existing_vertices, new_vertices):
"""Merge two sorted lists of (scores, vertices) and re-sort by score."""
scores = jnp.concatenate([existing_scores, new_scores], axis=1)
idx = jnp.argsort(scores, axis=1)
scores = scores[jnp.arange(scores.shape[0])[:,None], idx]
vertices = jnp.concatenate([existing_vertices, new_vertices], axis=1)
vertices = vertices[jnp.arange(vertices.shape[0])[:,None], idx]
return scores, vertices
@functools.partial(jit, donate_argnums=[1,2], static_argnums=[3,4,5,6])
def split_body_fn(x, sorted_list_of_vertices, sorted_list_of_scores, block_size, pow, min_length, split_per_step):
"""One step of iterative splitting: pop top-scoring patches, split them, score children, merge back."""
# remember the fixed buffer size so we can truncate back to it at the end
allocated_size = sorted_list_of_vertices.shape[1]
# pop the top split_per_step patches (highest scores are at the end)
vertices_to_split = sorted_list_of_vertices[:,-split_per_step:]
# remove the popped patches from the sorted list
sorted_list_of_vertices = sorted_list_of_vertices[:,:-split_per_step]
sorted_list_of_scores = sorted_list_of_scores[:,:-split_per_step]
# split each popped patch at its midpoint into 2^D children
split_vertices = split(vertices_to_split)
# flatten the (num_popped, 2^D) children into a single sequence
split_vertices = einops.rearrange(split_vertices, 'b s n k d -> b (s n) k d')
# score each child patch by its total gradient within its bounding box
split_scores = get_score(x, split_vertices, pow, min_length, block_size)
# sort the new children by score before merging
sort_idx = jnp.argsort(split_scores, axis=1)
split_score = split_scores[jnp.arange(split_scores.shape[0])[:,None], sort_idx]
split_vertices = split_vertices[jnp.arange(split_vertices.shape[0])[:,None], sort_idx]
# merge children into the existing sorted list and re-sort
sorted_list_of_scores, sorted_list_of_vertices = insert_sorted(
sorted_list_of_scores, split_score, sorted_list_of_vertices, split_vertices
)
# truncate back to the fixed buffer size, keeping only the highest-scoring patches
sorted_list_of_vertices = sorted_list_of_vertices[:,-allocated_size:]
sorted_list_of_scores = sorted_list_of_scores[:,-allocated_size:]
return sorted_list_of_vertices, sorted_list_of_scores
def iterative_splitting(x, init_vertices, block_size, pow, min_length, split_per_step, num_split):
"""Run num_split rounds of greedy patch splitting via lax.scan."""
allocated_length = num_split * (2 ** init_vertices.shape[-1] - 1) * split_per_step
# print("allocated length: ", allocated_length)
sorted_list_of_scores = get_score(x, init_vertices, pow, min_length, block_size)
sort_idx = jnp.argsort(sorted_list_of_scores, axis=1)
sorted_list_of_vertices = init_vertices[jnp.arange(init_vertices.shape[0])[:,None], sort_idx]
sorted_list_of_scores = sorted_list_of_scores[jnp.arange(sorted_list_of_scores.shape[0])[:,None], sort_idx]
# print(sorted_list_of_vertices.shape)
placeholder_for_vertices = jnp.zeros((x.shape[0], allocated_length, 2, sorted_list_of_vertices.shape[-1]), dtype=sorted_list_of_vertices.dtype) - 1
placeholder_for_scores = jnp.full((x.shape[0], allocated_length), -jnp.inf, dtype=sorted_list_of_scores.dtype)
sorted_list_of_vertices = jnp.concatenate([placeholder_for_vertices, sorted_list_of_vertices], axis=1)
sorted_list_of_scores = jnp.concatenate([placeholder_for_scores, sorted_list_of_scores], axis=1)
split_fn = lambda carry, _: (split_body_fn(x, *carry, block_size, pow, min_length, split_per_step), ())
carry = (sorted_list_of_vertices, sorted_list_of_scores)
(sorted_list_of_vertices, sorted_list_of_scores), _ = lax.scan(
split_fn, carry, None, length=num_split
)
return sorted_list_of_vertices, sorted_list_of_scores
def batch_pseudo_kdtree(
x,
num_tokens,
min_length,
max_length,
pow,
split_per_step=1,
sobel_kernel_size=5,
normalize=True,
pre_compress=False,
):
"""
Iterative Adaptive Patchification (IAP): produce variable-size patches by
greedily splitting high-gradient regions. Starts from a coarse max_length grid
and iteratively splits patches with the highest gradient score until num_tokens
patches are reached. Returns vertices (bounding boxes) and scores.
"""
if pre_compress:
# x = nn.avg_pool(x, (min_length,)*(x.ndim-2), strides=(min_length,)*(x.ndim-2), padding='VALID')
x = compress(x, min_length).mean(-1)
print('Pre-compressed input shape:', x.shape)
# average the channel scores
x = get_pixel_score(x, normalize, sobel_kernel_size).mean(-1)
score_mat = x
if not pre_compress:
# x = nn.avg_pool(x[...,None], (min_length,)*(x.ndim-1), strides=(min_length,)*(x.ndim-1), padding='VALID').squeeze(-1)
x = compress(x[...,None], min_length).mean(-1).squeeze(-1)
# from now on x is compressed
max_length = max_length // min_length
original_min_length = min_length
min_length = 1
increment_per_step = (2 ** (x.ndim-1) - 1) * split_per_step
existing_patch = prod_fn(x.shape[1:]) // max_length ** (x.ndim-1)
num_split = math.floor((num_tokens - existing_patch) / increment_per_step)
init_vertices = jnp.indices([ s // max_length for s in x.shape[1:]])
init_vertices *= max_length
init_vertices = einops.repeat(init_vertices, 'c ... -> b (...) c', b=x.shape[0])
init_vertices = jnp.stack(
[init_vertices, init_vertices + max_length], axis=-2
)
# return init_vertices
sorted_list_of_vertices, sorted_list_of_scores = iterative_splitting(x, init_vertices, max_length, pow, min_length, split_per_step, num_split)
sorted_list_of_vertices *= original_min_length
return dict(
vertice_list=sorted_list_of_vertices,
score_list=sorted_list_of_scores,
score_mat=score_mat,
)
@jax.ensure_compile_time_eval()
def get_max_size(lengths, grid_size, num_tokens):
"""Solve LP to find the max possible count of each patch size, used for pre-allocation."""
sizes = onp.array(lengths) ** len(grid_size)
A = sizes
max_patch_num = []
total_size = functools.reduce(lambda x, y: x * y, grid_size)
for i in range(len(lengths)):
x = cp.Variable(len(lengths))
constraints = [cp.sum(x) == num_tokens, x >= 0, A@x == total_size]
objective = cp.Maximize(x[i])
prob = cp.Problem(objective, constraints)
_ = prob.solve()
opt_x = x.value
opt_x = onp.floor(opt_x).astype(onp.int32)
# occupied area after the flooring
filled_area = onp.dot(sizes, opt_x)
# additional offset to the max (floored) number of patches
unfilled_area = total_size - filled_area
opt_x = onp.ceil(unfilled_area / sizes[i]) + opt_x[i]
max_patch_num.append(int(min(opt_x, num_tokens)))
return tuple(max_patch_num)
def get_trunk_for_size(tokens, length_list, vertice_list, size_preallocation, size):
"""Select up to size_preallocation tokens matching a given patch size."""
is_size = length_list == size
idx = jnp.argpartition(-is_size.astype(jnp.int32), kth=size_preallocation-1, axis=-1, )[:,:size_preallocation]
tokens = tokens[jnp.arange(tokens.shape[0])[:,None], idx]
vertice_list = vertice_list[jnp.arange(tokens.shape[0])[:,None], idx]
is_size = is_size[jnp.arange(tokens.shape[0])[:,None], idx]
return tokens, vertice_list, is_size
def merge_patch_features(patch_list, vertice_list, length_list, is_size_list, num_tokens):
"""Concatenate patches from all sizes and select the top num_tokens valid ones."""
patch_list = jnp.concatenate(patch_list, axis=1)
vertice_list = jnp.concatenate(vertice_list, axis=1)
length_list = jnp.concatenate(length_list, axis=1)
is_size_list = jnp.concatenate(is_size_list, axis=1)
is_size_indices = jnp.argpartition(-is_size_list.astype(jnp.int32), kth=num_tokens-1, axis=-1)[:,:num_tokens]
patch_list = patch_list[jnp.arange(patch_list.shape[0])[:,None], is_size_indices]
vertice_list = vertice_list[jnp.arange(patch_list.shape[0])[:,None], is_size_indices]
length_list = length_list[jnp.arange(patch_list.shape[0])[:,None], is_size_indices]
return patch_list, vertice_list, length_list
def get_patches(x, vertices, sizes, size_preallocations, flatten):
"""Extract raw image patches grouped by size from the original image."""
edge_length = vertices[...,1,0] - vertices[...,0,0]
# NS
edge_length = edge_length.astype(jnp.int32)
out = []
for s, p in zip(sizes, size_preallocations):
is_size = edge_length == s
is_size_idx = jnp.argpartition(-is_size.astype(jnp.int32), kth=p-1, axis=-1)[:,:p]
vertices_for_size = vertices[jnp.arange(vertices.shape[0])[:, None], is_size_idx]
is_size = is_size[jnp.arange(vertices.shape[0])[:, None], is_size_idx]
patches_for_size = get_block(x, vertices_for_size[...,0,:], s)
if flatten:
patches_for_size = einops.rearrange(patches_for_size, 'b s ... -> b s (...)')
out.append(dict(patches=patches_for_size, vertices=vertices_for_size, is_size=is_size))
return out
class VariableSizePatchEmbedder(nn.Module):
"""Encoder: run IAP to get variable-size patches, then project each size
through its own Dense layer into a shared embedding space."""
tree_config: FrozenConfigDict
levels: Sequence[int] # available patch sizes, e.g. [2, 4, 8]
qkv_dim: int # output embedding dimension
compute_precision: jnp.dtype = jnp.float32
params_precision: jnp.dtype = jnp.float32
def setup(self):
embed = dict()
self.sizes = self.levels
# initialize the corresponding patch embedding layers, regardless if it is going to be used
for s in self.sizes:
embed[s] = nn.Dense(
self.qkv_dim,
name=f'patch_embd_{s}',
dtype=self.compute_precision,
param_dtype=self.params_precision,
)
self.embed = embed
@nn.compact
def __call__(self, images, num_tokens, *args, **kwargs):
# N, H, W, C = images.shape
min_level = min(self.sizes)
max_level = max(self.sizes)
grid_size = images.shape[1:-1]
min_num_tokens = prod_fn(grid_size) // (max_level ** len(grid_size))
max_num_tokens = prod_fn(grid_size) // (min_level ** len(grid_size))
tree_out = batch_pseudo_kdtree(
images,
num_tokens=num_tokens,
min_length=self.tree_config.min_length,
max_length=self.tree_config.max_length,
split_per_step=self.tree_config.split_per_step,
sobel_kernel_size=self.tree_config.sobel_kernel_size,
pow=self.tree_config.pow,
pre_compress=self.tree_config.pre_compress,
)
vertices = tree_out['vertice_list']
score_mat = tree_out['score_mat']
self.sow('intermediates', 'score_mat', score_mat)
num_tokens = vertices.shape[1]
assert num_tokens <= max_num_tokens and num_tokens >= min_num_tokens, f"Number of tokens {num_tokens} is out of range [{min_num_tokens}, {max_num_tokens}] for image size {images.shape[1:-1]}"
max_num_patch = get_max_size(self.sizes, images.shape[1:-1], num_tokens)
edge_length = vertices[...,1,0] - vertices[...,0,0]
patch_list = []
vertices_list = []
length_list = []
is_size_list = []
length_list = []
for s, p in zip(self.sizes, max_num_patch):
is_size = edge_length == s
is_size_idx = jnp.argpartition(-is_size.astype(jnp.int32), kth=p-1, axis=-1)[:,:p]
vertices_for_size = vertices[jnp.arange(vertices.shape[0])[:, None], is_size_idx]
is_size = is_size[jnp.arange(vertices.shape[0])[:, None], is_size_idx]
# N S ... F
patches_for_size = get_block(images, vertices_for_size[...,0,:], s)
patches_for_size = einops.rearrange(patches_for_size, 'b s ... -> b s (...)')
patches_for_size = self.embed[s](patches_for_size)
patch_list.append(patches_for_size)
vertices_list.append(vertices_for_size)
is_size_list.append(is_size)
length_list.append(is_size * s)
patch_list, vertice_list, length_list = merge_patch_features(patch_list, vertices_list, length_list, is_size_list, num_tokens)
# the length list is sorted
return patch_list, vertice_list, length_list#, max_num_patch, self.sizes
def apply_rope(q, positions, max_scale):
"""Apply 1D Rotary Position Embedding. q: ...l d, positions: ..."""
head_dim = q.shape[-1]
log_scale = jnp.linspace(0, 1, head_dim // 2, endpoint=False)
scale = max_scale ** log_scale
positions = positions[...,None] / scale.reshape([1]*(positions.ndim) + [-1])
q0, q1 = jnp.split(q, 2, axis=-1)
s = jnp.sin(positions)
c = jnp.cos(positions)
x0 = q0 * c - q1 * s
x1 = q0 * s + q1 * c
return jnp.concatenate([x0, x1], axis=-1)
def apply_nd_rope(q, positions, max_scale):
"""Apply RoPE for N-D positions by splitting head_dim across spatial axes."""
num_heads = q.shape[-2]
batch_dims = q.shape[:-3]
batch_dims_str = ' '.join([f'b{i}' for i in range(len(batch_dims))])
batch_dims_dict = {f'b{i}': d for i, d in enumerate(batch_dims)}
q = einops.rearrange(q, f'{batch_dims_str} l h d -> ({batch_dims_str}) l h d')
q = einops.rearrange(q, 'b l h (c f) -> b l h c f', c = positions.shape[-1])
positions = einops.repeat(positions, f'{batch_dims_str} l c -> ({batch_dims_str}) l h c', h=num_heads)
q = apply_rope(q, positions, max_scale)
q = einops.rearrange(q, 'b l h c f -> b l h (c f)')
return einops.rearrange(q, f'({batch_dims_str}) l h d -> {batch_dims_str} l h d', **batch_dims_dict)
def make_rope_attention(attention_fn, positions_q, positions_k, max_scale):
"""Wrap an attention function to apply RoPE to queries and keys before attending."""
def rope_attention_fn(query, key, value, *args, **kwargs):
q = apply_nd_rope(query, positions_q, max_scale)
k = apply_nd_rope(key, positions_k, max_scale)
return attention_fn(q, k, value, *args, **kwargs)
return rope_attention_fn
class FeedForward(nn.Module):
"""
one layer mlp with gelu activation, no skip connection.
"""
hidden_dim: int
out_dim: int
compute_precision: jnp.dtype = jnp.float32
params_precision: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
x = nn.Dense(
self.hidden_dim,
dtype=self.compute_precision,
param_dtype=self.params_precision,
)(x)
x = nn.gelu(x)
x = nn.Dense(
self.out_dim,
dtype=self.compute_precision,
param_dtype=self.params_precision,
)(x)
return x
class DecoderBlock(nn.Module):
"""Cross-attention block: LayerNorm -> MultiHeadAttention(q, kv) + residual -> FFN + residual."""
num_heads: int
mlp_ratio: int
attention_fn: Callable = None
compute_precision: jnp.dtype = jnp.float32
params_precision: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, inputs_q, inputs_kv, mask=None, attention_fn=None):
in_dim = inputs_q.shape[-1]
if self.attention_fn is None:
if attention_fn is None:
attention_fn = nn.dot_product_attention
else:
attention_fn = self.attention_fn
attention = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
deterministic=True,
attention_fn=attention_fn,
dtype=self.compute_precision,
param_dtype=self.params_precision,
)(
nn.LayerNorm(
dtype=self.params_precision,
param_dtype=self.params_precision,
)(inputs_q),
nn.LayerNorm(
dtype=self.params_precision,
param_dtype=self.params_precision,
)(inputs_kv),
mask=mask,
)
x = attention + inputs_q
ffl_layer = FeedForward(
hidden_dim=self.mlp_ratio * in_dim,
out_dim = in_dim,
compute_precision=self.compute_precision,
params_precision=self.params_precision,
)
x += ffl_layer(
nn.LayerNorm(
dtype=self.params_precision,
param_dtype=self.params_precision,
)(x)
)
return x
def tokens_to_patch_grid(
tokens: jnp.ndarray, # (B, N, F)
boxes: jnp.ndarray, # (B, N, 2, 2) [[y0,x0],[y1,x1]], half-open
shapes: Sequence[int], # (H, W)
p: int,
lengths: jnp.ndarray = None, # (B,N)
avail_sizes: Sequence[int] = None
):
"""Map variable-size token embeddings back onto a dense (H/p, W/p) grid by
tiling each token into all grid cells it covers. Optionally returns per-cell
in-patch positional indices for learnable within-patch embeddings."""
assert all([s % p == 0 for s in shapes]), "Incoming sizes must be divisible by p"
B, N, F = tokens.shape
D = boxes.shape[-1]
# rescale boxes from pixel coords to patch-grid coords
boxes//= p
# build a dense grid of unit-cell locations: each cell is 1x1 in patch-grid space
patch_grid_shape = [s // p for s in shapes]
# patch_loc[i0, i1, ..., d] = coordinate of grid cell (i0, i1, ...) along axis d
patch_loc = jnp.stack(
jnp.meshgrid(
*[jnp.arange(s) for s in patch_grid_shape], indexing="ij"
),
axis=-1
)
# each unit cell as a box: [start, end) with end = start + 1
patch_boxes = jnp.stack(
[patch_loc, (patch_loc + 1)],
axis=-2
)
# extract start/end corners of each token's bounding box
start, end = boxes[:, :, 0, :], boxes[:, :, 1, :] # (B, N, D)
# broadcast for comparison: (B, 1...1, N, D) vs (1, H/p, W/p, ..., 1, D)
start = jnp.expand_dims(start, axis=range(1, 1+D)) # (B, ..., N, D)
end = jnp.expand_dims(end, axis=range(1, 1+D)) # (B, ..., N, D)
patch_start = patch_boxes[None,...,0:1,:]
patch_end = patch_boxes[None,...,1:2,:]
# check which unit cell falls entirely within which token's box
inside = (patch_start >= start) & (patch_end <= end)
inside = jnp.all(inside, axis=-1) # (B, H/p, W/p, ..., N)
# for each grid cell, pick the token whose box contains it
idx = jnp.argmax(inside, axis=-1) # (B, H/p, W/p, ...)
# gather token features onto the dense grid (tiles large patches)
grid_feats = tokens[jnp.expand_dims(jnp.arange(B), axis=range(1, idx.ndim)), idx]
if avail_sizes is None:
return grid_feats
else:
# also compute within-patch positional indices for learnable in-patch embeddings
lengths//= p
avail_sizes = [s // p for s in avail_sizes]
# look up the patch size at each grid cell
grid_of_sizes = lengths[jnp.expand_dims(jnp.arange(B), axis=range(1, idx.ndim)), idx]
# look up the bounding box at each grid cell
grid_of_boxes = boxes[jnp.expand_dims(jnp.arange(B), axis=range(1, idx.ndim)), idx]
# offset of this grid cell relative to the patch's top-left corner
grid_of_index_within_boxes = (patch_loc - grid_of_boxes[...,0,:]) # B,...,D
# flatten the N-D within-patch offset into a 1D index (row-major with patch size as stride)
grid_multiplier = grid_of_sizes[...,None] ** jnp.expand_dims(jnp.arange(D), axis=range(grid_of_sizes.ndim))
grid_of_index_within_boxes = jnp.sum(grid_of_index_within_boxes * grid_multiplier, axis=-1) # B,...
# compute cumulative offsets so each patch size maps to a distinct range of embedding indices
size_step = jnp.array(avail_sizes) ** D
size_steps = jnp.cumsum(size_step)[:-1]
size_steps = jnp.concatenate([jnp.array([0]), size_steps], axis=0)
# map each grid cell's patch size to its size group index
size_indices = jnp.searchsorted(jnp.array(avail_sizes), grid_of_sizes.astype(jnp.int32))
# look up the cumulative offset for this size group
grid_of_steps = size_steps[size_indices]
# final index = size_group_offset + within_patch_offset, used to index into the gate/bias table
grid_of_index_within_boxes = grid_of_index_within_boxes + grid_of_steps
return grid_feats, grid_of_index_within_boxes
class LatentDecoder(nn.Module):
"""Decoder: tile encoder tokens onto a dense latent grid, apply learnable
in-patch gate/bias, then cross-attend from grid latents to encoder tokens
with RoPE, and project to output channels. Supports trunked attention for
large grids."""
grid_size: tuple # original spatial size, e.g. (128, 384)
out_dim: int # number of output channels per pixel
num_heads: int
mlp_ratio: int
max_pos_embed_scale: int
num_decoder_layers: int = 1
patch_size: int = 1
available_sizes: Sequence[int] = None
compute_precision: jnp.dtype = jnp.float32
params_precision: jnp.dtype = jnp.float32
trunk_size: Optional[int] = None
in_patch_latents: bool = True
@nn.compact
def __call__(self, inputs_kv, positions):
D = positions.shape[-1]
in_dim = inputs_kv.shape[-1]
latent_grid_size = [s // self.patch_size for s in self.grid_size]
lengths = positions[:, :, 1, 0] - positions[:, :, 0, 0] # (B,N)
lengths = lengths.astype(jnp.int32)
total_size_embd_count = sum([s**D for s in self.available_sizes])
gate = self.param(
'gate',
nn.initializers.normal(stddev=1.0, dtype=self.params_precision),
(total_size_embd_count, in_dim),
)
bias = self.param(
'bias',
lambda *args: nn.initializers.zeros(*args, dtype=self.params_precision),
(total_size_embd_count, in_dim),
)
# this is the tiled residual
residual, residual_indices = tokens_to_patch_grid(
inputs_kv,
positions,
self.grid_size,
self.patch_size,
lengths,
self.available_sizes
)
self.sow('intermediates', 'residual_indices', residual_indices)
residual = einops.rearrange(residual, 'b ... d -> b (...) d')
if self.in_patch_latents:
# legacy, did not improve the performance.
residual_indices = einops.rearrange(residual_indices, 'b ... -> b (...)')
# a learnable bias to mess up the uniformity from tiling
residual_bias = bias[residual_indices].astype(self.compute_precision)
residual_gate = nn.sigmoid(gate[residual_indices]).astype(self.compute_precision)
latents = residual.astype(self.compute_precision) * residual_gate + residual_bias
else:
latents = residual.astype(self.compute_precision)
print('LatentDecoder: not using in-patch latents!')
positions = positions.mean(axis=-2)
# define the position of the latents as the center of each patch
latents_positions = jnp.stack(
jnp.meshgrid(*[jnp.arange(g) for g in latent_grid_size], indexing='ij'), axis=-1
) * self.patch_size + self.patch_size / 2.0
latents_positions = einops.repeat(latents_positions, '... c-> b (...) c', b=inputs_kv.shape[0])
for i in range(self.num_decoder_layers):
if self.trunk_size is not None and self.trunk_size < latents.shape[1]:
if self.is_initializing():
print(f'Using trunked attention with trunk size {self.trunk_size} on decoder with {latents.shape[1]} tokens')
trunked_outputs = []
num_trunks = math.ceil(latents.shape[1] / self.trunk_size)
decoder_block = DecoderBlock(
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
compute_precision=self.compute_precision,
params_precision=self.params_precision,
)
for t in range(num_trunks):
trunk_latents = latents[:, t*self.trunk_size:(t+1)*self.trunk_size, :]
trunk_latent_positions = latents_positions[:, t*self.trunk_size:(t+1)*self.trunk_size, :]
attention_fn = make_rope_attention(
nn.dot_product_attention,
max_scale=self.max_pos_embed_scale,
positions_q=trunk_latent_positions,
positions_k=positions,
)
trunk_latents = decoder_block(trunk_latents, inputs_kv, attention_fn=attention_fn)
trunked_outputs.append(trunk_latents)
latents = jnp.concatenate(trunked_outputs, axis=1)
else:
attention_fn = make_rope_attention(
nn.dot_product_attention,
max_scale=self.max_pos_embed_scale,
positions_q=latents_positions,
positions_k=positions,
)
latents = DecoderBlock(
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attention_fn=attention_fn,
compute_precision=self.compute_precision,
params_precision=self.params_precision,
)(latents, inputs_kv)
out = latents
out = FeedForward(
hidden_dim=out.shape[-1],
out_dim=self.out_dim * self.patch_size ** len(latent_grid_size),
compute_precision=self.compute_precision,
params_precision=self.params_precision,
)(out)
image_grid_size = self.grid_size
grid_size = [h // self.patch_size for h in image_grid_size]
dim_names = [f"i{i}" for i in range(len(grid_size))]
patch_dims = [f"p{i}" for i in range(len(grid_size))]
patch_dims_str = " ".join(patch_dims)
grid_dims_str = " ".join(dim_names)
original_dims_str = " ".join([f"({h} {d})" for h, d in zip(dim_names, patch_dims)])
patch_dims_dict = {p: self.patch_size for p in patch_dims}
grid_dims_dict = {d: g for d, g in zip(dim_names, grid_size)}
out = einops.rearrange(out, f'b ({grid_dims_str}) ({patch_dims_str} d) -> b {original_dims_str} d', **patch_dims_dict, **grid_dims_dict)
return out