Source code for council.llm.base.llm_config_object

from __future__ import annotations

import os
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional

import yaml
from council.utils import DataObject, DataObjectSpecBase
from council.utils.parameter import Undefined


class LLMProviders(str, Enum):
    """Supported LLM providers."""

    OpenAI = "openAISpec"
    Azure = "azureSpec"
    Anthropic = "anthropicSpec"
    Gemini = "googleGeminiSpec"
    Ollama = "ollamaSpec"
    Groq = "groqSpec"

    @staticmethod
    def all() -> List[LLMProviders]:
        return list(LLMProviders.__members__.values())


class LLMProvider:
    def __init__(self, name: str, description: str, specs: Dict[str, Any], kind: LLMProviders) -> None:
        self.name = name
        self.description = description
        self._specs = specs
        self._kind = kind

    @property
    def kind(self) -> LLMProviders:
        return self._kind

    def is_of_kind(self, kind: LLMProviders) -> bool:
        return self._kind == kind

    @classmethod
    def from_dict(cls, values: Dict[str, Any]) -> LLMProvider:
        name = values.get("name", "")
        description = values.get("description", "")

        provider_specs: Mapping[LLMProviders, Optional[Dict[str, Any]]] = {
            provider: values.get(provider) for provider in LLMProviders.all()
        }

        for provider, spec in provider_specs.items():
            if spec is not None:
                return LLMProvider(name, description, spec, provider)

        raise ValueError("Unsupported model provider")

    def to_dict(self) -> Dict[str, Any]:
        result: Dict[str, Any] = {"name": self.name, "description": self.description}

        for provider in LLMProviders.all():
            if self.is_of_kind(provider):
                result[provider] = self._specs
                break

        return result

    def must_get_value(self, key: str) -> Any:
        return self.get_value(key=key, required=True)

    def get_value(self, key: str, required: bool = False, default: Optional[Any] = Undefined()) -> Optional[Any]:
        maybe_value = self._specs.get(key, None)
        if maybe_value is None:
            if not isinstance(default, Undefined):
                return default

        if isinstance(maybe_value, dict):
            default_value: Optional[str] = maybe_value.get("default", None)
            env_var_name: Optional[str] = maybe_value.get("fromEnvVar", None)
            if env_var_name is not None:
                maybe_value = os.environ.get(env_var_name, default_value)

        if maybe_value is None and required:
            raise Exception(f"LLMProvider {self.name} - A required key {key} is missing.")
        return maybe_value

    def __str__(self) -> str:
        return f"{self._kind}: {self.name} ({self.description})"


class LLMConfigSpec(DataObjectSpecBase):
    def __init__(
        self, description: str, provider: LLMProvider, fallback: Optional[LLMProvider], parameters: Dict[str, Any]
    ) -> None:
        self.description = description
        self.provider = provider
        self.parameters = parameters
        self.fallback_provider = fallback

    @classmethod
    def from_dict(cls, values: Dict[str, Any]) -> LLMConfigSpec:
        description = values.get("description", "")
        parameters = values.get("parameters", {})
        fallback_spec: Optional[Dict[str, Any]] = values.get("fallbackProvider", None)
        fallback = LLMProvider.from_dict(fallback_spec) if fallback_spec is not None else None
        provider = LLMProvider.from_dict(values["provider"])
        if provider is None:
            raise ValueError("provider needs to be defined.")

        return LLMConfigSpec(description, provider, fallback, parameters)

    def to_dict(self) -> Dict[str, Any]:
        result = {"description": self.description, "provider": self.provider, "parameters": self.parameters}
        if self.fallback_provider is not None:
            result["fallback_provider"] = self.fallback_provider
        return result

    def __str__(self) -> str:
        return f"{self.description}"

    def check_provider(self, provider: LLMProviders) -> None:
        if not self.provider.is_of_kind(provider):
            raise ValueError(f"Invalid LLM provider, actual {self.provider}, expected {provider}")


[docs] class LLMConfigObject(DataObject[LLMConfigSpec]): """ Helper class to instantiate an LLM from a YAML file """ @classmethod def from_dict(cls, values: Dict[str, Any]) -> LLMConfigObject: return super()._from_dict(LLMConfigSpec, values) @classmethod def from_yaml(cls, filename: str) -> LLMConfigObject: with open(filename, "r", encoding="utf-8") as f: values = yaml.safe_load(f) cls._check_kind(values, "LLMConfig") return LLMConfigObject.from_dict(values)