Skip to content

Commit a81932a

Browse files
Peter JohnsonPeter Johnson
authored andcommitted
fix bengio model
1 parent b09234f commit a81932a

4 files changed

Lines changed: 8018 additions & 15 deletions

File tree

evaluation_function/models/bengio_infer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def complete(prompt, steps=10,model=None,config=None,sp=None,device=None):
2323
import random
2424
import torch
2525
with torch.no_grad():
26-
words = prompt[:]
26+
words = sp.encode(prompt, out_type=str)
2727
for _ in range(steps):
2828
dist = predict_next(words, topk=5, model=model, config=config, sp=sp, device=device)
2929
words_probs = [(word, prob) for word, prob in dist]
@@ -36,7 +36,6 @@ def run(response, answer, params: Params) -> Result:
3636
print("Loading Bengio-style Neural N-gram Language Model for inference...")
3737
import torch
3838
import sentencepiece as spm
39-
sp = spm.SentencePieceProcessor(model_file="bpe.model")
4039

4140
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
4241

@@ -45,6 +44,10 @@ def run(response, answer, params: Params) -> Result:
4544
MODEL_DIR.mkdir(parents=True, exist_ok=True)
4645
MODEL_PATH = MODEL_DIR / "bengio_model.pt"
4746
MODEL_CONFIG_PATH = MODEL_DIR / "bengio_model_config.json"
47+
BPE_PATH = MODEL_DIR / "bpe.model"
48+
if not BPE_PATH.exists():
49+
raise FileNotFoundError(f"Missing SentencePiece model at {BPE_PATH}")
50+
sp = spm.SentencePieceProcessor(model_file=str(BPE_PATH))
4851

4952
with open(MODEL_CONFIG_PATH) as f:
5053
config = json.load(f)
@@ -61,16 +64,8 @@ def run(response, answer, params: Params) -> Result:
6164

6265
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
6366
model.eval()
67+
result=[]
68+
completion = response if isinstance(response, str) else "the general"
69+
result.append(complete(completion, steps=20, model=model, config=config, sp=sp, device=device))
6470

65-
completions = [
66-
sp.encode("the cat sat", out_type=str),
67-
sp.encode("the cat sat", out_type=str),
68-
sp.encode("the cat sat", out_type=str),
69-
sp.encode("the man saw", out_type=str),
70-
sp.encode("in the general", out_type=str)
71-
]
72-
for prompt in completions:
73-
result = complete(prompt, steps=20, model=model, config=config, sp=sp, device=device)
74-
print(f"Prompt: {' '.join(prompt)}\nCompletion: {result}\n")
75-
76-
return Result(is_correct=True, feedback_items=[("general", "Model loaded successfully for inference.")])
71+
return Result(is_correct=True, feedback_items=[("general", ''.join(result))])
361 KB
Binary file not shown.

0 commit comments

Comments
 (0)