Source code for council.controllers.llm_controller

import logging
from typing import List, Tuple

from council.contexts import AgentContext, ScoredChatMessage
from council.chains import Chain
from council.llm import LLMMessage, LLMBase
from council.utils import Option
from council.runners import Budget

from .controller_base import ControllerBase
from .execution_unit import ExecutionUnit

logger = logging.getLogger(__name__)


[docs]class LLMController(ControllerBase): """ A controller that uses an LLM to decide the execution plan """ _llm: LLMBase
[docs] def __init__(self, llm: LLMBase, response_threshold: float = 0, top_k_execution_plan: int = 10000): """ Initialize a new instance Parameters: llm (LLMBase): the instance of LLM to use response_threshold (float): a minimum threshold to select a response from its score top_k_execution_plan (int): maximum number of execution plan returned """ self._llm = llm self._response_threshold = response_threshold self._top_k = top_k_execution_plan
[docs] def select_responses(self, context: AgentContext) -> List[ScoredChatMessage]: return context.evaluationHistory[-1]
[docs] def get_plan(self, context: AgentContext, chains: List[Chain], budget: Budget) -> List[ExecutionUnit]: messages = self._build_llm_messages(chains, context) llm_result = self._llm.post_chat_request(messages) response = llm_result.first_choice logger.debug(f"llm response: {response}") parsed = [self._parse_line(line, chains) for line in response.splitlines() if line.lower().startswith("name:")] filtered = [r.unwrap() for r in parsed if r.is_some() and r.unwrap()[1] > self._response_threshold] if (filtered is None) or (len(filtered) == 0): return [] filtered.sort(key=lambda item: item[1], reverse=True) result = [ExecutionUnit(r[0], budget) for r in filtered if r is not None] return result[: self._top_k]
@staticmethod def _build_llm_messages(chains, context): answer_choices = "\n ".join([f"name: {c.name}, description: {c.description}" for c in chains]) task_description = [ "# Role:", "You are an assistant responsible to identify the intent of the user against a list of categories.", "Categories are given as a name and a description formatted precisely as:", "name: {name}, description: {description})", answer_choices, "# Instructions:", "# Score how relevant a category is from 0 to 10 using their description", "# For each category, your scores will be formatted precisely as:", "Name: {name};Score: {score as int};{short justification}", "# When no category is relevant, you will answer exactly with 'unknown'", ] messages = [ LLMMessage.system_message("\n".join(task_description)), LLMMessage.user_message( f"What are most relevant categories for:\n {context.chatHistory.try_last_user_message.unwrap().message}" ), ] return messages @staticmethod def _parse_line(line: str, chains: List[Chain]) -> Option[Tuple[Chain, int]]: result: Option[Tuple[Chain, int]] = Option.none() name: str = "" score: str = "" line = line.lower().removeprefix("name:") try: (name, score, _j) = line.split(";", 3) name = name.strip().casefold() chain = next(filter(lambda item: item.name.casefold() == name, chains)) score = score.replace("score:", "").strip() result = Option.some((chain, int(score))) except StopIteration: logger.warning(f'message="no chain found with name `{name}`"') finally: return result