Source code for council.contexts._chain_context

import logging
from typing import Iterable, List, Optional

import more_itertools

from ._agent_context import AgentContext
from ._agent_context_store import AgentContextStore
from ._budget import Budget
from ._cancellation_token import CancellationToken
from ._chat_history import ChatHistory
from ._chat_message import ChatMessage
from ._composite_message_collection import CompositeMessageCollection
from ._context_base import ContextBase
from ._execution_context import ExecutionContext
from ._message_collection import MessageCollection
from ._message_list import MessageList
from ._monitored import Monitored

logger = logging.getLogger(__name__)


[docs] class ChainContext(ContextBase, MessageCollection): """ Represents the data context for a :class:`council.chains.Chain` """ def __init__( self, store: AgentContextStore, execution_context: ExecutionContext, name: str, budget: Budget, messages: Optional[Iterable[ChatMessage]] = None, ): super().__init__(store, execution_context, budget) self._name = name self._current_messages = MessageList() self._previous_messages = MessageList(messages) self._current_iteration_messages = CompositeMessageCollection([self._previous_messages, self._current_messages]) self._previous_iteration_messages = CompositeMessageCollection( list(self._store.chain_iterations(self._name))[:-1] ) self._all_iteration_messages = CompositeMessageCollection( [self._previous_iteration_messages, self._current_iteration_messages] ) self._all_messages = CompositeMessageCollection([self.chat_history, self._all_iteration_messages]) @property def cancellation_token(self) -> CancellationToken: """ returns the cancellation token """ return self._store.cancellation_token @property def budget(self) -> Budget: """ returns the budget """ return self._budget @property def messages(self) -> Iterable[ChatMessage]: """ returns all the visible messages for the chain, ordered by execution iteration (oldest first). This contains: - the :class:`ChatHistory` - all messages from previous iterations - visible messages from current iteration """ return self._all_messages.messages @property def reversed(self) -> Iterable[ChatMessage]: """ similar to :meth:`messages`, but in reverse order (most recent first). """ return self._all_messages.reversed @property def chain_histories(self) -> Iterable[MessageCollection]: """ returns the collections of all messages generated by the current chain, grouped by execution iteration """ for item in self._store.iterations: chain = item.chains.get(self._name) if chain is not None: yield chain @property def current(self) -> MessageCollection: """ Returns the :class:`MessageCollection` for the current execution of a :class:`.Chain` Returns: MessageCollection: a collection of messages """ return self._current_iteration_messages
[docs] @staticmethod def from_agent_context(context: AgentContext, monitored: Monitored, name: str, budget: Optional[Budget] = None): """ creates a new instance from an :class:`AgentContext`, adjust the execution context appropriately. """ context._store.current_iteration.ensure_chain_exists(name) return ChainContext( context._store, context._execution_context.new_for(monitored), name, budget or Budget.default() )
[docs] def fork_for(self, monitored: Monitored, budget: Optional[Budget] = None) -> "ChainContext": """ forks the context for the given object, adjust the execution context appropriately """ return ChainContext( self._store, self._execution_context.new_for(monitored), self._name, budget or self._budget, more_itertools.flatten([self._previous_messages.messages, self._current_messages.messages]), )
[docs] def should_stop(self) -> bool: """ returns `True` is the execution of the chain should be stopped. `False` otherwise. Returns: bool: `True` is either the budget is expired or the cancellation token is set """ if self._budget.is_expired(): logger.debug('message="stopping" reason="budget expired"') return True if self.cancellation_token.cancelled: logger.debug('message="stopping" reason="cancellation token is set"') return True return False
[docs] def merge(self, contexts: List["ChainContext"]) -> None: """ merge the given context to the context """ for context in contexts: self._current_messages.add_messages(context._current_messages.messages)
[docs] def append(self, message: ChatMessage) -> None: """ adds the message to the context """ self._current_messages.add_message(message) self._store.current_iteration.append_to_chain(self._name, message, self._execution_context.entry)
[docs] def extend(self, messages: Iterable[ChatMessage]) -> None: """ adds many message to the context """ for message in messages: self.append(message)
[docs] @staticmethod def from_chat_history(history: ChatHistory, budget: Optional[Budget] = None) -> "ChainContext": """ helper function that creates a new instance from a :class:`ChatHistory`. For test purpose only. """ from ..mocks import MockMonitored context = AgentContext.from_chat_history(history) context.new_iteration() return ChainContext.from_agent_context(context, MockMonitored("mock chain"), "mock chain", budget)
[docs] @staticmethod def from_user_message(message: str, budget: Optional[Budget] = None) -> "ChainContext": """ creates a new instance from a user message. The :class:`ChatHistory` contains only the user message """ return ChainContext.from_chat_history(ChatHistory.from_user_message(message), budget)
[docs] @staticmethod def empty() -> "ChainContext": """ helper function that creates a new empty instance. For test purpose only. """ return ChainContext.from_chat_history(ChatHistory())