Source code for council.prompt.llm_prompt_config_object

from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Sequence

import yaml
from council.utils import DataObject, DataObjectSpecBase


class LLMPromptTemplate:
    def __init__(self, template: str, model: Optional[str], model_family: Optional[str]) -> None:
        self._template = template
        self._model = model
        self._model_family = model_family

        if self._model is None and self._model_family is None:
            raise ValueError("At least one of `model` or `model-family` must be defined")

        if self._model is not None and self._model_family is not None:
            if not self._model.startswith(self._model_family):
                raise ValueError(
                    f"model `{self._model}` and model-family `{self._model_family}` are not compliant."
                    f"Please use separate prompt templates"
                )

    @classmethod
    def from_dict(cls, values: Dict[str, Any]) -> LLMPromptTemplate:
        template = values.get("template")
        if template is None:
            raise ValueError("`template` must be defined")

        model = values.get("model", None)
        model_family = values.get("model-family", None)
        return LLMPromptTemplate(template, model, model_family)

    @property
    def template(self) -> str:
        return self._template

    def is_compatible(self, model: str) -> bool:
        if self._model is not None and self._model == model:
            return True

        if self._model_family is not None and model.startswith(self._model_family):
            return True
        return False


class LLMPromptConfigSpec(DataObjectSpecBase):
    def __init__(self, system: Sequence[LLMPromptTemplate], user: Optional[Sequence[LLMPromptTemplate]]) -> None:
        self.system_prompts = list(system)
        self.user_prompts = list(user or [])

    @classmethod
    def from_dict(cls, values: Mapping[str, Any]) -> LLMPromptConfigSpec:
        system_prompts = values.get("system", [])
        user_prompts = values.get("user")
        if not system_prompts:
            raise ValueError("System prompt(s) must be defined")

        system = [LLMPromptTemplate.from_dict(p) for p in system_prompts]

        user: Optional[List[LLMPromptTemplate]] = None
        if user_prompts is not None:
            user = [LLMPromptTemplate.from_dict(p) for p in user_prompts]
        return LLMPromptConfigSpec(system, user)

    def to_dict(self) -> Dict[str, Any]:
        result = {"system": self.system_prompts}
        if not self.user_prompts:
            result["user"] = self.user_prompts
        return result

    def __str__(self):
        msg = f"{len(self.system_prompts)} system prompt(s)"
        if self.user_prompts is not None:
            msg += f"; {len(self.user_prompts)} user prompt(s)"
        return msg


[docs] class LLMPromptConfigObject(DataObject[LLMPromptConfigSpec]): """ Helper class to instantiate a LLMPrompt from a YAML file """ @classmethod def from_dict(cls, values: Dict[str, Any]) -> LLMPromptConfigObject: return super()._from_dict(LLMPromptConfigSpec, values) @classmethod def from_yaml(cls, filename: str) -> LLMPromptConfigObject: with open(filename, "r", encoding="utf-8") as f: values = yaml.safe_load(f) cls._check_kind(values, "LLMPrompt") return LLMPromptConfigObject.from_dict(values) @property def has_user_prompt_template(self) -> bool: """Return True, if user prompt template was specified in yaml file.""" return bool(self.spec.user_prompts)
[docs] def get_system_prompt_template(self, model: str) -> str: """Return system prompt template for a given model.""" return self._get_prompt_template(self.spec.system_prompts, model)
[docs] def get_user_prompt_template(self, model: str) -> str: """ Return user prompt template for a given model. Raises ValueError if no user prompt template was provided. """ if not self.has_user_prompt_template: raise ValueError("No user prompt template provided") return self._get_prompt_template(self.spec.user_prompts, model)
@staticmethod def _get_prompt_template(prompts: List[LLMPromptTemplate], model: str) -> str: """ Get the first prompt compatible to the given `model` (or `default` prompt). Args: prompts (List[LLMPromptTemplate]): List of prompts to search from Returns: str: prompt template Raises: ValueError: if both prompt template for a given model and default prompt template are not provided """ try: return next(prompt.template for prompt in prompts if prompt.is_compatible(model)) except StopIteration: try: return next(prompt.template for prompt in prompts if prompt.is_compatible("default")) except StopIteration: raise ValueError(f"No prompt template for a given model `{model}` nor a default one")