|
| 1 | +import time |
| 2 | +import discord |
| 3 | +from langfuse import get_client, openai |
| 4 | +from openai.types.completion_usage import CompletionUsage |
| 5 | +from discord.ext import commands |
| 6 | +from utils.embeds import LLMPerformanceEmbed, InfoEmbed |
| 7 | +from utils.config import LLM as LLMConfig |
| 8 | + |
| 9 | +client = openai.AsyncOpenAI( |
| 10 | + base_url=LLMConfig.get("BASE_URL"), |
| 11 | + api_key=LLMConfig.get("API_KEY"), |
| 12 | +) |
| 13 | + |
| 14 | +langfuse_client = get_client() |
| 15 | + |
| 16 | +UPDATE_INTERVAL_SECONDS = 1 |
| 17 | +DISCORD_CHAR_LIMIT = 2000 |
| 18 | +SAFE_SPLIT_LIMIT = 1980 |
| 19 | + |
| 20 | + |
| 21 | +async def get_models(ctx: discord.AutocompleteContext): |
| 22 | + text = ctx.options["model"] |
| 23 | + if await ctx.bot.is_owner(ctx.interaction.user): |
| 24 | + models_list = await client.models.list() |
| 25 | + models = [m.id for m in models_list.data] |
| 26 | + else: |
| 27 | + models = [LLMConfig.get("DEFAULT_MODEL")] |
| 28 | + return [m for m in models if text in m][:25] |
| 29 | + |
| 30 | + |
| 31 | +def find_best_split_position(text: str, max_len: int) -> int: |
| 32 | + """ |
| 33 | + Finds the best position to split text to respect paragraphs, lines, and words. |
| 34 | + Searches backwards from the max_len point. |
| 35 | + """ |
| 36 | + if len(text) <= max_len: |
| 37 | + return len(text) |
| 38 | + |
| 39 | + # 1. Try to find a paragraph break (double newline) |
| 40 | + try: |
| 41 | + # Search backwards from the max_len position |
| 42 | + pos = text.rindex("\n\n", 0, max_len) |
| 43 | + return pos |
| 44 | + except ValueError: |
| 45 | + pass # Not found |
| 46 | + |
| 47 | + # 2. If no paragraph break, try a line break (single newline) |
| 48 | + try: |
| 49 | + pos = text.rindex("\n", 0, max_len) |
| 50 | + return pos |
| 51 | + except ValueError: |
| 52 | + pass # Not found |
| 53 | + |
| 54 | + # 3. If no newline, try to find the last space to not break a word |
| 55 | + try: |
| 56 | + pos = text.rindex(" ", 0, max_len) |
| 57 | + return pos |
| 58 | + except ValueError: |
| 59 | + pass # Not found |
| 60 | + |
| 61 | + # 4. If all else fails, force a hard cut at the safe limit |
| 62 | + return max_len |
| 63 | + |
| 64 | + |
| 65 | +class LLM(commands.Cog): |
| 66 | + def __init__(self, bot: discord.Bot): |
| 67 | + self.bot = bot |
| 68 | + self.response_queue = {} |
| 69 | + |
| 70 | + llm = discord.SlashCommandGroup( |
| 71 | + "llm", |
| 72 | + integration_types=set( |
| 73 | + [ |
| 74 | + discord.IntegrationType.user_install, |
| 75 | + discord.IntegrationType.guild_install, |
| 76 | + ] |
| 77 | + ), |
| 78 | + ) |
| 79 | + |
| 80 | + @llm.command(description="List all models") |
| 81 | + async def list(self, ctx: discord.ApplicationContext): |
| 82 | + res = await client.models.list() |
| 83 | + embed = InfoEmbed( |
| 84 | + self.bot.user, |
| 85 | + "\n".join([f"- {m.id}" for m in res.data]), |
| 86 | + ) |
| 87 | + await ctx.respond(embed=embed) |
| 88 | + |
| 89 | + @llm.command( |
| 90 | + description="Chat with a model", |
| 91 | + ) |
| 92 | + @discord.option( |
| 93 | + "prompt", |
| 94 | + type=discord.SlashCommandOptionType.string, |
| 95 | + ) |
| 96 | + @discord.option( |
| 97 | + "model", |
| 98 | + type=discord.SlashCommandOptionType.string, |
| 99 | + autocomplete=get_models, |
| 100 | + required=False, |
| 101 | + default=LLMConfig.get("DEFAULT_MODEL"), |
| 102 | + ) |
| 103 | + async def chat( |
| 104 | + self, |
| 105 | + ctx: discord.ApplicationContext, |
| 106 | + prompt: str, |
| 107 | + model: str, |
| 108 | + ): |
| 109 | + await ctx.defer() |
| 110 | + |
| 111 | + user_id = str(ctx.author.id) |
| 112 | + with langfuse_client.start_as_current_span( |
| 113 | + name="discord-ask-command", |
| 114 | + input=prompt, |
| 115 | + ) as root_span: |
| 116 | + root_span.update_trace( |
| 117 | + user_id=user_id, |
| 118 | + metadata={ |
| 119 | + "discord_username": ctx.author.name, |
| 120 | + "channel_id": str(ctx.channel.id), |
| 121 | + "guild_id": str(ctx.guild.id) if ctx.guild else "DM", |
| 122 | + }, |
| 123 | + ) |
| 124 | + |
| 125 | + response_messages = [] |
| 126 | + current_message_content = "" |
| 127 | + full_response = "" |
| 128 | + start_time, first_token_time, end_time = None, None, None |
| 129 | + |
| 130 | + initial_message = await ctx.respond(f"🧠 Thinking with `{model}`...") |
| 131 | + response_messages.append(initial_message) |
| 132 | + last_update_time = time.time() |
| 133 | + |
| 134 | + start_time = time.time() |
| 135 | + try: |
| 136 | + stream = await client.chat.completions.create( |
| 137 | + model=model, |
| 138 | + messages=[ |
| 139 | + { |
| 140 | + "role": "system", |
| 141 | + "content": "You are a helpful assistant on Discord, skilled in formatting your output with markdown.", |
| 142 | + }, |
| 143 | + {"role": "user", "content": prompt}, |
| 144 | + ], |
| 145 | + stream=True, |
| 146 | + stream_options={"include_usage": True}, |
| 147 | + ) |
| 148 | + |
| 149 | + usage = None |
| 150 | + |
| 151 | + async for chunk in stream: |
| 152 | + content = chunk.choices[0].delta.content |
| 153 | + usage = chunk.usage |
| 154 | + if first_token_time is None and content: |
| 155 | + first_token_time = time.time() |
| 156 | + |
| 157 | + if content: |
| 158 | + current_message_content += content |
| 159 | + full_response += content # Keep a full copy for logging |
| 160 | + # ... (Smart splitting and periodic update logic is unchanged) ... |
| 161 | + if len(current_message_content) > SAFE_SPLIT_LIMIT: |
| 162 | + split_pos = find_best_split_position( |
| 163 | + current_message_content, SAFE_SPLIT_LIMIT |
| 164 | + ) |
| 165 | + text_to_send, carry_over_text = ( |
| 166 | + current_message_content[:split_pos], |
| 167 | + current_message_content[split_pos:], |
| 168 | + ) |
| 169 | + await response_messages[-1].edit( |
| 170 | + content=text_to_send.strip() |
| 171 | + ) |
| 172 | + response_messages.append(await ctx.send("...")) |
| 173 | + current_message_content = carry_over_text.lstrip() |
| 174 | + last_update_time = time.time() |
| 175 | + |
| 176 | + if time.time() - last_update_time >= UPDATE_INTERVAL_SECONDS: |
| 177 | + if current_message_content: |
| 178 | + await response_messages[-1].edit( |
| 179 | + content=current_message_content + " █" |
| 180 | + ) |
| 181 | + last_update_time = time.time() |
| 182 | + |
| 183 | + end_time = time.time() |
| 184 | + |
| 185 | + except Exception as e: |
| 186 | + end_time = time.time() # Log end time even on failure |
| 187 | + error_message = f"An unexpected error occurred: {e}" |
| 188 | + print(f"Error during stream for prompt '{prompt}': {error_message}") |
| 189 | + |
| 190 | + if response_messages: |
| 191 | + await response_messages[-1].edit(content=error_message) |
| 192 | + return |
| 193 | + |
| 194 | + ttft, tps, completion_tokens = 0.0, 0.0, 0 |
| 195 | + if usage: |
| 196 | + completion_tokens = usage.completion_tokens |
| 197 | + if start_time and first_token_time: |
| 198 | + ttft = first_token_time - start_time |
| 199 | + if first_token_time and end_time: |
| 200 | + generation_time = end_time - first_token_time |
| 201 | + if generation_time > 0 and completion_tokens > 1: |
| 202 | + tps = (completion_tokens - 1) / generation_time |
| 203 | + |
| 204 | + stats_text = ( |
| 205 | + f"\n\n" |
| 206 | + f"-# {model} • {tps:.2f} tps • TTFT: {ttft:.2f}s • Tokens: {completion_tokens}" |
| 207 | + ) |
| 208 | + final_content = current_message_content.strip() |
| 209 | + # 3. Handle the final message edit |
| 210 | + if final_content: |
| 211 | + # Check if appending the stats would exceed Discord's character limit |
| 212 | + if len(final_content) + len(stats_text) > DISCORD_CHAR_LIMIT: |
| 213 | + # If it's too long, edit the last message with just the content... |
| 214 | + await response_messages[-1].edit(content=final_content) |
| 215 | + # ...and send the stats in a new, separate message. |
| 216 | + await ctx.send(stats_text.strip()) |
| 217 | + else: |
| 218 | + # If it fits, combine them and edit the last message. |
| 219 | + final_combined_content = final_content + stats_text |
| 220 | + await response_messages[-1].edit(content=final_combined_content) |
| 221 | + else: |
| 222 | + # Handle the case where the response was empty but we still want to clean up |
| 223 | + if ( |
| 224 | + len(response_messages) > 1 |
| 225 | + and response_messages[-1].content == "..." |
| 226 | + ): |
| 227 | + await response_messages[-1].delete() |
| 228 | + else: |
| 229 | + # Edit the very first message if there was no output at all |
| 230 | + await response_messages[0].edit( |
| 231 | + content="*No response was generated.*" |
| 232 | + ) |
| 233 | + |
| 234 | + root_span.update(output=full_response) |
0 commit comments