forked from microsoft/monitors4codegen
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmonitor.py
More file actions
79 lines (67 loc) · 2.99 KB
/
monitor.py
File metadata and controls
79 lines (67 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Provides the definition of a monitor as per the Monitor-Guided Decoding framework
"""
from typing import List, Tuple
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper
from monitors4codegen.multilspy import LanguageServer
from monitors4codegen.multilspy.multilspy_config import Language
from dataclasses import dataclass
from monitors4codegen.multilspy.multilspy_utils import TextUtils
@dataclass
class MonitorFileBuffer:
"""
Dataclass for storing the state of the monitor for the prompt file in which the generation is happening
"""
lsp: LanguageServer
file_path: str
prompt_lc: Tuple[int, int]
current_lc: Tuple[int, int]
language: Language
gen_text: str = ""
def append_text(self, text: str):
"""
Appends the given text to the prompt file and returns the new line and character
"""
current_lc_index = TextUtils.get_index_from_line_col(
self.lsp.get_open_file_text(self.file_path), self.current_lc[0], self.current_lc[1]
)
new_lc = self.lsp.insert_text_at_position(self.file_path, self.current_lc[0], self.current_lc[1], text)
self.current_lc = (new_lc["line"], new_lc["character"])
self.gen_text += text
assert current_lc_index + len(text) == TextUtils.get_index_from_line_col(
self.lsp.get_open_file_text(self.file_path), self.current_lc[0], self.current_lc[1]
)
class Monitor:
"""
Provides the definition of a monitor as per the Monitor-Guided Decoding framework
"""
def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None:
self.tokenizer = tokenizer
self.monitor_file_buffer = monitor_file_buffer
self.responsible_for_file_buffer_state = responsible_for_file_buffer_state
async def pre(self) -> None:
"""
If the current state is uninitialized, or s0, this function checks
if the static analysis should be performed at this point.
If yes, it invokes the static analysis, and updates the state.
"""
raise NotImplementedError()
async def maskgen(self, input_ids: List[int]) -> List[int]:
"""
Given input_ids, which is the list of token ids generated so far (or input for the first time),
this function returns the list of token ids that should be masked for the next token generation.
This is the function that is invoked by the end user at every token decodes.
"""
raise NotImplementedError()
def a_phi(self):
"""
This function defines the implementation of the static analysis,
and returns the result of the static analysis.
It is invoked primarily by pre()
"""
raise NotImplementedError()
def update(self, generated_token: str):
"""
This function updates the state of the monitor, given the generated token.
"""
raise NotImplementedError()