Source code for council.llm.llm_middleware

from __future__ import annotations

import hashlib
import json
import time
from collections import OrderedDict
from threading import Lock
from typing import Any, Callable, List, Optional, Protocol, Sequence

from council.contexts import Consumption, LLMContext

from .llm_base import LLMBase, LLMMessage, LLMResult, T_Configuration
from .llm_exception import LLMOutOfRetriesException


[docs] class LLMRequest: def __init__(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> None: self._context = context self._messages = messages self._kwargs = kwargs @property def context(self) -> LLMContext: return self._context @property def messages(self) -> Sequence[LLMMessage]: return self._messages @property def kwargs(self) -> Any: return self._kwargs
[docs] @staticmethod def default(messages: Sequence[LLMMessage], **kwargs: Any) -> LLMRequest: """Creates a default LLMRequest with an empty context.""" return LLMRequest(LLMContext.empty(), messages, **kwargs)
[docs] class LLMResponse: def __init__(self, request: LLMRequest, result: Optional[LLMResult], duration: float) -> None: self._request = request self._result = result self._duration = duration @property def result(self) -> Optional[LLMResult]: return self._result @property def value(self, default: str = "") -> str: return self._result.first_choice if self._result is not None else default @property def duration(self) -> float: return self._duration
[docs] @staticmethod def empty(request: LLMRequest) -> LLMResponse: """Creates an empty LLMResponse for a given request.""" return LLMResponse(request, None, -1.0)
ExecuteLLMRequest = Callable[[LLMRequest], LLMResponse]
[docs] class LLMMiddleware(Protocol): """ Protocol for defining LLM middleware. Middleware can intercept and modify requests and responses between the client and the LLM, introducing custom logic. """ def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: ...
[docs] class LLMMiddlewareChain: """Manages a chain of LLM middlewares and executes requests through them.""" def __init__(self, llm: LLMBase, middlewares: Optional[Sequence[LLMMiddleware]] = None) -> None: self._llm = llm self._middlewares: list[LLMMiddleware] = list(middlewares) if middlewares else []
[docs] def add_middleware(self, middleware: LLMMiddleware) -> None: """Add middleware to a chain.""" self._middlewares.append(middleware)
[docs] def execute(self, request: LLMRequest) -> LLMResponse: """Execute middleware chain.""" def execute_request(r: LLMRequest) -> LLMResponse: start = time.time() result = self._llm.post_chat_request(r.context, request.messages, **r.kwargs) return LLMResponse(request, result, time.time() - start) handler: ExecuteLLMRequest = execute_request for middleware in reversed(self._middlewares): handler = self._wrap_middleware(middleware, handler) return handler(request)
@property def llm(self) -> LLMBase: return self._llm def _wrap_middleware(self, middleware: LLMMiddleware, handler: ExecuteLLMRequest) -> ExecuteLLMRequest: def wrapped(request: LLMRequest) -> LLMResponse: return middleware(self._llm, handler, request) return wrapped
[docs] class LLMLoggingMiddleware: """Middleware for logging LLM requests, responses and consumptions.""" def __init__(self, log_consumptions: bool = False) -> None: self.log_consumptions = log_consumptions def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: request.context.logger.info( f"Sending request with {len(request.messages)} message(s) to {llm.configuration.model_name()}" ) response = execute(request) if response.result is not None: request.context.logger.info(f"Response: `{response.result.first_choice}` in {response.duration} seconds") if self.log_consumptions: for consumption in response.result.consumptions: request.context.logger.info(f"{consumption}") else: request.context.logger.warning("No response") return response
[docs] class LLMFileLoggingMiddleware: """Middleware for logging LLM requests, responses and consumptions into a file."""
[docs] def __init__(self, log_file: str, component_name: str, log_consumptions: bool = False) -> None: """Initialize the middleware with the path to the log_file, component name and whether to log consumptions""" self.log_file = log_file self.component_name = component_name self.log_consumptions = log_consumptions self._lock = Lock()
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: self._log_llm_request(request) response = execute(request) self._log_llm_response(response) return response def _log_llm_request(self, request: LLMRequest) -> None: messages_str = "\n\n".join(message.format() for message in request.messages) self._log(f"LLM input for {self.component_name}:\n{messages_str}") def _log_llm_response(self, response: LLMResponse) -> None: if response.result is None: self._log(f"LLM output for {self.component_name} is not available") return self._log( f"LLM output for {self.component_name} Duration: {response.duration:.2f} Output:\n" f"{response.result.first_choice}" ) if self.log_consumptions: for consumption in response.result.consumptions: self._log(f"Consumption for {self.component_name}: {consumption}") def _log(self, content: str) -> None: """Append `content` to a current log file""" with self._lock: # ensure each write is done atomically in case of multi-threading with open(self.log_file, "a", encoding="utf-8") as file: file.write(f"\n{content}")
[docs] class LLMRetryMiddleware: """ Middleware for implementing retry logic for LLM requests. Attempts to retry failed requests a specified number of times with a delay between attempts. """ def __init__(self, retries: int, delay: float, exception_to_check: Optional[type[Exception]] = None) -> None: self._retries = retries self._delay = delay self._exception_to_check = exception_to_check if exception_to_check else Exception def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: attempt = 0 exceptions: List[Exception] = [] while attempt < self._retries: try: return execute(request) except Exception as e: if not isinstance(e, self._exception_to_check): raise exceptions.append(e) attempt += 1 if attempt >= self._retries: break time.sleep(self._delay) raise LLMOutOfRetriesException(llm_name=llm.model_name, retry_count=attempt, exceptions=exceptions)
class CacheEntry: """Represents a cached LLM response.""" def __init__(self, response: LLMResponse, ttl: float) -> None: self.response = self._rebuild_response(response) self.timestamp = time.time() self.ttl = ttl @staticmethod def _rebuild_response(response: LLMResponse) -> LLMResponse: """Modify the response so it has zero duration and consumptions in 'cached_' units""" if response.result is not None: cached_consumptions: List[Consumption] = [ Consumption(value=consumption.value, unit=f"cached_{consumption.unit}", kind=consumption.kind) for consumption in response.result.consumptions ] cached_result = LLMResult(response.result.choices, cached_consumptions, response.result.raw_response) else: cached_result = None cached_response = LLMResponse(response._request, cached_result, 0) return cached_response @property def is_expired(self) -> bool: return time.time() - self.timestamp >= self.ttl def renew_lifetime(self) -> None: """Update the timestamp to extend the lifetime.""" self.timestamp = time.time()
[docs] class LLMCachingMiddleware: """Middleware that caches LLM responses to avoid duplicate calls."""
[docs] def __init__(self, ttl: float = 300.0, cache_limit_size: int = 10) -> None: """ Initialize the caching middleware. Args: ttl: Sliding window time-to-live in seconds for cache entries (default: 5 mins) cache_limit_size: Cache limit size in cached entries (default: 10) """ self._cache: OrderedDict[str, CacheEntry] = OrderedDict() self._ttl = ttl self._cache_limit_size = cache_limit_size
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: self._remove_expired() key = self.get_hash(request, llm.configuration) if key in self._cache: entry = self._cache[key] entry.renew_lifetime() # renew on hit self._cache.move_to_end(key) # move to most recent return entry.response response = execute(request) # cache the new response self._cache[key] = CacheEntry(response=response, ttl=self._ttl) self._cache.move_to_end(key) self._enforce_cache_limit() return response
[docs] @staticmethod def get_hash(request: LLMRequest, configuration: T_Configuration) -> str: """Convert the request and LLM configuration to a hash with hashlib.sha256.""" serialized = json.dumps( { "configuration": {key: str(value) for key, value in configuration.__dict__.items()}, # TODO: request.context is not serialized "messages": [m.normalize() for m in request.messages], "kwargs": request.kwargs, }, sort_keys=True, ) return hashlib.sha256(serialized.encode()).hexdigest()
def _remove_expired(self) -> None: """Remove all expired cache entries.""" expired_keys = [key for key, entry in self._cache.items() if entry.is_expired] for key in expired_keys: del self._cache[key] def _enforce_cache_limit(self) -> None: """Remove oldest entries if cache size exceeds limit.""" while len(self._cache) > self._cache_limit_size: self._cache.popitem(last=False) # remove the first (oldest) item
[docs] def clear_cache(self) -> None: """Clear all cached entries.""" self._cache.clear()