diff --git a/modules/openai.py b/modules/openai.py index 2e0529b..b9b6e21 100644 --- a/modules/openai.py +++ b/modules/openai.py @@ -6,6 +6,7 @@ from openai import OpenAI from nemubot import context from nemubot.hooks import hook +from nemubot.tools import web from nemubot.module.more import Response @@ -14,18 +15,24 @@ from nemubot.module.more import Response CLIENT = None MODEL = "gpt-4" +ENDPOINT = None def load(context): - global CLIENT - if not context.config or "apikey" not in context.config: + global CLIENT, ENDPOINT, MODEL + if not context.config or ("apikey" not in context.config and "endpoint" not in context.config): raise ImportError ("You need a OpenAI API key in order to use " "this module. Add it to the module configuration: " "\n") - CLIENT = OpenAI( - base_url=context.config["endpoint"], - api_key=context.config["apikey"], - ) + kwargs = { + "api_key": context.config["apikey"] or "", + } + + if "endpoint" in context.config: + ENDPOINT = context.config["endpoint"] + kwargs["base_url"] = ENDPOINT + + CLIENT = OpenAI(**kwargs) if "model" in context.config: MODEL = context.config["model"] @@ -33,6 +40,32 @@ def load(context): # MODULE INTERFACE #################################################### +@hook.command("list_models", + help="list available LLM") +def cmd_listllm(msg): + llms = web.getJSON(ENDPOINT + "/models", timeout=6) + return Response(message=[m for m in map(lambda i: i["id"], llms["data"])], title="Here is the available models", channel=msg.channel) + + +@hook.command("set_model", + help="Set the model to use when talking to nemubot") +def cmd_setllm(msg): + if len(msg.args) != 1: + raise IMException("Indicate 1 model to use") + + wanted_model = msg.args[0] + + llms = web.getJSON(ENDPOINT + "/models", timeout=6) + for model in llms["data"]: + if wanted_model == model["id"]: + break + else: + raise IMException("Unable to set such model: unknown") + + MODEL = wanted_model + return Response("New model in use: " + wanted_model, channel=msg.channel) + + @hook.ask() def parseask(msg): chat_completion = CLIENT.chat.completions.create(