import logging
from typing import Iterable, List, Optional
import more_itertools
from ._agent_context import AgentContext
from ._agent_context_store import AgentContextStore
from ._budget import Budget
from ._cancellation_token import CancellationToken
from ._chat_history import ChatHistory
from ._chat_message import ChatMessage
from ._composite_message_collection import CompositeMessageCollection
from ._context_base import ContextBase
from ._execution_context import ExecutionContext
from ._message_collection import MessageCollection
from ._message_list import MessageList
from ._monitored import Monitored
logger = logging.getLogger(__name__)
[docs]
class ChainContext(ContextBase, MessageCollection):
"""
Represents the data context for a :class:`council.chains.Chain`
"""
def __init__(
self,
store: AgentContextStore,
execution_context: ExecutionContext,
name: str,
budget: Budget,
messages: Optional[Iterable[ChatMessage]] = None,
):
super().__init__(store, execution_context, budget)
self._name = name
self._current_messages = MessageList()
self._previous_messages = MessageList(messages)
self._current_iteration_messages = CompositeMessageCollection([self._previous_messages, self._current_messages])
self._previous_iteration_messages = CompositeMessageCollection(
list(self._store.chain_iterations(self._name))[:-1]
)
self._all_iteration_messages = CompositeMessageCollection(
[self._previous_iteration_messages, self._current_iteration_messages]
)
self._all_messages = CompositeMessageCollection([self.chat_history, self._all_iteration_messages])
@property
def cancellation_token(self) -> CancellationToken:
"""
returns the cancellation token
"""
return self._store.cancellation_token
@property
def budget(self) -> Budget:
"""
returns the budget
"""
return self._budget
@property
def messages(self) -> Iterable[ChatMessage]:
"""
returns all the visible messages for the chain, ordered by execution iteration (oldest first).
This contains:
- the :class:`ChatHistory`
- all messages from previous iterations
- visible messages from current iteration
"""
return self._all_messages.messages
@property
def reversed(self) -> Iterable[ChatMessage]:
"""
similar to :meth:`messages`, but in reverse order (most recent first).
"""
return self._all_messages.reversed
@property
def chain_histories(self) -> Iterable[MessageCollection]:
"""
returns the collections of all messages generated by the current chain, grouped by execution iteration
"""
for item in self._store.iterations:
chain = item.chains.get(self._name)
if chain is not None:
yield chain
@property
def current(self) -> MessageCollection:
"""
Returns the :class:`MessageCollection` for the current execution of a :class:`.Chain`
Returns:
MessageCollection: a collection of messages
"""
return self._current_iteration_messages
[docs]
@staticmethod
def from_agent_context(context: AgentContext, monitored: Monitored, name: str, budget: Optional[Budget] = None):
"""
creates a new instance from an :class:`AgentContext`, adjust the execution context appropriately.
"""
context._store.current_iteration.ensure_chain_exists(name)
return ChainContext(
context._store, context._execution_context.new_for(monitored), name, budget or Budget.default()
)
[docs]
def fork_for(self, monitored: Monitored, budget: Optional[Budget] = None) -> "ChainContext":
"""
forks the context for the given object, adjust the execution context appropriately
"""
return ChainContext(
self._store,
self._execution_context.new_for(monitored),
self._name,
budget or self._budget,
more_itertools.flatten([self._previous_messages.messages, self._current_messages.messages]),
)
[docs]
def should_stop(self) -> bool:
"""
returns `True` is the execution of the chain should be stopped. `False` otherwise.
Returns:
bool: `True` is either the budget is expired or the cancellation token is set
"""
if self._budget.is_expired():
logger.debug('message="stopping" reason="budget expired"')
return True
if self.cancellation_token.cancelled:
logger.debug('message="stopping" reason="cancellation token is set"')
return True
return False
[docs]
def merge(self, contexts: List["ChainContext"]) -> None:
"""
merge the given context to the context
"""
for context in contexts:
self._current_messages.add_messages(context._current_messages.messages)
[docs]
def append(self, message: ChatMessage) -> None:
"""
adds the message to the context
"""
self._current_messages.add_message(message)
self._store.current_iteration.append_to_chain(self._name, message, self._execution_context.entry)
[docs]
def extend(self, messages: Iterable[ChatMessage]) -> None:
"""
adds many message to the context
"""
for message in messages:
self.append(message)
[docs]
@staticmethod
def from_chat_history(history: ChatHistory, budget: Optional[Budget] = None) -> "ChainContext":
"""
helper function that creates a new instance from a :class:`ChatHistory`.
For test purpose only.
"""
from ..mocks import MockMonitored
context = AgentContext.from_chat_history(history)
context.new_iteration()
return ChainContext.from_agent_context(context, MockMonitored("mock chain"), "mock chain", budget)
[docs]
@staticmethod
def from_user_message(message: str, budget: Optional[Budget] = None) -> "ChainContext":
"""
creates a new instance from a user message. The :class:`ChatHistory` contains only the user message
"""
return ChainContext.from_chat_history(ChatHistory.from_user_message(message), budget)
[docs]
@staticmethod
def empty() -> "ChainContext":
"""
helper function that creates a new empty instance.
For test purpose only.
"""
return ChainContext.from_chat_history(ChatHistory())