Hi, there seems to be a bug in the calculation of final_window_start:
|
# Next, select indices of the sequence such that it will result in embeddings representing the original |
|
# sentence. To capture maximal context, the indices will be the middle part of each embedded window |
|
# sub-sequence (plus any leftover start and final edge windows), e.g., |
|
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
|
# "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]" |
|
# with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start |
|
# and final windows with indices [0, 1] and [14, 15] respectively. |
|
|
|
# Find the stride as half the max pieces, ignoring the special start and end tokens |
|
# Calculate an offset to extract the centermost embeddings of each window |
|
stride = (self.max_pieces - self.start_tokens - self.end_tokens) // 2 |
|
stride_offset = stride // 2 + self.start_tokens |
|
|
|
first_window = list(range(stride_offset)) |
|
|
|
max_context_windows = [i for i in range(full_seq_len) |
|
if stride_offset - 1 < i % self.max_pieces < stride_offset + stride] |
|
|
|
final_window_start = full_seq_len - (full_seq_len % self.max_pieces) + stride_offset + stride |
|
final_window = list(range(final_window_start, full_seq_len)) |
|
|
|
select_indices = first_window + max_context_windows + final_window |
On the test case from your comment, final_window_start is greater than full_seq_len:
full_seq_len = 16
max_pieces = 8
start_tokens = 1
end_tokens = 1
# Next, select indices of the sequence such that it will result in embeddings representing the original
# sentence. To capture maximal context, the indices will be the middle part of each embedded window
# sub-sequence (plus any leftover start and final edge windows), e.g.,
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
# with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
# and final windows with indices [0, 1] and [14, 15] respectively.
# Find the stride as half the max pieces, ignoring the special start and end tokens
# Calculate an offset to extract the centermost embeddings of each window
stride = (max_pieces - start_tokens - end_tokens) // 2
stride_offset = stride // 2 + start_tokens
first_window = list(range(stride_offset))
max_context_windows = [i for i in range(full_seq_len)
if stride_offset - 1 < i % max_pieces < stride_offset + stride]
final_window_start = full_seq_len - (full_seq_len % max_pieces) + stride_offset + stride
final_window = list(range(final_window_start, full_seq_len))
select_indices = first_window + max_context_windows + final_window
print(select_indices)
Output is [0, 1, 2, 3, 4, 10, 11, 12] and [14, 15] is missing.
Hi, there seems to be a bug in the calculation of final_window_start:
udify/udify/modules/bert_pretrained.py
Lines 488 to 509 in cbabef6
On the test case from your comment,
final_window_startis greater than full_seq_len:Output is
[0, 1, 2, 3, 4, 10, 11, 12]and[14, 15]is missing.