Source code for council.controllers.llm_controller

import logging
from typing import List, Tuple

from council.contexts import AgentContext, ScoredChatMessage
from council.chains import Chain
from council.llm import LLMMessage, LLMBase
from council.utils import Option
from council.runners import Budget

from .controller_base import ControllerBase
from .execution_unit import ExecutionUnit

logger = logging.getLogger(__name__)


[docs]class LLMController(ControllerBase): """ A controller that uses an LLM to decide the execution plan """ _llm: LLMBase
[docs] def __init__(self, llm: LLMBase, response_threshold: float = 0, top_k_execution_plan: int = 10000): """ Initialize a new instance Parameters: llm (LLMBase): the instance of LLM to use response_threshold (float): a minimum threshold to select a response from its score top_k_execution_plan (int): maximum number of execution plan returned """ self._llm = llm self._response_threshold = response_threshold self._top_k = top_k_execution_plan
[docs] def select_responses(self, context: AgentContext) -> List[ScoredChatMessage]: return context.evaluationHistory[-1]
[docs] def get_plan(self, context: AgentContext, chains: List[Chain], budget: Budget) -> List[ExecutionUnit]: answer_choices = "\n ".join([f"name: {c.name}, description: {c.description}" for c in chains]) task_description = [ "You are an assistant responsible to identify the intent of the user. ", "Categories are given as a name and a category (name: {name}, description: {description})", answer_choices, "Instructions:" "# score categories out of 10 using there description", "# For each category, you will answer with {name};{score}", "# Each response is provided on a new line", "# When no category is relevant, you will answer exactly with 'unknown'", ] messages = [ LLMMessage.system_message("\n".join(task_description)), LLMMessage.user_message( f"what are most relevant categories for: {context.chatHistory.try_last_user_message.unwrap().message}" ), ] llm_result = self._llm.post_chat_request(messages) response = llm_result.first_choice logger.debug(f"llm response: {response}") parsed = [self.parse_line(line, chains) for line in response.splitlines()] filtered = [r.unwrap() for r in parsed if r.is_some() and r.unwrap()[1] > self._response_threshold] if (filtered is None) or (len(filtered) == 0): return [] filtered.sort(key=lambda item: item[1], reverse=True) result = [ExecutionUnit(r[0], budget) for r in filtered if r is not None] return result[: self._top_k]
@staticmethod def parse_line(line: str, chains: List[Chain]) -> Option[Tuple[Chain, int]]: result: Option[Tuple[Chain, int]] = Option.none() name = "" try: (name, score) = line.split(";", 2) name = name.casefold() chain = next(filter(lambda item: item.name.casefold() == name, chains)) result = Option.some((chain, int(score))) except StopIteration: logger.warning(f'message="not chain found with name `{name}`"') finally: return result