Source code for council.llm.llm_function.llm_pipeline

from __future__ import annotations

from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, Final, Generic, List, Optional, Protocol, Sequence, Type, TypeVar, Union, cast

from council.llm.base import LLMBase, LLMMessage

from .llm_function import LLMFunction
from .llm_middleware import LLMMiddlewareChain
from .llm_response_parser import BaseModelResponseParser


class ProcessorException(Exception):
    """
    Exception raised during the execution of a Processor.
    Contains information about the input that caused the exception, the exception itself,
    and optionally name of a previous processor that should handle the exception.
    """

    def __init__(self, *, input: str, message: str, transfer_to: Optional[str] = None):
        self.input = input
        self.message = message
        self.transfer_to = transfer_to


class LLMProcessorInput(Protocol):
    """Input for a LLMProcessor."""

    def to_prompt(self) -> str:
        """Convert the object to a prompt string."""
        ...


T_Input = TypeVar("T_Input")
T_Output = TypeVar("T_Output")

# input must implement to_prompt()
T_LLMInput = TypeVar("T_LLMInput", bound=LLMProcessorInput)
# output must implement from_response() and to_response_template()
T_LLMOutput = TypeVar("T_LLMOutput", bound=BaseModelResponseParser)


class LLMProcessorRecord:
    """Record of a LLMProcessor execution."""

    def __init__(self, *, input: str, output: str, exception: Optional[ProcessorException] = None):
        self.input = input
        self.output = output
        self.exception = exception

    def set_exception(self, exception: ProcessorException) -> None:
        self.exception = exception

    def to_dict(self) -> Dict[str, str]:
        result = {
            "input": self.input,
            "output": self.output,
        }
        if self.exception is not None:
            result["exception"] = self.exception.message
        return result


[docs] class ProcessorBase(Generic[T_Input, T_Output], ABC): """ Base class for a Processor, transforming an input object into an output object. List of processors can be then used to create a PipelineProcessor. """ @abstractmethod def execute(self, obj: T_Input, exception: Optional[ProcessorException] = None) -> T_Output: # implement logic that transforms obj into T_Output pass
[docs] class LLMProcessor(ProcessorBase[T_LLMInput, T_LLMOutput]): """ ProcessorBase that uses an LLM to convert an input object into an output object. Keeps track of records processed by this instance. """ def __init__( self, llm: Union[LLMBase, LLMMiddlewareChain], output_obj_type: Type[T_LLMOutput], name: Optional[str] = None ) -> None: self._llm_middleware = LLMMiddlewareChain(llm) if not isinstance(llm, LLMMiddlewareChain) else llm self._output_obj_type = output_obj_type self.name = name or output_obj_type.__name__ self._records: List[LLMProcessorRecord] = [] self.PROMPT_TEMPLATE: Final[str] = "\n".join(["{input_obj_prompt}", "", "{response_template}"]) @property def records(self) -> List[LLMProcessorRecord]: """List of all records processed by this instance.""" return self._records @property def records_with_exceptions(self) -> Sequence[LLMProcessorRecord]: """List of records processed by this instance that resulted in an exception.""" return [record for record in self.records if record.exception is not None] def add_record(self, *, input_prompt: str, produced_output: str) -> None: self._records.append(LLMProcessorRecord(input=input_prompt, output=produced_output)) def last_record(self) -> LLMProcessorRecord: # TODO: naive? if len(self.records) == 0: raise ValueError("No records found.") return self.records[-1] def execute(self, obj: T_LLMInput, exception: Optional[ProcessorException] = None) -> T_LLMOutput: system_prompt = self.PROMPT_TEMPLATE.format( input_obj_prompt=obj.to_prompt(), response_template=self._output_obj_type.to_response_template(), ) messages = [LLMMessage.system_message(system_prompt)] if exception is not None: self.last_record().set_exception(exception) messages.extend([LLMMessage.assistant_message(exception.input), LLMMessage.user_message(exception.message)]) llm_func: LLMFunction[T_LLMOutput] = LLMFunction( llm=self._llm_middleware, response_parser=self._output_obj_type.from_response, messages=messages, ) llm_func_response = llm_func.execute_with_llm_response() self.add_record(input_prompt=system_prompt, produced_output=llm_func_response.llm_response.value) return llm_func_response.response
class PipelineProcessorBase(Generic[T_Input, T_Output], ABC): """ Base class for a PipelineProcessor, executing a sequence of Processors. """ def __init__(self, processors: Sequence[ProcessorBase]): self.processors = list(processors) @abstractmethod def execute(self, obj: T_Input) -> T_Output: pass
[docs] class NaivePipelineProcessor(PipelineProcessorBase[T_Input, T_Output]): """ PipelineProcessor that executes processors in a linear order without any error handling. Each processor should be able to handle errors independently. .. mermaid:: flowchart LR A(Processor A) --> B(Processor B) A --> A B --> B """ def execute(self, obj: T_Input) -> T_Output: current_obj: T_Input = obj for processor in self.processors: current_obj = processor.execute(current_obj) return cast(T_Output, current_obj)
[docs] class BacktrackingPipelineProcessor(PipelineProcessorBase[T_Input, T_Output]): """ PipelineProcessor that executes processors in a linear order backtracking errors. If a processor fails, the pipeline will backtrack to the previous processor (or specified in the exception) and try again. .. mermaid:: flowchart LR A(Processor A) --> B(Processor B) A --> A B --> B B --> A """ def __init__(self, processors: Sequence[ProcessorBase], max_backtracks: int = 3): super().__init__(processors) self.max_backtracks = max_backtracks @staticmethod def should_handle_exception(processor: ProcessorBase, exception: Optional[ProcessorException] = None) -> bool: if exception is None: return True if not isinstance(processor, LLMProcessor): return False if exception.transfer_to is None: return True return exception.transfer_to == processor.name def execute(self, obj: T_Input) -> T_Output: index = 0 inputs: List[T_Input] = [deepcopy(obj) for _ in range(len(self.processors))] previous_exception: Optional[ProcessorException] = None current_obj = obj backtrack_count = 0 while index < len(self.processors): try: if not self.should_handle_exception(self.processors[index], previous_exception): if index == 0: raise ProcessorException( input=str(inputs[0]), message="Cannot backtrack from first processor", ) index -= 1 continue current_obj = self.processors[index].execute(inputs[index], previous_exception) index += 1 previous_exception = None # reset exception after successful execution if index < len(self.processors): inputs[index] = current_obj except ProcessorException as e: if backtrack_count >= self.max_backtracks: raise e # out of retry index = max(0, index - 1) backtrack_count += 1 previous_exception = e return cast(T_Output, current_obj)