Streaming Inference Models
To mimic human behavior during user interaction, it felt right to completely ignore all compute costs and go all in at what I know best - even if it meant building a python package to support one inference type.
Python script below is an example of my ideal data class for collecting data from chat streams. It runs a background transformer model, BART classifier to inference my emotions and linguistics, as well as support Tom's speech linguistics or act as the base line for his chat actions.
An extension to this would be building MLOPs workflow over this collected data and Tom's policy net to support Deep Q-learning to integrate conditional contexts into the generated texts from Tom's base model (i.e. base or initial LLM - Meta Llama or Gemini 3n Multimodal ) .
""" utils/chat.py """
from __future__ import annotations
import os
from uuid import uuid4
from pathlib import Path
from typing import Dict, List
import numpy as np
import pandas as pd
from datetime import datetime
from dataclasses import dataclass, field
from transformers import pipeline
call_file = Path(__file__).resolve().parent.parent
if call_file != Path('.').absolute():
import sys
sys.path.append(str(call_file))
from utils._config import Constants
CHAT_CONST = Constants.load_constants()
SPEECH_INTENT = CHAT_CONST['intent']
QUESTION_TYPES = CHAT_CONST['question_types']
EMOTION_LABEL_LEADS = CHAT_CONST['emotions']['lead_labels']
EMOTION_LABEL_FULL = CHAT_CONST['emotions']['full_labels']
FLAGGER = CHAT_CONST["flagger"]
class SessionHelper:
"""
A helper class for classifying speech and emotion in chat messages using a zero-shot classification model.
This class loads a transformer-based classification pipeline (e.g., Hugging Face zero-shot model)
and applies it to analyze intents, question types, emotional expressions, and other semantic features
within a chat message.
Attributes:
clf_id (str): The Hugging Face model ID loaded from environment variable `CLF_ID`.
classifier (transformers.pipeline): A Hugging Face zero-shot classification pipeline.
"""
def __init__(self):
"""
Initializes the SessionHelper by loading the classification pipeline from the environment variable CLF_ID.
Raises:
OSError: If CLF_ID is not found in environment variables.
"""
self.clf_id = os.getenv("CLF_ID", None)
if not self.clf_id:
raise OSError("No classification model stored in environment.")
else:
self.classifier = pipeline("zero-shot-classification", model=self.clf_id)
def classify(self, inputs: str, candidate_labels: list):
"""
Performs zero-shot classification on the input text.
Args:
inputs (str): The input message text to classify.
candidate_labels (list): A list of label strings to classify the input against.
Returns:
tuple[dict, str]: A tuple containing:
- A dictionary mapping candidate labels to their probability scores.
- The label with the highest score in lowercase.
Raises:
ValueError: If the classification result is not as expected.
"""
try:
# NOTE: when multi_label=True (without normalization), it doesnt automatically sort the labels from highest to lowest & the labels and scores are unordered
result = self.classifier(inputs, candidate_labels)
if not isinstance(result, dict):
raise ValueError("Unexpected result format from classifier (not a dict).")
if set(result.keys()) != {"labels", "scores", "sequence"}:
raise ValueError(f"Result missing expected keys. Got: {set(result.keys())}")
if result["scores"] != sorted(result["scores"], reverse=True):
raise ValueError("Classifier scores are not sorted in descending order.")
probs = dict(zip(result['labels'], result['scores']))
return probs, result["labels"][0].lower()
except Exception as e:
raise e
def __call__(self, message: ChatMessage):
"""
Classifies and annotates a ChatMessage instance with speech and emotional features.
Args:
message (ChatMessage): The message object to annotate with classification data.
Returns:
ChatMessage: The updated message object containing classification metadata.
Raises:
ValueError: If classification fails or yields unexpected outputs.
"""
message.intent_prob, message.intent = self.classify(message.content, SPEECH_INTENT)
print(f'Intent: {message.intent}')
if message.intent == 'questioning':
# Checks for the type of speech
message.question = True
message.question_type_prob, message.question_type = self.classify(message.content, QUESTION_TYPES)
if message.question_type is None:
raise ValueError(f'Problem classifying question type: {message}')
if message.question_type == "question about human":
_, entity_detect = self.classify(message.content, Constants.ENTITY.value)
if entity_detect == 'others':
message.flag['unknown_entity_mentioned'] += 1
elif message.intent == 'farewell':
message.stop = True
else:
if message.intent not in SPEECH_INTENT:
raise ValueError(f'What the fuck kind of msg intent did the classifier output: {message.intent}')
# Emotional question
message.emote_prob, message.emote = self.classify(message.content, EMOTION_LABEL_LEADS)
print(f'Emotion: {message.emote}')
if message.emote == 'emotional' or message.emote == 'expressive':
message.emote_parent_prob, message.emote_parent = self.classify(message.content, list(EMOTION_LABEL_FULL.keys()))
if message.emote_parent is None:
raise ValueError(f'Problem classifying parent emote: {message}')
child_labels = EMOTION_LABEL_FULL[message.emote_parent]
message.emote_child_prob, message.emote_child = self.classify(message.content, child_labels)
print(f'Parent Emotion: {message.emote_parent}')
print(f'Child Emote: {message.emote_child}')
return message
@staticmethod
def create_flagger():
"""
Creates a flagger dictionary from a global FLAGGER list.
Returns:
dict: A dictionary with flag names as keys and 0 as default value.
"""
return {item['flag_name']: 0 for item in FLAGGER}
@dataclass
class ChatMessage:
"""
Represents a single message exchanged in a chat session, along with enriched features.
This class captures metadata about the message including its role (e.g., user, assistant),
content, timestamp, and various annotations for intent and emotion that can be used in
downstream conversational analytics or agent modeling.
Attributes:
role (str): The role of the speaker (e.g., 'user', 'assistant').
content (str): The textual content of the message.
timestamp (str): ISO-formatted timestamp of when the message was created. Defaults to current time.
stop (bool): Indicator if this message signals a stop condition. Defaults to False.
intent (Optional[str]): Detected intent category for the message.
intent_prob (Optional[Dict[str, float]]): Probabilities across various intent types.
question (bool): Indicates whether the message is a question.
question_type (Optional[str]): Classified question type (e.g., 'why', 'how').
question_type_prob (Optional[Dict[str, float]]): Probabilities across question types.
emote (str): Primary emotion detected in the message (default is 'neutral').
emote_parent (Optional[str]): Parent-level emotion category.
emote_child (Optional[str]): Fine-grained emotion category.
emote_prob (Optional[Dict[str, float]]): Probabilities across emotion labels.
emote_parent_prob (Optional[Dict[str, float]]): Probabilities for parent emotion categories.
emote_child_prob (Optional[Dict[str, float]]): Probabilities for child emotion categories.
flag (dict): Dictionary of flags or markers associated with this message, created via `create_flagger`.
"""
role: str
content: str
timestamp: str = field(default=datetime.now().isoformat())
# speech based features
stop: bool = field(default=False)
intent: str | None = field(default=None)
intent_prob: Dict[str, float] | None = field(default=None)
question: bool = field(default=False)
question_type: str | None = field(default=None)
question_type_prob: Dict[str, float] | None = field(default=None)
# human based features
emote: str = field(default='neutral')
emote_parent: str | None = field(default=None)
emote_child: str | None = field(default=None)
emote_prob: Dict[str, float] | None = field(default=None)
emote_parent_prob: Dict[str, float] | None = field(default=None)
emote_child_prob: Dict[str, float] | None = field(default=None)
# case examples
flag: dict = field(default_factory=SessionHelper.create_flagger)
def __post_init__(self):
"""Validates the role upon initialization to ensure it matches expected chat roles.
Raises:
ValueError: If the provided role is not within the defined chat roles.
"""
if self.role not in Constants.CHAT_ROLES.value:
raise ValueError(f"Invalid input role: {self.role}")
@dataclass
class ChatSession:
"""
Manages a session of chat messages and facilitates inference of speech and emotional features
for user messages.
Attributes:
_messages (List[ChatMessage]): A list of ChatMessage instances in the session.
session_id (str): A unique identifier for the chat session.
infer (SessionHelper): An instance of the SessionHelper to perform classification.
session_flagger (dict): A flag counter for detected session-level semantic features.
"""
_messages: List[ChatMessage] = field(default_factory=list, repr=False)
session_id: str = field(default=uuid4().hex)
infer: SessionHelper = field(default_factory=SessionHelper, repr=False)
session_flagger: dict = field(default_factory=SessionHelper.create_flagger, repr=False)
def __iter__(self):
"""
Iterator over the session's messages.
Yields:
dict: A dictionary containing 'role' and 'content' fields for each message.
"""
for item in self._messages:
yield {
'role': item.role,
'content': item.content
}
def __getitem__(self, index):
"""
Accesses a specific message from the session by index.
Args:
index (int): Index of the message.
Returns:
dict: A dictionary with 'role' and 'content' of the message.
Raises:
IndexError: If the index is out of bounds.
"""
if -len(self._messages) <= index < len(self._messages):
msg = self._messages[index]
return {
'role': msg.role,
'content': msg.content
}
raise IndexError
@property
def messages(self):
"""
Returns all messages in the session as a DataFrame.
Includes all inferred features and session metadata.
Returns:
pd.DataFrame: A DataFrame of chat messages with session ID.
"""
data = pd.DataFrame(self._messages)
data['session_id'] = self.session_id
return data
def add_message(self, role: str, content: str):
"""
Adds a message to the session and performs inference if it's a user message.
Args:
role (str): The role of the message sender (e.g., 'user', 'assistant').
content (str): The content of the message.
Raises:
Exception: If the message creation or classification fails.
"""
try:
chat_message = ChatMessage(role=role, content=content)
if role == 'user':
message = self.infer(chat_message)
print(message)
for key_flag, val in message.flag.items():
if val != 0:
self.session_flagger[key_flag] += 1
self._messages.append(message)
else:
self._messages.append(chat_message)
except Exception as e:
raise e
@property
def file_name(self):
"""
Returns the file path to save the current chat session as a Parquet file.
The path is relative to `memory/history`.
Returns:
Path: Path object to the `.parquet` file for this session.
Raises:
RuntimeError: If not run from the expected project root.
FileNotFoundError: If the expected directory doesn't exist.
"""
if Path('.').parent.absolute().stem != Constants.ROOT_DIR.value:
raise RuntimeError(
f"Expected to be run from inside '{Constants.ROOT_DIR.value}', but found '{Path('.').parent.absolute().stem}'."
)
dir_path = Path("memory/history")
# TODO: SWAP TO OLLAMA MODEL directory
if not dir_path.exists():
dir_path.mkdir(parents=True, exist_ok=True)
return dir_path / f"{self.session_id}.parquet"
@dataclass(slots=True)
class Observation:
"""
Represents the observation features extracted from a chat session.
This class holds the results of various analyses performed on a user's
messages, such as engagement, expressiveness, and intent classification.
"""
engage: float | None = field(
metadata={"description": "The user's engagement score, representing their activity level in the conversation. `None` if not available."}
)
express: float | None = field(
metadata={"description": "The user's expressiveness score, indicating how expressive their responses were. `None` if not available."}
)
intents: Dict[str, float] | None = field(
metadata={"description": "A dictionary of detected intents with their associated probabilities. `None` if no intents detected."}
)
class State:
"""
State data fetcher and helper for deriving user features from an agent's state space.
This class provides utility methods to compute behavioral features such as
engagement, expressiveness, and intents from chat logs or user dialogue data.
It also produces an `Observation` object summarizing user behavior.
Attributes:
emote_cols (set): Expected columns for emotion probability distributions.
"""
@staticmethod
def compute_engagement(user_df: pd.DataFrame) -> float | None:
"""
Computes the user's engagement level as the mean time between messages.
Args:
user_df (pd.DataFrame): Filtered dataframe containing only user messages.
Returns:
float | None: Mean time between user messages in seconds, or None if empty.
"""
return user_df['time_delta_sec'].mean() if not user_df.empty else None
@staticmethod
def compute_expressiveness(user_df: pd.DataFrame) -> float | None:
"""
Computes user expressiveness by analyzing the emotion probability distribution.
Args:
user_df (pd.DataFrame): Filtered dataframe with a column 'emote_prob',
where each row contains a dictionary of emotion probabilities.
Returns:
float | None: Mean expressiveness score, defined as (1 - mean(neutral)),
or None if no valid data.
Raises:
ValueError: If the structure of `emote_prob` is invalid or sums are inconsistent.
"""
emote_df = pd.json_normalize(user_df['emote_prob'].dropna().tolist())
if set(emote_df.columns) != EMOTION_LABEL_LEADS:
raise ValueError(f"Emote prob table issue rendering: {emote_df}")
if not np.allclose(emote_df.sum(axis=1), 1.0, rtol=1e-3):
raise ValueError(f"Unexpected sum in chat's emote scores: {emote_df}")
return 1 - emote_df['neutral'].mean()
@staticmethod
def compute_intents(user_df: pd.DataFrame) -> dict:
"""
Computes the user's intent distribution from chat data.
Args:
user_df (pd.DataFrame): Filtered dataframe with a column 'intent_prob',
where each row contains a dictionary of intent probabilities.
Returns:
dict: A dictionary of average intent probabilities. Returns empty dict if none found.
"""
intent_df = pd.json_normalize(user_df['intent_prob'].dropna().tolist())
return intent_df.mean(axis=0).to_dict() if not intent_df.empty else {}
@staticmethod
def observe_space(dialogue: ChatSession) -> Observation:
"""
Processes a chat session to compute observational features related to user behavior.
Args:
dialogue (ChatSession): A chat session object containing dialogue messages
in a dataframe format with columns 'timestamp' and 'role'.
Returns:
Observation: An observation object with user engagement, expressiveness,
and intent probabilities.
Raises:
ValueError: If input data is malformed or contains invalid timestamps.
"""
try:
df = dialogue.messages.copy()
# prepares / process df
if df.empty or 'timestamp' not in df.columns or 'role' not in df.columns:
raise ValueError("Dialogue data is empty or missing required columns.")
df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce')
if df['timestamp'].isna().any():
raise ValueError("Invalid or missing timestamps detected.")
df['time_delta_sec'] = df['timestamp'].diff().dt.total_seconds()
user_df = df[df['role'] == 'user']
# fetches features
engage = State.compute_engagement(user_df)
express = State.compute_expressiveness(user_df)
intents = State.compute_intents(user_df)
return Observation(engage=engage, express=express, intents=intents)
except Exception as e:
raise ValueError(f"Failed to process dialogue features: {e}")
if __name__ == "__main__":
import yaml
sample_path = Path(__file__).resolve().parent.parent / 'tests/conftest.yaml'
save_path = Path(__file__).resolve().parent.parent / 'tests/short_chat_cat.csv'
with open(sample_path, 'r') as file:
data = yaml.safe_load(file)
chat = ChatSession(session_id="test_session_id")
if chat_msg := data.get("chat_example_2"):
for item in chat_msg:
if (role := item.get("role")) and (msg := item.get("content")):
chat.add_message(role, msg)
df = chat.messages
df.to_csv("tests/short_chat_cat.csv", index=False)
Last updated