diff --git a/bindings/chatllm.py b/bindings/chatllm.py index 506d210..9e79d8b 100644 --- a/bindings/chatllm.py +++ b/bindings/chatllm.py @@ -98,6 +98,7 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str self._chatllm_append_param = self._lib.chatllm_append_param self._chatllm_start = self._lib.chatllm_start self._chatllm_set_ai_prefix = self._lib.chatllm_set_ai_prefix + self._chatllm_set_additional_args = self._lib.chatllm_set_additional_args self._chatllm_ai_continue = self._lib.chatllm_ai_continue self._chatllm_user_input = self._lib.chatllm_user_input self._chatllm_tool_input = self._lib.chatllm_tool_input @@ -132,7 +133,11 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str self._chatllm_start.argtypes = [c_void_p, self._PRINTFUNC, self._ENDFUNC, c_void_p] self._chatllm_set_ai_prefix.restype = c_int + self._chatllm_set_additional_args = self._lib.chatllm_set_additional_args self._chatllm_set_ai_prefix.argtypes = [c_void_p, c_char_p] + self._chatllm_set_additional_args.restype = c_int + self._chatllm_set_additional_args.argtypes = [c_void_p, c_char_p] + self._chatllm_set_additional_args = self._lib.chatllm_set_additional_args self._chatllm_ai_continue.restype = c_int self._chatllm_ai_continue.argtypes = [c_void_p, c_char_p] @@ -265,8 +270,18 @@ def start(self, obj: c_void_p, callback_obj: Any) -> int: id = self.alloc_id_for_obj(callback_obj) return self._chatllm_start(obj, self._cb_print, self._cb_end, c_void_p(id)) + def set_additional_args(self, obj: c_void_p, key_value: Union[str, bytes]) -> int: + if isinstance(key_value, str): + key_value = key_value.encode() + return self._chatllm_set_additional_args(obj, c_char_p(key_value)) + + def set_additional_args(self, key: str, value: str) -> int: + key_value = f"{key}={value}" + return self._lib.set_additional_args(self._chat, key_value) + def set_ai_prefix(self, obj: c_void_p, prefix: str) -> int: return self._chatllm_set_ai_prefix(obj, c_char_p(prefix.encode())) + self._chatllm_set_additional_args = self._lib.chatllm_set_additional_args def _input_multimedia_msg(self, obj: c_void_p, user_input: List[dict | str]) -> int: self._chatllm_multimedia_msg_prepare(obj) @@ -449,6 +464,15 @@ def __init__(self, lib: LibChatLLM, param: Union[None, str, List[str]], auto_sta def append_param(self, param: Union[str, List[str]]) -> None: self._lib.append_param(self._chat, param) + def set_additional_args(self, obj: c_void_p, key_value: Union[str, bytes]) -> int: + if isinstance(key_value, str): + key_value = key_value.encode() + return self._chatllm_set_additional_args(obj, c_char_p(key_value)) + + def set_additional_args(self, key: str, value: str) -> int: + key_value = f"{key}={value}" + return self._lib.set_additional_args(self._chat, key_value) + def set_ai_prefix(self, prefix: str) -> int: return self._lib.set_ai_prefix(self._chat, prefix) diff --git a/bindings/libchatllm.h b/bindings/libchatllm.h index 2a900d7..d9295e8 100644 --- a/bindings/libchatllm.h +++ b/bindings/libchatllm.h @@ -259,6 +259,7 @@ DLL_DECL int API_CALL chatllm_user_input_multimedia_msg(struct chatllm_obj *obj) * @param[in] utf8_str prefix * @return 0 if succeeded */ +DLL_DECL int API_CALL chatllm_set_additional_args(struct chatllm_obj *obj, const char *utf8_str); DLL_DECL int API_CALL chatllm_set_ai_prefix(struct chatllm_obj *obj, const char *utf8_str); /** diff --git a/src/main.cpp b/src/main.cpp index d9b48cb..954e11b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1799,6 +1799,25 @@ int chatllm_async_start(struct chatllm_obj *obj, f_chatllm_print f_print, f_chat ASYNC_FUN_BODY(chatllm_start(obj, f_print, f_end, user_data)); } +int chatllm_set_additional_args(struct chatllm_obj *obj, const char *utf8_str) +{ + Chat *chat = reinterpret_cast(obj); + if (!chat->pipeline || !chat->pipeline->is_loaded()) { + return -1; + } + std::string str(utf8_str); + size_t eq_pos = str.find("="); + if (eq_pos == std::string::npos) { + return -2; + } + std::string key = str.substr(0, eq_pos); + std::string value = str.substr(eq_pos + 1); + std::map args; + args[key] = value; + chat->pipeline->set_additional_args(args); + return 0; +} + int chatllm_set_ai_prefix(struct chatllm_obj *obj, const char *utf8_str) { Chat *chat = reinterpret_cast(obj);