Source code for mesa_llm.module_llm

import logging
import os

from dotenv import load_dotenv
from litellm import acompletion, completion, litellm
from litellm.exceptions import (
    APIConnectionError,
    NotFoundError,
    RateLimitError,
    Timeout,
)
from tenacity import AsyncRetrying, retry, retry_if_exception_type, wait_exponential

RETRYABLE_EXCEPTIONS = (
    APIConnectionError,
    Timeout,
    RateLimitError,
)

load_dotenv()
logger = logging.getLogger(__name__)


[docs] class ModuleLLM: """ A module that provides a simple interface for using LLMs Note : Currently supports OpenAI, Anthropic, xAI, Huggingface, Ollama, OpenRouter, NovitaAI, Gemini """ def __init__( self, llm_model: str, api_base: str | None = None, system_prompt: str | None = None, ): """ Initialize the LLM module Args: llm_model: The model to use for the LLM in the format "{provider}/{model}" (for example, "openai/gpt-4o"). api_base: The API base to use if the LLM provider is Ollama system_prompt: The system prompt to use for the LLM Raises: ValueError: If llm_model is not in the expected "{provider}/{model}" format, or if the provider API key is missing. """ self.api_base = api_base self.llm_model = llm_model self.system_prompt = system_prompt if "/" not in llm_model: raise ValueError( f"Invalid model format '{llm_model}'. " "Expected '{provider}/{model}', e.g. 'openai/gpt-4o'." ) provider = self.llm_model.split("/")[0].upper() if provider in ["OLLAMA", "OLLAMA_CHAT"]: if self.api_base is None: self.api_base = "http://localhost:11434" logger.warning( "Using default Ollama API base: %s. If inference is not working, you may need to set the API base to the correct URL.", self.api_base, ) else: try: self.api_key = os.environ[f"{provider}_API_KEY"] except KeyError as err: raise ValueError( f"No API key found for {provider}. Please set the {provider}_API_KEY environment variable (e.g., in your .env file)." ) from err try: litellm.get_model_info(model=self.llm_model) except Exception as exc: logger.debug( "Skipping function-calling capability check for unmapped model %s: %s", self.llm_model, exc, ) else: if not litellm.supports_function_calling(model=self.llm_model): logger.warning( "%s does not support function calling. This model may not be able to use tools. Please check the model documentation at https://docs.litellm.ai/docs/providers for more information.", self.llm_model, ) def _build_invalid_model_error(self, error: Exception) -> ValueError: provider = self.llm_model.split("/", 1)[0].lower() return ValueError( f"Invalid or unsupported model '{self.llm_model}' for provider " f"'{provider}'. Details: {error}. " "Please verify the model name and provider prefix, and if you are " f"using a custom endpoint set the correct api_base (current: {self.api_base})." ) def _build_messages( self, prompt: str | list[str] | None = None, system_prompt: str | None = None, ) -> list[dict]: """ Format the prompt messages for the LLM of the form : {"role": ..., "content": ...} Args: prompt: The prompt to generate a response for (str, list of strings, or None) system_prompt: Optional system prompt scoped to this call only. Returns: The messages for the LLM """ messages = [] # Always include a system message. Default to empty string if no system # prompt to support Ollama. system_content = ( self.system_prompt if system_prompt is None else system_prompt ) or "" messages.append({"role": "system", "content": system_content}) if prompt is not None: if isinstance(prompt, str): messages.append({"role": "user", "content": prompt}) elif isinstance(prompt, list): invalid_prompt = next( ( (index, value) for index, value in enumerate(prompt) if not isinstance(value, str) ), None, ) if invalid_prompt is not None: index, value = invalid_prompt raise TypeError( f"Invalid prompt list element at index {index}: " f"type '{type(value).__name__}'. Expected str." ) messages.extend([{"role": "user", "content": p} for p in prompt]) else: raise TypeError( f"Invalid prompt type '{type(prompt).__name__}'. " "Expected str, list[str], or None." ) return messages def _build_rate_limit_error(self, error: RateLimitError) -> RateLimitError: provider = self.llm_model.split("/", 1)[0].lower() docs_url = { "anthropic": "https://platform.claude.com/docs/en/api/rate-limits", "gemini": "https://ai.google.dev/gemini-api/docs/rate-limits", "novita": "https://novita.ai/docs/guides/llm-rate-limits", "openai": "https://developers.openai.com/api/docs/guides/rate-limits", "openrouter": "https://openrouter.ai/docs/api/reference/limits", "xai": "https://docs.x.ai/developers/rate-limits", }.get(provider) detail = error.message.removeprefix("litellm.RateLimitError: ").strip() message_parts = [f"Rate limit exceeded for model '{self.llm_model}'."] if detail: message_parts.append(detail) message_parts.append( "Please wait a few minutes and try again, or switch to a different model." ) if docs_url: message_parts.append(f"To check your quota visit: {docs_url}") message = " ".join(message_parts) return RateLimitError( message=message, llm_provider=error.llm_provider, model=error.model, response=error.response, litellm_debug_info=error.litellm_debug_info, max_retries=error.max_retries, num_retries=error.num_retries, )
[docs] @retry( wait=wait_exponential(multiplier=1, min=1, max=60), retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS), reraise=True, ) def generate( self, prompt: str | list[str] | None = None, tool_schema: list[dict] | None = None, tool_choice: str = "auto", response_format: dict | object | None = None, system_prompt: str | None = None, ) -> str: """ Generate a response from the LLM using litellm based on the prompt Args: prompt: The prompt to generate a response for (str, list of strings, or None) tool_schema: The schema of the tools to use tool_choice: The choice of tool to use response_format: The format of the response system_prompt: Optional system prompt scoped to this call only. Returns: The response from the LLM """ messages = self._build_messages(prompt, system_prompt=system_prompt) completion_kwargs = { "model": self.llm_model, "messages": messages, "tools": tool_schema, "tool_choice": tool_choice if tool_schema else None, "response_format": response_format, } if self.api_base: completion_kwargs["api_base"] = self.api_base try: response = completion(**completion_kwargs) except RateLimitError as error: raise self._build_rate_limit_error(error) from error except NotFoundError as error: raise self._build_invalid_model_error(error) from error except Exception as error: if str(error).startswith("This model isn't mapped yet."): raise self._build_invalid_model_error(error) from error raise return response
[docs] async def agenerate( self, prompt: str | list[str] | None = None, tool_schema: list[dict] | None = None, tool_choice: str = "auto", response_format: dict | object | None = None, system_prompt: str | None = None, ) -> str: """ Asynchronous version of generate() method for parallel LLM calls. """ messages = self._build_messages(prompt, system_prompt=system_prompt) async for attempt in AsyncRetrying( wait=wait_exponential(multiplier=1, min=1, max=60), retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS), reraise=True, ): with attempt: completion_kwargs = { "model": self.llm_model, "messages": messages, "tools": tool_schema, "tool_choice": tool_choice if tool_schema else None, "response_format": response_format, } if self.api_base: completion_kwargs["api_base"] = self.api_base try: response = await acompletion(**completion_kwargs) except RateLimitError as error: raise self._build_rate_limit_error(error) from error except NotFoundError as error: raise self._build_invalid_model_error(error) from error except Exception as error: if str(error).startswith("This model isn't mapped yet."): raise self._build_invalid_model_error(error) from error raise return response