faiss_rag_enterprise/llama_index/llms/langchain.py

226 lines
7.6 KiB
Python

from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Sequence
if TYPE_CHECKING:
from langchain.base_language import BaseLanguageModel
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.core.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.llms.base import llm_chat_callback, llm_completion_callback
from llama_index.llms.generic_utils import (
completion_response_to_chat_response,
stream_completion_response_to_chat_response,
)
from llama_index.llms.llm import LLM
from llama_index.types import BaseOutputParser, PydanticProgramMode
class LangChainLLM(LLM):
"""Adapter for a LangChain LLM."""
_llm: Any = PrivateAttr()
def __init__(
self,
llm: "BaseLanguageModel",
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
) -> None:
self._llm = llm
super().__init__(
callback_manager=callback_manager,
system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
output_parser=output_parser,
)
@classmethod
def class_name(cls) -> str:
return "LangChainLLM"
@property
def llm(self) -> "BaseLanguageModel":
return self._llm
@property
def metadata(self) -> LLMMetadata:
from llama_index.llms.langchain_utils import get_llm_metadata
return get_llm_metadata(self._llm)
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
from llama_index.llms.langchain_utils import (
from_lc_messages,
to_lc_messages,
)
if not self.metadata.is_chat_model:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
lc_messages = to_lc_messages(messages)
lc_message = self._llm.predict_messages(messages=lc_messages, **kwargs)
message = from_lc_messages([lc_message])[0]
return ChatResponse(message=message)
@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
if not formatted:
prompt = self.completion_to_prompt(prompt)
output_str = self._llm.predict(prompt, **kwargs)
return CompletionResponse(text=output_str)
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
if not self.metadata.is_chat_model:
prompt = self.messages_to_prompt(messages)
stream_completion = self.stream_complete(prompt, formatted=True, **kwargs)
return stream_completion_response_to_chat_response(stream_completion)
if hasattr(self._llm, "stream"):
def gen() -> Generator[ChatResponse, None, None]:
from llama_index.llms.langchain_utils import (
from_lc_messages,
to_lc_messages,
)
lc_messages = to_lc_messages(messages)
response_str = ""
for message in self._llm.stream(lc_messages, **kwargs):
message = from_lc_messages([message])[0]
delta = message.content
response_str += delta
yield ChatResponse(
message=ChatMessage(role=message.role, content=response_str),
delta=delta,
)
return gen()
else:
from llama_index.langchain_helpers.streaming import (
StreamingGeneratorCallbackHandler,
)
handler = StreamingGeneratorCallbackHandler()
if not hasattr(self._llm, "streaming"):
raise ValueError("LLM must support streaming.")
if not hasattr(self._llm, "callbacks"):
raise ValueError("LLM must support callbacks to use streaming.")
self._llm.callbacks = [handler] # type: ignore
self._llm.streaming = True # type: ignore
thread = Thread(target=self.chat, args=[messages], kwargs=kwargs)
thread.start()
response_gen = handler.get_response_gen()
def gen() -> Generator[ChatResponse, None, None]:
text = ""
for delta in response_gen:
text += delta
yield ChatResponse(
message=ChatMessage(text=text),
delta=delta,
)
return gen()
@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
if not formatted:
prompt = self.completion_to_prompt(prompt)
from llama_index.langchain_helpers.streaming import (
StreamingGeneratorCallbackHandler,
)
handler = StreamingGeneratorCallbackHandler()
if not hasattr(self._llm, "streaming"):
raise ValueError("LLM must support streaming.")
if not hasattr(self._llm, "callbacks"):
raise ValueError("LLM must support callbacks to use streaming.")
self._llm.callbacks = [handler] # type: ignore
self._llm.streaming = True # type: ignore
thread = Thread(target=self.complete, args=[prompt], kwargs=kwargs)
thread.start()
response_gen = handler.get_response_gen()
def gen() -> Generator[CompletionResponse, None, None]:
text = ""
for delta in response_gen:
text += delta
yield CompletionResponse(delta=delta, text=text)
return gen()
@llm_chat_callback()
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
# TODO: Implement async chat
return self.chat(messages, **kwargs)
@llm_completion_callback()
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
# TODO: Implement async complete
return self.complete(prompt, formatted=formatted, **kwargs)
@llm_chat_callback()
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
# TODO: Implement async stream_chat
async def gen() -> ChatResponseAsyncGen:
for message in self.stream_chat(messages, **kwargs):
yield message
return gen()
@llm_completion_callback()
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
# TODO: Implement async stream_complete
async def gen() -> CompletionResponseAsyncGen:
for response in self.stream_complete(prompt, formatted=formatted, **kwargs):
yield response
return gen()