Source code for council.llm.anthropic_llm

from __future__ import annotations

from typing import Any, Sequence

from anthropic import Anthropic, APITimeoutError, APIStatusError
from anthropic._types import NOT_GIVEN

from council.contexts import LLMContext, Consumption
from council.llm import (
    LLMBase,
    LLMMessage,
    LLMMessageRole,
    LLMResult,
    LLMCallTimeoutException,
    LLMCallException,
    AnthropicLLMConfiguration,
    LLMException,
    LLMessageTokenCounterBase,
)

_HUMAN_TURN = Anthropic.HUMAN_PROMPT
_ASSISTANT_TURN = Anthropic.AI_PROMPT


class AnthropicTokenCounter(LLMessageTokenCounterBase):
    def __init__(self, client: Anthropic):
        self._client = client

    def count_messages_token(self, messages: Sequence[LLMMessage]) -> int:
        tokens = 0
        for msg in messages:
            tokens += self._client.count_tokens(msg.content)
        return tokens


[docs] class AnthropicLLM(LLMBase): """ Implementation for an Anthropic LLM. Notes: More details: https://docs.anthropic.com/claude/docs and https://docs.anthropic.com/claude/reference/complete_post """
[docs] def __init__(self, config: AnthropicLLMConfiguration): """ Initialize a new instance. Args: config(AnthropicLLMConfiguration): configuration for the instance """ super().__init__() self._config = config self._client = Anthropic(api_key=config.api_key.value, max_retries=0)
def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult: prompt = self._to_anthropic_messages(messages) try: completion = self._client.completions.create( prompt=prompt, model=self._config.model.unwrap(), max_tokens_to_sample=self._config.max_tokens.unwrap(), timeout=self._config.timeout.value, temperature=self._config.temperature.unwrap_or(NOT_GIVEN), top_k=self._config.top_k.unwrap_or(NOT_GIVEN), top_p=self._config.top_p.unwrap_or(NOT_GIVEN), ) response = completion.completion return LLMResult(choices=[response], consumptions=self.to_consumptions(prompt, response)) except APITimeoutError as e: raise LLMCallTimeoutException(self._config.timeout.value) from e except APIStatusError as e: raise LLMCallException(code=e.status_code, error=e.message) from e def to_consumptions(self, prompt: str, response: str) -> Sequence[Consumption]: model = self._config.model.unwrap() prompt_tokens = self._client.count_tokens(prompt) completion_tokens = self._client.count_tokens(response) return [ Consumption(1, "call", f"{model}"), Consumption(prompt_tokens, "token", f"{model}:prompt_tokens"), Consumption(completion_tokens, "token", f"{model}:completion_tokens"), Consumption(prompt_tokens + completion_tokens, "token", f"{model}:total_tokens"), ] @staticmethod def _to_anthropic_messages(messages: Sequence[LLMMessage]) -> str: messages_count = len(messages) if messages_count == 0: raise LLMException("No message to process.") result = [] if messages[0].is_of_role(LLMMessageRole.System) and messages_count > 1: result.append(f"{_HUMAN_TURN} {messages[0].content}\n{messages[1].content}") remaining = messages[2:] else: result.append(f"{_HUMAN_TURN} {messages[0].content}") remaining = messages[1:] for item in remaining: prefix = _HUMAN_TURN if item.is_of_role(LLMMessageRole.User) else _ASSISTANT_TURN result.append(f"{prefix} {item.content}") result.append(_ASSISTANT_TURN) return "".join(result)
[docs] @staticmethod def from_env() -> AnthropicLLM: """ Helper function that create a new instance by getting the configuration from environment variables. Returns: AnthropicLLM """ return AnthropicLLM(AnthropicLLMConfiguration.from_env())