Compare commits

..

2 Commits

View File

@ -6,6 +6,7 @@ from openai import OpenAI
from nemubot import context from nemubot import context
from nemubot.hooks import hook from nemubot.hooks import hook
from nemubot.tools import web
from nemubot.module.more import Response from nemubot.module.more import Response
@ -14,18 +15,24 @@ from nemubot.module.more import Response
CLIENT = None CLIENT = None
MODEL = "gpt-4" MODEL = "gpt-4"
ENDPOINT = None
def load(context): def load(context):
global CLIENT global CLIENT, ENDPOINT, MODEL
if not context.config or "apikey" not in context.config: if not context.config or ("apikey" not in context.config and "endpoint" not in context.config):
raise ImportError ("You need a OpenAI API key in order to use " raise ImportError ("You need a OpenAI API key in order to use "
"this module. Add it to the module configuration: " "this module. Add it to the module configuration: "
"\n<module name=\"openai\" " "\n<module name=\"openai\" "
"apikey=\"XXXXXX-XXXXXXXXXX\" endpoint=\"https://...\" model=\"gpt-4\" />") "apikey=\"XXXXXX-XXXXXXXXXX\" endpoint=\"https://...\" model=\"gpt-4\" />")
CLIENT = OpenAI( kwargs = {
base_url=context.config["endpoint"], "api_key": context.config["apikey"] or "",
api_key=context.config["apikey"], }
)
if "endpoint" in context.config:
ENDPOINT = context.config["endpoint"]
kwargs["base_url"] = ENDPOINT
CLIENT = OpenAI(**kwargs)
if "model" in context.config: if "model" in context.config:
MODEL = context.config["model"] MODEL = context.config["model"]
@ -33,6 +40,32 @@ def load(context):
# MODULE INTERFACE #################################################### # MODULE INTERFACE ####################################################
@hook.command("list_models",
help="list available LLM")
def cmd_listllm(msg):
llms = web.getJSON(ENDPOINT + "/models", timeout=6)
return Response(message=[m for m in map(lambda i: i["id"], llms["data"])], title="Here is the available models", channel=msg.channel)
@hook.command("set_model",
help="Set the model to use when talking to nemubot")
def cmd_setllm(msg):
if len(msg.args) != 1:
raise IMException("Indicate 1 model to use")
wanted_model = msg.args[0]
llms = web.getJSON(ENDPOINT + "/models", timeout=6)
for model in llms["data"]:
if wanted_model == model["id"]:
break
else:
raise IMException("Unable to set such model: unknown")
MODEL = wanted_model
return Response("New model in use: " + wanted_model, channel=msg.channel)
@hook.ask() @hook.ask()
def parseask(msg): def parseask(msg):
chat_completion = CLIENT.chat.completions.create( chat_completion = CLIENT.chat.completions.create(