Source code for council.llm.openai_llm

from __future__ import annotations

from typing import Any, Optional

import httpx
from httpx import HTTPStatusError, TimeoutException

from . import LLMCallException, LLMCallTimeoutException, OpenAIChatCompletionsModel, OpenAITokenCounter
from .llm_config_object import LLMConfigObject, LLMProviders
from .openai_chat_gpt_configuration import OpenAIChatGPTConfiguration


class OpenAIChatCompletionsModelProvider:
    """
    Represents an OpenAI language model hosted on Azure.
    """

    def __init__(self, config: OpenAIChatGPTConfiguration, name: Optional[str] = None) -> None:
        self.config = config
        bearer = f"Bearer {config.api_key.unwrap()}"
        self._headers = {"Authorization": bearer, "Content-Type": "application/json"}
        self._name = name

    def post_request(self, payload: dict[str, Any]) -> httpx.Response:
        """
        Posts a request to the OpenAI chat completions endpoint.
        """
        uri = self.config.api_host.unwrap() + "/v1/chat/completions"

        timeout = self.config.timeout.unwrap()
        try:
            with httpx.Client(timeout=timeout) as client:
                return client.post(url=uri, headers=self._headers, json=payload)
        except TimeoutException as e:
            raise LLMCallTimeoutException(timeout=timeout, llm_name=self._name) from e
        except HTTPStatusError as e:
            raise LLMCallException(code=e.response.status_code, error=e.response.text, llm_name=self._name) from e


[docs] class OpenAILLM(OpenAIChatCompletionsModel): """ Represents an OpenAI large language model hosted on OpenAI. """ def __init__(self, config: OpenAIChatGPTConfiguration, name: Optional[str] = None): name = name or f"{self.__class__.__name__}" super().__init__( config, OpenAIChatCompletionsModelProvider(config, name).post_request, token_counter=OpenAITokenCounter.from_model(config.model.unwrap_or("")), name=name, ) @staticmethod def from_env(model: Optional[str] = None, api_host: Optional[str] = None) -> OpenAILLM: config: OpenAIChatGPTConfiguration = OpenAIChatGPTConfiguration.from_env(model=model, api_host=api_host) return OpenAILLM(config) @staticmethod def from_config(config_object: LLMConfigObject) -> OpenAILLM: provider = config_object.spec.provider if not provider.is_of_kind(LLMProviders.OpenAI): raise ValueError(f"Invalid LLM provider, actual {provider}, expected {LLMProviders.OpenAI}") config = OpenAIChatGPTConfiguration.from_spec(config_object.spec) return OpenAILLM(config=config, name=config_object.metadata.name)