@@ -23,7 +23,7 @@ def complete(prompt, steps=10,model=None,config=None,sp=None,device=None):
2323 import random
2424 import torch
2525 with torch .no_grad ():
26- words = prompt [:]
26+ words = sp . encode ( prompt , out_type = str )
2727 for _ in range (steps ):
2828 dist = predict_next (words , topk = 5 , model = model , config = config , sp = sp , device = device )
2929 words_probs = [(word , prob ) for word , prob in dist ]
@@ -36,7 +36,6 @@ def run(response, answer, params: Params) -> Result:
3636 print ("Loading Bengio-style Neural N-gram Language Model for inference..." )
3737 import torch
3838 import sentencepiece as spm
39- sp = spm .SentencePieceProcessor (model_file = "bpe.model" )
4039
4140 device = torch .device ("mps" if torch .backends .mps .is_available () else "cpu" )
4241
@@ -45,6 +44,10 @@ def run(response, answer, params: Params) -> Result:
4544 MODEL_DIR .mkdir (parents = True , exist_ok = True )
4645 MODEL_PATH = MODEL_DIR / "bengio_model.pt"
4746 MODEL_CONFIG_PATH = MODEL_DIR / "bengio_model_config.json"
47+ BPE_PATH = MODEL_DIR / "bpe.model"
48+ if not BPE_PATH .exists ():
49+ raise FileNotFoundError (f"Missing SentencePiece model at { BPE_PATH } " )
50+ sp = spm .SentencePieceProcessor (model_file = str (BPE_PATH ))
4851
4952 with open (MODEL_CONFIG_PATH ) as f :
5053 config = json .load (f )
@@ -61,16 +64,8 @@ def run(response, answer, params: Params) -> Result:
6164
6265 model .load_state_dict (torch .load (MODEL_PATH , map_location = device ))
6366 model .eval ()
67+ result = []
68+ completion = response if isinstance (response , str ) else "the general"
69+ result .append (complete (completion , steps = 20 , model = model , config = config , sp = sp , device = device ))
6470
65- completions = [
66- sp .encode ("the cat sat" , out_type = str ),
67- sp .encode ("the cat sat" , out_type = str ),
68- sp .encode ("the cat sat" , out_type = str ),
69- sp .encode ("the man saw" , out_type = str ),
70- sp .encode ("in the general" , out_type = str )
71- ]
72- for prompt in completions :
73- result = complete (prompt , steps = 20 , model = model , config = config , sp = sp , device = device )
74- print (f"Prompt: { ' ' .join (prompt )} \n Completion: { result } \n " )
75-
76- return Result (is_correct = True , feedback_items = [("general" , "Model loaded successfully for inference." )])
71+ return Result (is_correct = True , feedback_items = [("general" , '' .join (result ))])
0 commit comments