diff --git a/bindings/chatllm.py b/bindings/chatllm.py index 506d210..e2a5499 100644 --- a/bindings/chatllm.py +++ b/bindings/chatllm.py @@ -115,6 +115,7 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str self._chatllm_multimedia_msg_prepare = self._lib.chatllm_multimedia_msg_prepare self._chatllm_multimedia_msg_append = self._lib.chatllm_multimedia_msg_append self._chatllm_user_input_multimedia_msg = self._lib.chatllm_user_input_multimedia_msg + self._chatllm_destroy = self._lib.chatllm_destroy self._chatllm_async_user_input = self._lib.chatllm_async_user_input self._chatllm_async_ai_continue = self._lib.chatllm_async_ai_continue @@ -149,9 +150,13 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str self._chatllm_async_user_input.argtypes = [c_void_p, c_char_p] self._chatllm_user_input_multimedia_msg.restype = c_int + self._chatllm_destroy = self._lib.chatllm_destroy self._chatllm_user_input_multimedia_msg.argtypes = [c_void_p] + self._chatllm_destroy = self._lib.chatllm_destroy self._chatllm_async_user_input_multimedia_msg.restype = c_int self._chatllm_async_user_input_multimedia_msg.argtypes = [c_void_p] + self._chatllm_destroy.restype = c_int + self._chatllm_destroy.argtypes = [c_void_p] self._chatllm_tool_input.restype = c_int self._chatllm_tool_input.argtypes = [c_void_p, c_char_p] @@ -268,6 +273,10 @@ def start(self, obj: c_void_p, callback_obj: Any) -> int: def set_ai_prefix(self, obj: c_void_p, prefix: str) -> int: return self._chatllm_set_ai_prefix(obj, c_char_p(prefix.encode())) + def destroy(self, obj: c_void_p) -> int: + if obj is None: return 0 + return self._chatllm_destroy(obj) + def _input_multimedia_msg(self, obj: c_void_p, user_input: List[dict | str]) -> int: self._chatllm_multimedia_msg_prepare(obj) for x in user_input: @@ -299,6 +308,7 @@ def chat(self, obj: c_void_p, user_input: str | List[dict | str]) -> int: elif isinstance(user_input, list): self._input_multimedia_msg(obj, user_input) return self._chatllm_user_input_multimedia_msg(obj) + self._chatllm_destroy = self._lib.chatllm_destroy def async_chat(self, obj: c_void_p, user_input: str | List[dict | str]) -> int: if isinstance(user_input, str): @@ -562,6 +572,13 @@ def save_session(self, file_name: str) -> str: def load_session(self, file_name: str) -> str: return self._lib.load_session(self._chat, file_name) + def destroy(self) -> int: + if hasattr(self, "_chat") and self._chat: + if self.is_generating: self.abort() + self._lib.destroy(self._chat) + self._chat = None + return 0 + def callback_print_reference(self, s: str) -> None: self.references.append(s) diff --git a/bindings/libchatllm.h b/bindings/libchatllm.h index 2a900d7..f988b93 100644 --- a/bindings/libchatllm.h +++ b/bindings/libchatllm.h @@ -248,6 +248,7 @@ DLL_DECL int API_CALL chatllm_user_input(struct chatllm_obj *obj, const char *ut * @param[in] obj model object * @return 0 if succeeded */ +DLL_DECL int API_CALL chatllm_destroy(struct chatllm_obj *obj); DLL_DECL int API_CALL chatllm_user_input_multimedia_msg(struct chatllm_obj *obj); /** diff --git a/scripts/binding.py b/scripts/binding.py index 3aee5bb..1fd98bd 100644 --- a/scripts/binding.py +++ b/scripts/binding.py @@ -1,6 +1,6 @@ import sys, os -this_dir = os.path.dirname(os.path.abspath(sys.argv[0])) +this_dir = os.path.dirname(os.path.abspath(__file__)) PATH_APP = os.path.abspath(os.path.join(this_dir, '..')) PATH_BINDS = os.path.join(PATH_APP, 'bindings') PATH_SCRIPTS = os.path.join(PATH_APP, 'scripts') diff --git a/src/main.cpp b/src/main.cpp index d9b48cb..e8b9f04 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1865,7 +1865,7 @@ int chatllm_user_input(struct chatllm_obj *obj, const char *utf8_str) if (!chat->pipeline->is_loaded()) return -2; - if ( (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat) + if ( (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat && chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::ASR) && (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::ASR)) return -3; @@ -1874,6 +1874,13 @@ int chatllm_user_input(struct chatllm_obj *obj, const char *utf8_str) return chatllm_generate(obj); } +int chatllm_destroy(struct chatllm_obj *obj) +{ + Chat *chat = reinterpret_cast(obj); + delete chat; + return 0; +} + int chatllm_user_input_multimedia_msg(struct chatllm_obj *obj) { int r = 0; @@ -1882,7 +1889,7 @@ int chatllm_user_input_multimedia_msg(struct chatllm_obj *obj) if (!streamer->is_prompt) return -1; - if (chat->pipeline->is_loaded() && (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat)) + if (chat->pipeline->is_loaded() && (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat && chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::ASR)) return -1; chat->history.push_back(chat->content_scratch, role_user); @@ -1907,7 +1914,7 @@ int chatllm_ai_continue(struct chatllm_obj *obj, const char *utf8_str) if (!streamer->is_prompt) return -1; - if (chat->pipeline->is_loaded() && (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat)) + if (chat->pipeline->is_loaded() && (chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::Chat && chat->pipeline->model->get_purpose() != chatllm::ModelPurpose::ASR)) return -1; if (chat->history.size() < 1) return -2;