Source code for council.controllers.llm_controller

import logging
from typing import List, Optional, Tuple

from council.chains import Chain
from council.contexts import AgentContext, LLMContext, Monitored
from council.llm import LLMBase, LLMMessage
from council.utils import Option
from .controller_base import ControllerBase
from .execution_unit import ExecutionUnit

logger = logging.getLogger(__name__)


[docs] class LLMController(ControllerBase): """ A controller that uses an LLM to decide the execution plan """ _llm: Monitored[LLMBase]
[docs] def __init__(self, chains: List[Chain], llm: LLMBase, response_threshold: float = 0.0, top_k: Optional[int] = None): """ Initialize a new instance of an LLMController Parameters: llm (LLMBase): the instance of LLM to use response_threshold (float): a minimum threshold to select a response from its score top_k (int): maximum number of execution plan returned """ super().__init__(chains=chains) self._llm = self.new_monitor("llm", llm) self._response_threshold = response_threshold self._top_k = top_k self._llm_system_message = self._build_system_message()
@property def llm(self) -> LLMBase: """ the LLM used by the controller """ return self._llm.inner def _execute(self, context: AgentContext) -> List[ExecutionUnit]: response = self._call_llm(context) parsed = [ self._parse_line(line, self._chains) for line in response.strip().splitlines() if line.lower().startswith("name:") ] filtered = [r.unwrap() for r in parsed if r.is_some() and r.unwrap()[1] > self._response_threshold] if (filtered is None) or (len(filtered) == 0): return [] filtered.sort(key=lambda item: item[1], reverse=True) result = [ ExecutionUnit(chain, context.budget, name=f"{chain.name};{score}") for chain, score in filtered if chain is not None ] if self._top_k is not None and self._top_k > 0: return result[: self._top_k] return result def _call_llm(self, context: AgentContext) -> str: messages = self._build_llm_messages(context) llm_result = self.llm.post_chat_request(LLMContext.from_context(context, self._llm), messages) response = llm_result.first_choice logger.debug(f"llm response: {response}") return response def _build_llm_messages(self, context: AgentContext) -> List[LLMMessage]: messages = [ self._llm_system_message, LLMMessage.user_message( "What are most relevant categories" f"for:\n {context.chat_history.try_last_user_message.unwrap().message}" ), ] return messages def _build_system_message(self) -> LLMMessage: answer_choices = "\n ".join([f"name: {c.name}, description: {c.description}" for c in self._chains]) task_description = [ "# Role:", "You are an assistant responsible to identify the intent of the user against a list of categories.", "Categories are given as a name and a description formatted precisely as:", "name: {name}, description: {description})", answer_choices, "# INSTRUCTIONS:", "# Score how relevant a category is from 0 to 10 using their description", "# For each category, your scores will be formatted precisely as:", "Name: {name};Score: {score as int};{short justification}", "# When no category is relevant, you will answer exactly with 'unknown'", ] return LLMMessage.system_message("\n".join(task_description)) @staticmethod def _parse_line(line: str, chains: List[Chain]) -> Option[Tuple[Chain, int]]: result: Option[Tuple[Chain, int]] = Option.none() name: str = "" score: str = "" line = line.lower().removeprefix("name:") try: (name, score, _j) = line.split(";", 3) name = name.strip().casefold() chain = next(filter(lambda item: item.name.casefold() == name, chains)) score = score.replace("score:", "").strip() result = Option.some((chain, int(score))) except StopIteration: logger.warning(f'message="no chain found with name `{name}`"') finally: return result