Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions bindings/chatllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str
self._chatllm_append_param = self._lib.chatllm_append_param
self._chatllm_start = self._lib.chatllm_start
self._chatllm_set_ai_prefix = self._lib.chatllm_set_ai_prefix
self._chatllm_set_additional_args = self._lib.chatllm_set_additional_args
self._chatllm_ai_continue = self._lib.chatllm_ai_continue
self._chatllm_user_input = self._lib.chatllm_user_input
self._chatllm_tool_input = self._lib.chatllm_tool_input
Expand Down Expand Up @@ -132,7 +133,11 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str
self._chatllm_start.argtypes = [c_void_p, self._PRINTFUNC, self._ENDFUNC, c_void_p]

self._chatllm_set_ai_prefix.restype = c_int
self._chatllm_set_additional_args = self._lib.chatllm_set_additional_args
self._chatllm_set_ai_prefix.argtypes = [c_void_p, c_char_p]
self._chatllm_set_additional_args.restype = c_int
self._chatllm_set_additional_args.argtypes = [c_void_p, c_char_p]
self._chatllm_set_additional_args = self._lib.chatllm_set_additional_args

self._chatllm_ai_continue.restype = c_int
self._chatllm_ai_continue.argtypes = [c_void_p, c_char_p]
Expand Down Expand Up @@ -265,8 +270,18 @@ def start(self, obj: c_void_p, callback_obj: Any) -> int:
id = self.alloc_id_for_obj(callback_obj)
return self._chatllm_start(obj, self._cb_print, self._cb_end, c_void_p(id))

def set_additional_args(self, obj: c_void_p, key_value: Union[str, bytes]) -> int:
if isinstance(key_value, str):
key_value = key_value.encode()
return self._chatllm_set_additional_args(obj, c_char_p(key_value))

def set_additional_args(self, key: str, value: str) -> int:
key_value = f"{key}={value}"
return self._lib.set_additional_args(self._chat, key_value)

def set_ai_prefix(self, obj: c_void_p, prefix: str) -> int:
return self._chatllm_set_ai_prefix(obj, c_char_p(prefix.encode()))
self._chatllm_set_additional_args = self._lib.chatllm_set_additional_args

def _input_multimedia_msg(self, obj: c_void_p, user_input: List[dict | str]) -> int:
self._chatllm_multimedia_msg_prepare(obj)
Expand Down Expand Up @@ -449,6 +464,15 @@ def __init__(self, lib: LibChatLLM, param: Union[None, str, List[str]], auto_sta
def append_param(self, param: Union[str, List[str]]) -> None:
self._lib.append_param(self._chat, param)

def set_additional_args(self, obj: c_void_p, key_value: Union[str, bytes]) -> int:
if isinstance(key_value, str):
key_value = key_value.encode()
return self._chatllm_set_additional_args(obj, c_char_p(key_value))

def set_additional_args(self, key: str, value: str) -> int:
key_value = f"{key}={value}"
return self._lib.set_additional_args(self._chat, key_value)

def set_ai_prefix(self, prefix: str) -> int:
return self._lib.set_ai_prefix(self._chat, prefix)

Expand Down
1 change: 1 addition & 0 deletions bindings/libchatllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ DLL_DECL int API_CALL chatllm_user_input_multimedia_msg(struct chatllm_obj *obj)
* @param[in] utf8_str prefix
* @return 0 if succeeded
*/
DLL_DECL int API_CALL chatllm_set_additional_args(struct chatllm_obj *obj, const char *utf8_str);
DLL_DECL int API_CALL chatllm_set_ai_prefix(struct chatllm_obj *obj, const char *utf8_str);

/**
Expand Down
19 changes: 19 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,25 @@ int chatllm_async_start(struct chatllm_obj *obj, f_chatllm_print f_print, f_chat
ASYNC_FUN_BODY(chatllm_start(obj, f_print, f_end, user_data));
}

int chatllm_set_additional_args(struct chatllm_obj *obj, const char *utf8_str)
{
Chat *chat = reinterpret_cast<Chat *>(obj);
if (!chat->pipeline || !chat->pipeline->is_loaded()) {
return -1;
}
std::string str(utf8_str);
size_t eq_pos = str.find("=");
if (eq_pos == std::string::npos) {
return -2;
}
std::string key = str.substr(0, eq_pos);
std::string value = str.substr(eq_pos + 1);
std::map<std::string, std::string> args;
args[key] = value;
chat->pipeline->set_additional_args(args);
return 0;
}

int chatllm_set_ai_prefix(struct chatllm_obj *obj, const char *utf8_str)
{
Chat *chat = reinterpret_cast<Chat *>(obj);
Expand Down