84 lines
2.6 KiB
Python
84 lines
2.6 KiB
Python
from typing import Any, Sequence
|
|
|
|
from llama_index.core.llms.types import (
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseAsyncGen,
|
|
ChatResponseGen,
|
|
CompletionResponse,
|
|
CompletionResponseAsyncGen,
|
|
)
|
|
from llama_index.llms.base import (
|
|
llm_chat_callback,
|
|
llm_completion_callback,
|
|
)
|
|
from llama_index.llms.generic_utils import (
|
|
completion_response_to_chat_response,
|
|
stream_completion_response_to_chat_response,
|
|
)
|
|
from llama_index.llms.llm import LLM
|
|
|
|
|
|
class CustomLLM(LLM):
|
|
"""Simple abstract base class for custom LLMs.
|
|
|
|
Subclasses must implement the `__init__`, `_complete`,
|
|
`_stream_complete`, and `metadata` methods.
|
|
"""
|
|
|
|
@llm_chat_callback()
|
|
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
prompt = self.messages_to_prompt(messages)
|
|
completion_response = self.complete(prompt, formatted=True, **kwargs)
|
|
return completion_response_to_chat_response(completion_response)
|
|
|
|
@llm_chat_callback()
|
|
def stream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseGen:
|
|
prompt = self.messages_to_prompt(messages)
|
|
completion_response_gen = self.stream_complete(prompt, formatted=True, **kwargs)
|
|
return stream_completion_response_to_chat_response(completion_response_gen)
|
|
|
|
@llm_chat_callback()
|
|
async def achat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
return self.chat(messages, **kwargs)
|
|
|
|
@llm_chat_callback()
|
|
async def astream_chat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponseAsyncGen:
|
|
async def gen() -> ChatResponseAsyncGen:
|
|
for message in self.stream_chat(messages, **kwargs):
|
|
yield message
|
|
|
|
# NOTE: convert generator to async generator
|
|
return gen()
|
|
|
|
@llm_completion_callback()
|
|
async def acomplete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
return self.complete(prompt, formatted=formatted, **kwargs)
|
|
|
|
@llm_completion_callback()
|
|
async def astream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
async def gen() -> CompletionResponseAsyncGen:
|
|
for message in self.stream_complete(prompt, formatted=formatted, **kwargs):
|
|
yield message
|
|
|
|
# NOTE: convert generator to async generator
|
|
return gen()
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "custom_llm"
|