-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_text.py
More file actions
71 lines (55 loc) · 2.55 KB
/
generate_text.py
File metadata and controls
71 lines (55 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
from vocab_mapping.vocab_mapping import vocabulary_mapping
from min_lm.lm import MiniLM
from self_attention.self_attention import SelfAttention
from transformer.transformer import Transformer
from backbone_nn.embeddings.embed import Embedding
from trainer import trainer
from softmax.softm import softmax
import numpy as np
import argparse
def generate_text(model, seed_text, generate_len, vocab, seq_length):
"""
generates text using a sliding window approach
Args:
model: trained language model
seed_text: (prompt)
generate_len: number of tokens to generate
vocab: dictionary mapping tokens to indices
seq_length: fixed sequence length the model expects
Returns:
string with the generated text
"""
# tokenize the seed text.
seed_tokens = seed_text.split()
# tokens to indices.
seed_indices = [vocab.get(token, 0) for token in seed_tokens] # using index 0 if token not found
inv_vocab = {i: token for token, i in vocab.items()}
# ensure the seed has exactly seq_length tokens
if len(seed_indices) < seq_length:
# if shorter, repeating the seed until reaching the required length
while len(seed_indices) < seq_length:
seed_indices.extend(seed_indices)
seed_indices = seed_indices[:seq_length]
elif len(seed_indices) > seq_length:
# if longer, take only the last seq_length tokens
seed_indices = seed_indices[-seq_length:]
current_seq = torch.tensor(seed_indices, dtype=torch.long)
model.eval()
generated_tokens = seed_tokens.copy()
# generating output
for _ in range(generate_len):
with torch.no_grad():
logits = model(current_seq) # shape: [seq_length, vocab_size]
# We use the logits corresponding to the last token in the window.
last_logits = logits[-1]
# logits to probabilities
probs = softmax(last_logits, dim=0)
# sample the next token from the softmax probability distribution
next_token_idx = torch.multinomial(probs, num_samples=1).item()
# appending the generated token to convert index back to token
generated_tokens.append(inv_vocab[next_token_idx])
# updating current_seq: sliding the window by dropping the first token and appending the new token
current_seq = torch.cat([current_seq[1:], torch.tensor([next_token_idx], dtype=torch.long)])
return ' '.join(generated_tokens)