from __future__ import annotations
import time
from threading import Lock
from typing import Any, Callable, List, Optional, Protocol, Sequence
from council.contexts import LLMContext
from .llm_base import LLMBase, LLMMessage, LLMResult
from .llm_exception import LLMOutOfRetriesException
[docs]
class LLMRequest:
def __init__(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> None:
self._context = context
self._messages = messages
self._kwargs = kwargs
@property
def context(self) -> LLMContext:
return self._context
@property
def messages(self) -> Sequence[LLMMessage]:
return self._messages
@property
def kwargs(self) -> Any:
return self._kwargs
[docs]
@staticmethod
def default(messages: Sequence[LLMMessage], **kwargs: Any) -> LLMRequest:
"""Creates a default LLMRequest with an empty context."""
return LLMRequest(LLMContext.empty(), messages, **kwargs)
[docs]
class LLMResponse:
def __init__(self, request: LLMRequest, result: Optional[LLMResult], duration: float) -> None:
self._request = request
self._result = result
self._duration = duration
@property
def result(self) -> Optional[LLMResult]:
return self._result
@property
def value(self, default: str = "") -> str:
return self._result.first_choice if self._result is not None else default
@property
def duration(self) -> float:
return self._duration
[docs]
@staticmethod
def empty(request: LLMRequest) -> LLMResponse:
"""Creates an empty LLMResponse for a given request."""
return LLMResponse(request, None, -1.0)
ExecuteLLMRequest = Callable[[LLMRequest], LLMResponse]
[docs]
class LLMMiddleware(Protocol):
"""
Protocol for defining LLM middleware.
Middleware can intercept and modify requests and responses between the client and the LLM, introducing custom logic.
"""
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: ...
[docs]
class LLMMiddlewareChain:
"""Manages a chain of LLM middlewares and executes requests through them."""
def __init__(self, llm: LLMBase, middlewares: Optional[Sequence[LLMMiddleware]] = None) -> None:
self._llm = llm
self._middlewares: list[LLMMiddleware] = list(middlewares) if middlewares else []
[docs]
def add_middleware(self, middleware: LLMMiddleware) -> None:
"""Add middleware to a chain."""
self._middlewares.append(middleware)
[docs]
def execute(self, request: LLMRequest) -> LLMResponse:
"""Execute middleware chain."""
def execute_request(r: LLMRequest) -> LLMResponse:
start = time.time()
result = self._llm.post_chat_request(r.context, request.messages, **r.kwargs)
return LLMResponse(request, result, time.time() - start)
handler: ExecuteLLMRequest = execute_request
for middleware in reversed(self._middlewares):
handler = self._wrap_middleware(middleware, handler)
return handler(request)
@property
def llm(self) -> LLMBase:
return self._llm
def _wrap_middleware(self, middleware: LLMMiddleware, handler: ExecuteLLMRequest) -> ExecuteLLMRequest:
def wrapped(request: LLMRequest) -> LLMResponse:
return middleware(self._llm, handler, request)
return wrapped
[docs]
class LLMLoggingMiddleware:
"""Middleware for logging LLM requests and responses."""
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse:
request.context.logger.info(
f"Sending request with {len(request.messages)} message(s) to {llm.configuration.model_name()}"
)
response = execute(request)
if response.result is not None:
request.context.logger.info(f"Response: `{response.result.first_choice}` in {response.duration} seconds")
else:
request.context.logger.warning("No response")
return response
[docs]
class LLMFileLoggingMiddleware:
"""Middleware for logging LLM requests and responses into a file."""
[docs]
def __init__(self, log_file: str, component_name: str) -> None:
"""Initialize the middleware with the path to the log_file."""
self.log_file = log_file
self.component_name = component_name
self._lock = Lock()
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse:
self._log_llm_request(request)
response = execute(request)
self._log_llm_response(response)
return response
def _log_llm_request(self, request: LLMRequest) -> None:
messages_str = "\n\n".join(message.format() for message in request.messages)
self._log(f"LLM input for {self.component_name}:\n{messages_str}")
def _log_llm_response(self, response: LLMResponse) -> None:
if response.result is None:
self._log(f"LLM output for {self.component_name} is not available")
return
self._log(
f"LLM output for {self.component_name} Duration: {response.duration:.2f} Output:\n"
f"{response.result.first_choice}"
)
def _log(self, content: str) -> None:
"""Append `content` to a current log file"""
with self._lock: # ensure each write is done atomically in case of multi-threading
with open(self.log_file, "a", encoding="utf-8") as file:
file.write(f"\n{content}")
[docs]
class LLMRetryMiddleware:
"""
Middleware for implementing retry logic for LLM requests.
Attempts to retry failed requests a specified number of times with a delay between attempts.
"""
def __init__(self, retries: int, delay: float, exception_to_check: Optional[type[Exception]] = None) -> None:
self._retries = retries
self._delay = delay
self._exception_to_check = exception_to_check if exception_to_check else Exception
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse:
attempt = 0
exceptions: List[Exception] = []
while attempt < self._retries:
try:
return execute(request)
except Exception as e:
if not isinstance(e, self._exception_to_check):
raise
exceptions.append(e)
attempt += 1
if attempt >= self._retries:
break
time.sleep(self._delay)
raise LLMOutOfRetriesException(llm_name=llm.model_name, retry_count=attempt, exceptions=exceptions)