from __future__ import annotations
import json
import random
from collections import defaultdict
from typing import Any, Counter, DefaultDict, Dict, List, Mapping, Optional, Sequence
import yaml
from council.llm import LLMMessage, LLMMessageRole
from council.utils import DataObject, DataObjectSpecBase
[docs]
class LLMDatasetMessage:
"""
Represents a single chat message in a conversation.
"""
def __init__(self, role: LLMMessageRole, content: str):
self.role = role
self.content = content.strip()
@classmethod
def from_dict(cls, values: Dict[str, str]) -> LLMDatasetMessage:
role = values.get("role")
content = values.get("content")
if role is None or content is None:
raise ValueError("Both 'role' and 'content' must be defined for a message")
return LLMDatasetMessage(LLMMessageRole(role), content)
@classmethod
def from_llm_message(cls, message: LLMMessage) -> LLMDatasetMessage:
return LLMDatasetMessage(role=message.role, content=message.content)
def to_dict(self) -> Dict[str, str]:
return {"role": self.role, "content": self.content}
[docs]
class LLMDatasetConversation:
"""
Represents a conversation between user and assistant with optional labels.
"""
def __init__(self, messages: Sequence[LLMDatasetMessage], labels: Optional[Mapping[str, str]]):
self.messages = list(messages)
self.labels: Dict[str, str] = dict(labels) if labels is not None else {}
@classmethod
def from_dict(cls, values: Dict[str, Any]) -> LLMDatasetConversation:
messages = values.get("messages", [])
if not messages:
raise ValueError("Conversation must contain at least one message")
llm_dataset_messages = [LLMDatasetMessage.from_dict(message) for message in messages]
labels = values.get("labels")
return LLMDatasetConversation(llm_dataset_messages, labels)
def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {"messages": [message.to_dict() for message in self.messages]}
if self.labels:
result["labels"] = self.labels
return result
@staticmethod
def get_message_pair(*, user: str, assistant: str) -> List[Dict[str, str]]:
return [{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]
class LLMDatasetSpec(DataObjectSpecBase):
def __init__(self, conversations: List[LLMDatasetConversation], system_prompt: Optional[str] = None) -> None:
self.conversations = conversations
self.system_prompt = system_prompt.strip() if system_prompt is not None else None
@classmethod
def from_dict(cls, values: Mapping[str, Any]) -> LLMDatasetSpec:
conversations = values.get("conversations", [])
if not conversations:
raise ValueError("Dataset must contain at least one conversation")
parsed_conversations = [LLMDatasetConversation.from_dict(c) for c in conversations]
system_prompt = values.get("system_prompt")
return LLMDatasetSpec(parsed_conversations, system_prompt)
def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {"conversations": [conv.to_dict() for conv in self.conversations]}
if self.system_prompt is not None:
result["system_prompt"] = self.system_prompt
return result
def __str__(self):
result = f"{len(self.conversations)} conversation(s)"
if self.system_prompt is not None:
result += " with system prompt"
return result
[docs]
class LLMDatasetObject(DataObject[LLMDatasetSpec]):
"""
Helper class to instantiate a LLMDataset from a YAML file.
LLMDataset represents a dataset to be used for fine-tuning / batch API or managing few shot examples.
Contains a list of conversations between user and assistant and optional system prompt;
if specified, it will be a system prompt for every conversation in the dataset.
"""
@classmethod
def from_dict(cls, values: Dict[str, Any]) -> LLMDatasetObject:
return super()._from_dict(LLMDatasetSpec, values)
@classmethod
def from_yaml(cls, filename: str) -> LLMDatasetObject:
with open(filename, "r", encoding="utf-8") as f:
values = yaml.safe_load(f)
cls._check_kind(values, "LLMDataset")
return LLMDatasetObject.from_dict(values)
@property
def system_prompt(self) -> Optional[str]:
"""Return system prompt if any."""
return self.spec.system_prompt
@property
def conversations(self) -> List[LLMDatasetConversation]:
"""Return all raw conversations in the dataset."""
return self.spec.conversations
[docs]
def count_labels(self) -> DefaultDict[str, Counter]:
"""
Count occurrences of each label value grouped by label key.
Returns a dictionary where keys are label names and values are Counters of label values.
"""
label_counters: DefaultDict[str, Counter] = defaultdict(Counter)
for conversation in self.conversations:
if conversation.labels:
for label_key, label_value in conversation.labels.items():
label_counters[label_key][label_value] += 1
return label_counters
[docs]
def to_jsonl_messages(self) -> List[Dict[str, List[Dict[str, str]]]]:
"""
Convert the dataset to JSONL format with OpenAI messages structure.
Returns a list of dictionaries containing messages.
"""
messages_starter = []
if self.system_prompt is not None:
messages_starter = [{"role": "system", "content": self.system_prompt}]
jsonl_lines = []
for conversation in self.conversations:
messages = messages_starter + [msg.to_dict() for msg in conversation.messages]
jsonl_lines.append({"messages": messages})
return jsonl_lines
[docs]
def save_jsonl_messages(
self, path: str, random_seed: Optional[int] = None, val_split: Optional[float] = None
) -> None:
"""
Save the dataset as JSONL messages file(s), optionally splitting into training and validation sets.
JSONL file then can be used for fine-tuning.
See https://platform.openai.com/docs/guides/fine-tuning.
Args:
path: Base path for saving the file(s)
random_seed: If provided, will be used to shuffle dataset before saving (default: None)
val_split: If provided, fraction of data to use for validation and create separate files for train and val.
If None, saves all data to a single file (default: None)
Examples:
# Save all data into a single `my_dataset.jsonl` file
dataset.save_jsonl("my_dataset.jsonl") # Creates my_dataset.jsonl
# Split into train/val sets (80/20 split) and saves into `my_dataset_train.jsonl` and `my_dataset_val.jsonl`
dataset.save_jsonl("my_dataset.jsonl", random_seed=42, val_split=0.2)
"""
jsonl_lines = self.to_jsonl_messages()
if random_seed is not None:
random.seed(random_seed)
random.shuffle(jsonl_lines)
base_path = path[:-6] if path.endswith(".jsonl") else path
if val_split is None:
self._save_jsonl(f"{base_path}.jsonl", jsonl_lines)
return
split_index = int(len(jsonl_lines) * (1 - val_split))
train_lines, val_lines = jsonl_lines[:split_index], jsonl_lines[split_index:]
self._save_jsonl(f"{base_path}_train.jsonl", train_lines)
self._save_jsonl(f"{base_path}_val.jsonl", val_lines)
[docs]
def save_jsonl_requests(self, path: str, model: str, url: str = "/v1/chat/completions") -> None:
"""
Save the dataset as JSONL request file, which can be used for batch API.
See https://platform.openai.com/docs/guides/batch.
Args:
path: Path to the output file
model: OpenAI model name
url: OpenAI API URL (default: "/v1/chat/completions")
Examples:
dataset.save_jsonl_request("my_batch.jsonl", "gpt-4o-mini")
"""
messages_lines = self.to_jsonl_messages()
request_lines = [
{
"custom_id": f"request-{i}",
"method": "POST",
"url": url,
"body": {"model": model, "messages": message_line["messages"]},
}
for i, message_line in enumerate(messages_lines)
]
self._save_jsonl(path, request_lines)
@staticmethod
def _save_jsonl(filename: str, lines: List[Dict[str, Any]]) -> None:
"""Helper method to save lines to JSONL file."""
with open(filename, "w", encoding="utf-8") as f:
for line in lines:
f.write(json.dumps(line) + "\n")
[docs]
@staticmethod
def read_jsonl(path: str) -> List[Dict[str, Any]]:
"""Helper method to read JSONL file into list of dictionaries."""
data = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
data.append(json.loads(line))
return data
class LLMDatasetValidationException(Exception):
"""Exception raised for validation errors in LLMDatasetObject."""
[docs]
class LLMDatasetValidator:
"""
Helper class to validate the content of LLMDatasetObject.
"""
[docs]
@staticmethod
def validate_for_batch_api(dataset: LLMDatasetObject) -> None:
"""
Validate dataset for batch API.
Raises:
LLMDatasetValidationException
If dataset contains conversations that do not end with a user message.
"""
for idx, conversation in enumerate(dataset.conversations, start=1):
if conversation.messages[-1].role != "user":
raise LLMDatasetValidationException(f"Conversation #{idx}: must end with a user message")
print("All conversations end with a user message.")
[docs]
@staticmethod
def validate_for_fine_tuning(dataset: LLMDatasetObject) -> None:
"""
Validate dataset for fine-tuning.
Raises:
LLMDatasetValidationException
If dataset contains conversations that does not follow the pattern:
user -> assistant -> user -> assistant -> ...
"""
for idx, conversation in enumerate(dataset.conversations, start=1):
prefix = f"Conversation #{idx}:"
if len(conversation.messages) % 2 != 0:
raise LLMDatasetValidationException(f"{prefix} There must be an even number of messages")
for i in range(0, len(conversation.messages), 2):
if conversation.messages[i].role != "user":
raise LLMDatasetValidationException(f"{prefix} Message #{i} must be a user message")
if conversation.messages[i + 1].role != "assistant":
raise LLMDatasetValidationException(f"{prefix} Message #{i + 1} must be an assistant message")
print("All conversations have an even number of messages with alternating user/assistant roles.")