-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
43 lines (32 loc) · 1.35 KB
/
generate.py
File metadata and controls
43 lines (32 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import torch
import torch.nn as nn
from vocab_mapping.vocab_mapping import vocabulary_mapping
from min_lm.lm import MiniLM
from self_attention.self_attention import SelfAttention
from transformer.transformer import Transformer
from backbone_nn.embeddings.embed import Embedding
from trainer import trainer
from generate_text import generate_text
import numpy as np
import argparse
import shutil
def parse_args():
parser = argparse.ArgumentParser(description="Mini LLM")
parser.add_argument("--prompt", type=str, required=True, help="Insert the prompt")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if os.path.exists('./model/'):
model_config = torch.load("./model/trained_model.pth")
vocab_size = model_config["vocab_size"]
embed_dim = model_config["embed_dim"]
hidden_dim = model_config["hidden_dim"]
seq_length = model_config["seq_length"]
model = MiniLM(vocab_size, embed_dim, hidden_dim, seq_length)
# Loading the model parameters of the trained model
model.load_state_dict(model_config["state_dict"])
model.eval()
seed_text = args.prompt
generated_text = generate_text(model, seed_text, generate_len=20, vocab=model_config["vocab"], seq_length=seq_length)
print("Generated text:", generated_text)