Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions bindings/chatllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str
self._chatllm_multimedia_msg_prepare = self._lib.chatllm_multimedia_msg_prepare
self._chatllm_multimedia_msg_append = self._lib.chatllm_multimedia_msg_append
self._chatllm_user_input_multimedia_msg = self._lib.chatllm_user_input_multimedia_msg
self._chatllm_destroy = self._lib.chatllm_destroy

self._chatllm_async_user_input = self._lib.chatllm_async_user_input
self._chatllm_async_ai_continue = self._lib.chatllm_async_ai_continue
Expand Down Expand Up @@ -149,9 +150,13 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str
self._chatllm_async_user_input.argtypes = [c_void_p, c_char_p]

self._chatllm_user_input_multimedia_msg.restype = c_int
self._chatllm_destroy = self._lib.chatllm_destroy
self._chatllm_user_input_multimedia_msg.argtypes = [c_void_p]
self._chatllm_destroy = self._lib.chatllm_destroy
self._chatllm_async_user_input_multimedia_msg.restype = c_int
self._chatllm_async_user_input_multimedia_msg.argtypes = [c_void_p]
self._chatllm_destroy.restype = c_int
self._chatllm_destroy.argtypes = [c_void_p]

self._chatllm_tool_input.restype = c_int
self._chatllm_tool_input.argtypes = [c_void_p, c_char_p]
Expand Down Expand Up @@ -268,6 +273,10 @@ def start(self, obj: c_void_p, callback_obj: Any) -> int:
def set_ai_prefix(self, obj: c_void_p, prefix: str) -> int:
return self._chatllm_set_ai_prefix(obj, c_char_p(prefix.encode()))

def destroy(self, obj: c_void_p) -> int:
if obj is None: return 0
return self._chatllm_destroy(obj)

def _input_multimedia_msg(self, obj: c_void_p, user_input: List[dict | str]) -> int:
self._chatllm_multimedia_msg_prepare(obj)
for x in user_input:
Expand Down Expand Up @@ -299,6 +308,7 @@ def chat(self, obj: c_void_p, user_input: str | List[dict | str]) -> int:
elif isinstance(user_input, list):
self._input_multimedia_msg(obj, user_input)
return self._chatllm_user_input_multimedia_msg(obj)
self._chatllm_destroy = self._lib.chatllm_destroy

def async_chat(self, obj: c_void_p, user_input: str | List[dict | str]) -> int:
if isinstance(user_input, str):
Expand Down Expand Up @@ -562,6 +572,13 @@ def save_session(self, file_name: str) -> str:
def load_session(self, file_name: str) -> str:
return self._lib.load_session(self._chat, file_name)

def destroy(self) -> int:
if hasattr(self, "_chat") and self._chat:
if self.is_generating: self.abort()
self._lib.destroy(self._chat)
self._chat = None
return 0

def callback_print_reference(self, s: str) -> None:
self.references.append(s)

Expand Down
1 change: 1 addition & 0 deletions bindings/libchatllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ DLL_DECL int API_CALL chatllm_user_input(struct chatllm_obj *obj, const char *ut
* @param[in] obj model object
* @return 0 if succeeded
*/
DLL_DECL int API_CALL chatllm_destroy(struct chatllm_obj *obj);
DLL_DECL int API_CALL chatllm_user_input_multimedia_msg(struct chatllm_obj *obj);

/**
Expand Down
2 changes: 1 addition & 1 deletion scripts/binding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys, os

this_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
this_dir = os.path.dirname(os.path.abspath(__file__))
PATH_APP = os.path.abspath(os.path.join(this_dir, '..'))
PATH_BINDS = os.path.join(PATH_APP, 'bindings')
PATH_SCRIPTS = os.path.join(PATH_APP, 'scripts')
Expand Down
13 changes: 10 additions & 3 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1865,7 +1865,7 @@ int chatllm_user_input(struct chatllm_obj *obj, const char *utf8_str)

if (!chat->pipeline->is_loaded()) return -2;

if ( (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat)
if ( (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat && chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::ASR)
&& (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::ASR))
return -3;

Expand All @@ -1874,6 +1874,13 @@ int chatllm_user_input(struct chatllm_obj *obj, const char *utf8_str)
return chatllm_generate(obj);
}

int chatllm_destroy(struct chatllm_obj *obj)
{
Chat *chat = reinterpret_cast<Chat *>(obj);
delete chat;
return 0;
}

int chatllm_user_input_multimedia_msg(struct chatllm_obj *obj)
{
int r = 0;
Expand All @@ -1882,7 +1889,7 @@ int chatllm_user_input_multimedia_msg(struct chatllm_obj *obj)

if (!streamer->is_prompt) return -1;

if (chat->pipeline->is_loaded() && (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat))
if (chat->pipeline->is_loaded() && (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat && chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::ASR))
return -1;

chat->history.push_back(chat->content_scratch, role_user);
Expand All @@ -1907,7 +1914,7 @@ int chatllm_ai_continue(struct chatllm_obj *obj, const char *utf8_str)

if (!streamer->is_prompt) return -1;

if (chat->pipeline->is_loaded() && (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat))
if (chat->pipeline->is_loaded() && (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat && chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::ASR))
return -1;

if (chat->history.size() < 1) return -2;
Expand Down