faiss_rag_enterprise/llama_index/callbacks/finetuning_handler.py

216 lines
7.4 KiB
Python

import json
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from llama_index.callbacks.base import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
class BaseFinetuningHandler(BaseCallbackHandler):
"""
Callback handler for finetuning.
This handler will collect all messages
sent to the LLM, along with their responses.
It also defines a `get_finetuning_events` endpoint as well as a
`save_finetuning_events` endpoint.
"""
def __init__(self) -> None:
"""Initialize the base callback handler."""
super().__init__(
event_starts_to_ignore=[],
event_ends_to_ignore=[],
)
self._finetuning_events: Dict[str, List[Any]] = {}
self._function_calls: Dict[str, List[Any]] = {}
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
from llama_index.core.llms.types import ChatMessage, MessageRole
if event_type == CBEventType.LLM:
cur_messages = []
if payload and EventPayload.PROMPT in payload:
message = ChatMessage(
role=MessageRole.USER, text=str(payload[EventPayload.PROMPT])
)
cur_messages = [message]
elif payload and EventPayload.MESSAGES in payload:
cur_messages = payload[EventPayload.MESSAGES]
if len(cur_messages) > 0:
if event_id in self._finetuning_events:
self._finetuning_events[event_id].extend(cur_messages)
else:
self._finetuning_events[event_id] = cur_messages
# if functions exists, add that
if payload and EventPayload.ADDITIONAL_KWARGS in payload:
kwargs_dict = payload[EventPayload.ADDITIONAL_KWARGS]
if "functions" in kwargs_dict:
self._function_calls[event_id] = kwargs_dict["functions"]
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""
from llama_index.core.llms.types import ChatMessage, MessageRole
if (
event_type == CBEventType.LLM
and event_id in self._finetuning_events
and payload is not None
):
if isinstance(payload[EventPayload.RESPONSE], str):
response = ChatMessage(
role=MessageRole.ASSISTANT, text=str(payload[EventPayload.RESPONSE])
)
else:
response = payload[EventPayload.RESPONSE].message
self._finetuning_events[event_id].append(response)
@abstractmethod
def get_finetuning_events(self) -> Dict[str, Dict[str, Any]]:
"""Get finetuning events."""
@abstractmethod
def save_finetuning_events(self, path: str) -> None:
"""Save the finetuning events to a file."""
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
class OpenAIFineTuningHandler(BaseFinetuningHandler):
"""
Callback handler for OpenAI fine-tuning.
This handler will collect all messages
sent to the LLM, along with their responses. It will then save these messages
in a `.jsonl` format that can be used for fine-tuning with OpenAI's API.
"""
def get_finetuning_events(self) -> Dict[str, Dict[str, Any]]:
events_dict = {}
for event_id, event in self._finetuning_events.items():
events_dict[event_id] = {"messages": event[:-1], "response": event[-1]}
return events_dict
def save_finetuning_events(self, path: str) -> None:
"""
Save the finetuning events to a file.
This saved format can be used for fine-tuning with OpenAI's API.
The structure for each json line is as follows:
{
messages: [
{ rol: "system", content: "Text"},
{ role: "user", content: "Text" },
]
},
...
"""
from llama_index.llms.openai_utils import to_openai_message_dicts
events_dict = self.get_finetuning_events()
json_strs = []
for event_id, event in events_dict.items():
all_messages = event["messages"] + [event["response"]]
message_dicts = to_openai_message_dicts(all_messages, drop_none=True)
event_dict = {"messages": message_dicts}
if event_id in self._function_calls:
event_dict["functions"] = self._function_calls[event_id]
json_strs.append(json.dumps(event_dict))
with open(path, "w") as f:
f.write("\n".join(json_strs))
print(f"Wrote {len(json_strs)} examples to {path}")
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
class GradientAIFineTuningHandler(BaseFinetuningHandler):
"""
Callback handler for Gradient AI fine-tuning.
This handler will collect all messages
sent to the LLM, along with their responses. It will then save these messages
in a `.jsonl` format that can be used for fine-tuning with Gradient AI's API.
"""
def get_finetuning_events(self) -> Dict[str, Dict[str, Any]]:
events_dict = {}
for event_id, event in self._finetuning_events.items():
events_dict[event_id] = {"messages": event[:-1], "response": event[-1]}
return events_dict
def save_finetuning_events(self, path: str) -> None:
"""
Save the finetuning events to a file.
This saved format can be used for fine-tuning with OpenAI's API.
The structure for each json line is as follows:
{
"inputs": "<full_prompt_str>"
},
...
"""
from llama_index.llms.generic_utils import messages_to_history_str
events_dict = self.get_finetuning_events()
json_strs = []
for event in events_dict.values():
all_messages = event["messages"] + [event["response"]]
# TODO: come up with model-specific message->prompt serialization format
prompt_str = messages_to_history_str(all_messages)
input_dict = {"inputs": prompt_str}
json_strs.append(json.dumps(input_dict))
with open(path, "w") as f:
f.write("\n".join(json_strs))
print(f"Wrote {len(json_strs)} examples to {path}")
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""