chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from litellm.llms.langgraph.chat.transformation import LangGraphConfig
|
||||
|
||||
__all__ = ["LangGraphConfig"]
|
||||
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
SSE Stream Iterator for LangGraph.
|
||||
|
||||
Handles Server-Sent Events (SSE) streaming responses from LangGraph.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class LangGraphSSEStreamIterator:
|
||||
"""
|
||||
Iterator for LangGraph SSE streaming responses.
|
||||
Supports both sync and async iteration.
|
||||
|
||||
LangGraph stream format with stream_mode="messages-tuple":
|
||||
Each SSE event is a tuple: (event_type, data)
|
||||
Common event types: "messages", "metadata"
|
||||
"""
|
||||
|
||||
def __init__(self, response: httpx.Response, model: str):
|
||||
self.response = response
|
||||
self.model = model
|
||||
self.finished = False
|
||||
self.line_iterator = None
|
||||
self.async_line_iterator = None
|
||||
|
||||
def __iter__(self):
|
||||
"""Initialize sync iteration."""
|
||||
self.line_iterator = self.response.iter_lines()
|
||||
return self
|
||||
|
||||
def __aiter__(self):
|
||||
"""Initialize async iteration."""
|
||||
self.async_line_iterator = self.response.aiter_lines()
|
||||
return self
|
||||
|
||||
def _parse_sse_line(self, line: str) -> Optional[ModelResponseStream]:
|
||||
"""
|
||||
Parse a single SSE line and return a ModelResponse chunk if applicable.
|
||||
|
||||
LangGraph SSE format can vary:
|
||||
- data: [...] (tuple format)
|
||||
- event: ...\ndata: ...
|
||||
"""
|
||||
line = line.strip()
|
||||
if not line:
|
||||
return None
|
||||
|
||||
# Handle SSE data lines
|
||||
if line.startswith("data:"):
|
||||
json_str = line[5:].strip()
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
return self._process_data(data)
|
||||
except json.JSONDecodeError:
|
||||
verbose_logger.debug(f"Skipping non-JSON SSE line: {line[:100]}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _process_data(self, data) -> Optional[ModelResponseStream]:
|
||||
"""
|
||||
Process parsed data from SSE stream.
|
||||
|
||||
LangGraph uses tuple format: [event_type, payload]
|
||||
"""
|
||||
# Handle tuple format: ["messages", ...]
|
||||
if isinstance(data, list) and len(data) >= 2:
|
||||
event_type = data[0]
|
||||
payload = data[1]
|
||||
|
||||
if event_type == "messages":
|
||||
return self._process_messages_event(payload)
|
||||
elif event_type == "metadata":
|
||||
# Metadata event, might contain usage info
|
||||
return self._process_metadata_event(payload)
|
||||
|
||||
# Handle dict format (alternative response format)
|
||||
elif isinstance(data, dict):
|
||||
if "content" in data:
|
||||
return self._create_content_chunk(data.get("content", ""))
|
||||
elif "messages" in data:
|
||||
messages = data.get("messages", [])
|
||||
if messages:
|
||||
last_msg = messages[-1]
|
||||
if isinstance(last_msg, dict) and last_msg.get("type") == "ai":
|
||||
return self._create_content_chunk(last_msg.get("content", ""))
|
||||
|
||||
return None
|
||||
|
||||
def _process_messages_event(self, payload) -> Optional[ModelResponseStream]:
|
||||
"""
|
||||
Process a messages event from the stream.
|
||||
|
||||
payload format: [[message_object, metadata], ...]
|
||||
"""
|
||||
if isinstance(payload, list):
|
||||
for item in payload:
|
||||
if isinstance(item, list) and len(item) >= 1:
|
||||
msg = item[0]
|
||||
if isinstance(msg, dict):
|
||||
msg_type = msg.get("type", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Only return AI messages with content
|
||||
if msg_type == "ai" and content:
|
||||
return self._create_content_chunk(content)
|
||||
elif msg_type == "AIMessageChunk" and content:
|
||||
return self._create_content_chunk(content)
|
||||
elif isinstance(item, dict):
|
||||
msg_type = item.get("type", "")
|
||||
content = item.get("content", "")
|
||||
if msg_type in ("ai", "AIMessageChunk") and content:
|
||||
return self._create_content_chunk(content)
|
||||
|
||||
return None
|
||||
|
||||
def _process_metadata_event(self, payload) -> Optional[ModelResponseStream]:
|
||||
"""
|
||||
Process a metadata event, which may signal the end of the stream.
|
||||
"""
|
||||
if isinstance(payload, dict):
|
||||
# Check if this is a final event
|
||||
if "run_id" in payload:
|
||||
self.finished = True
|
||||
return self._create_final_chunk()
|
||||
return None
|
||||
|
||||
def _create_content_chunk(self, text: str) -> ModelResponseStream:
|
||||
"""Create a ModelResponseStream chunk with content."""
|
||||
chunk = ModelResponseStream(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=self.model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(content=text, role="assistant"),
|
||||
)
|
||||
]
|
||||
|
||||
return chunk
|
||||
|
||||
def _create_final_chunk(self) -> ModelResponseStream:
|
||||
"""Create a final ModelResponseStream chunk with finish_reason."""
|
||||
chunk = ModelResponseStream(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=self.model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(),
|
||||
)
|
||||
]
|
||||
|
||||
return chunk
|
||||
|
||||
def __next__(self) -> ModelResponseStream:
|
||||
"""Sync iteration - parse SSE events and yield ModelResponse chunks."""
|
||||
try:
|
||||
if self.line_iterator is None:
|
||||
raise StopIteration
|
||||
|
||||
for line in self.line_iterator:
|
||||
result = self._parse_sse_line(line)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Stream ended naturally - send final chunk if not already finished
|
||||
if not self.finished:
|
||||
self.finished = True
|
||||
return self._create_final_chunk()
|
||||
|
||||
raise StopIteration
|
||||
|
||||
except StopIteration:
|
||||
raise
|
||||
except httpx.StreamConsumed:
|
||||
raise StopIteration
|
||||
except httpx.StreamClosed:
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in LangGraph SSE stream: {str(e)}")
|
||||
raise StopIteration
|
||||
|
||||
async def __anext__(self) -> ModelResponseStream:
|
||||
"""Async iteration - parse SSE events and yield ModelResponse chunks."""
|
||||
try:
|
||||
if self.async_line_iterator is None:
|
||||
raise StopAsyncIteration
|
||||
|
||||
async for line in self.async_line_iterator:
|
||||
result = self._parse_sse_line(line)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Stream ended naturally - send final chunk if not already finished
|
||||
if not self.finished:
|
||||
self.finished = True
|
||||
return self._create_final_chunk()
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
except StopAsyncIteration:
|
||||
raise
|
||||
except httpx.StreamConsumed:
|
||||
raise StopAsyncIteration
|
||||
except httpx.StreamClosed:
|
||||
raise StopAsyncIteration
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in LangGraph SSE stream: {str(e)}")
|
||||
raise StopAsyncIteration
|
||||
@@ -0,0 +1,510 @@
|
||||
"""
|
||||
Transformation for LangGraph API.
|
||||
|
||||
LangGraph provides streaming (/runs/stream) and non-streaming (/runs/wait) endpoints
|
||||
for running agents.
|
||||
|
||||
Streaming endpoint: POST /runs/stream
|
||||
Non-streaming endpoint: POST /runs/wait
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.langgraph.chat.sse_iterator import LangGraphSSEStreamIterator
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HTTPHandler = Any
|
||||
AsyncHTTPHandler = Any
|
||||
CustomStreamWrapper = Any
|
||||
|
||||
|
||||
class LangGraphError(BaseLLMException):
|
||||
"""Exception class for LangGraph API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LangGraphConfig(BaseConfig):
|
||||
"""
|
||||
Configuration for LangGraph API.
|
||||
|
||||
LangGraph is a framework for building stateful, multi-actor applications with LLMs.
|
||||
It provides a streaming and non-streaming API for running agents.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Get LangGraph API base and key from params or environment.
|
||||
|
||||
Returns:
|
||||
Tuple of (api_base, api_key)
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
api_base = (
|
||||
api_base or get_secret_str("LANGGRAPH_API_BASE") or "http://localhost:2024"
|
||||
)
|
||||
|
||||
api_key = api_key or get_secret_str("LANGGRAPH_API_KEY")
|
||||
|
||||
return api_base, api_key
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
LangGraph supports minimal OpenAI params since it's an agent runtime.
|
||||
"""
|
||||
return ["stream"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI params to LangGraph params.
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the LangGraph request.
|
||||
|
||||
Streaming: /runs/stream
|
||||
Non-streaming: /runs/wait
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for LangGraph. Set it via LANGGRAPH_API_BASE env var or api_base parameter."
|
||||
)
|
||||
|
||||
# Remove trailing slash if present
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Choose endpoint based on streaming mode
|
||||
if stream:
|
||||
return f"{api_base}/runs/stream"
|
||||
else:
|
||||
return f"{api_base}/runs/wait"
|
||||
|
||||
def _get_assistant_id(self, model: str, optional_params: dict) -> str:
|
||||
"""
|
||||
Get the assistant ID from model or optional_params.
|
||||
|
||||
model format: "langgraph/assistant_id" or just "assistant_id"
|
||||
"""
|
||||
assistant_id = optional_params.get("assistant_id")
|
||||
if assistant_id:
|
||||
return assistant_id
|
||||
|
||||
# Extract from model name
|
||||
if "/" in model:
|
||||
parts = model.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1]
|
||||
return model
|
||||
|
||||
def _convert_messages_to_langgraph_format(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Convert OpenAI-format messages to LangGraph format.
|
||||
|
||||
OpenAI format: {"role": "user", "content": "..."}
|
||||
LangGraph format: {"role": "human", "content": "..."}
|
||||
"""
|
||||
langgraph_messages: List[Dict[str, str]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Convert OpenAI roles to LangGraph roles
|
||||
if role == "user":
|
||||
langgraph_role = "human"
|
||||
elif role == "assistant":
|
||||
langgraph_role = "assistant"
|
||||
elif role == "system":
|
||||
langgraph_role = "system"
|
||||
else:
|
||||
langgraph_role = "human"
|
||||
|
||||
# Handle content that might be a list
|
||||
if isinstance(content, list):
|
||||
content = convert_content_list_to_str(msg)
|
||||
|
||||
# Ensure content is a string
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
langgraph_messages.append({"role": langgraph_role, "content": content})
|
||||
|
||||
return langgraph_messages
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request to LangGraph format.
|
||||
|
||||
LangGraph request format:
|
||||
{
|
||||
"assistant_id": "agent",
|
||||
"input": {
|
||||
"messages": [{"role": "human", "content": "..."}]
|
||||
},
|
||||
"stream_mode": "messages-tuple" # for streaming
|
||||
}
|
||||
"""
|
||||
assistant_id = self._get_assistant_id(model, optional_params)
|
||||
langgraph_messages = self._convert_messages_to_langgraph_format(messages)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"assistant_id": assistant_id,
|
||||
"input": {"messages": langgraph_messages},
|
||||
}
|
||||
|
||||
# Add stream_mode for streaming requests
|
||||
stream = litellm_params.get("stream", False)
|
||||
if stream:
|
||||
stream_mode = optional_params.get("stream_mode", "messages-tuple")
|
||||
payload["stream_mode"] = stream_mode
|
||||
|
||||
# Add optional config if provided
|
||||
if "config" in optional_params:
|
||||
payload["config"] = optional_params["config"]
|
||||
|
||||
# Add optional metadata if provided
|
||||
if "metadata" in optional_params:
|
||||
payload["metadata"] = optional_params["metadata"]
|
||||
|
||||
# Add thread_id if provided (for stateful conversations)
|
||||
if "thread_id" in optional_params:
|
||||
payload["thread_id"] = optional_params["thread_id"]
|
||||
|
||||
verbose_logger.debug(f"LangGraph request payload: {payload}")
|
||||
return payload
|
||||
|
||||
def _extract_content_from_response(self, response_json: dict) -> str:
|
||||
"""
|
||||
Extract content from LangGraph non-streaming response.
|
||||
|
||||
Response format varies, but commonly:
|
||||
{
|
||||
"messages": [...], # or could be in different structure
|
||||
"values": {...}
|
||||
}
|
||||
"""
|
||||
# Try to get the last AI message from the response
|
||||
messages = response_json.get("messages", [])
|
||||
if isinstance(messages, list) and messages:
|
||||
# Find the last AI/assistant message
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, dict):
|
||||
msg_type = msg.get("type", "")
|
||||
role = msg.get("role", "")
|
||||
if msg_type == "ai" or role == "assistant":
|
||||
return msg.get("content", "")
|
||||
|
||||
# Check values for output
|
||||
values = response_json.get("values", {})
|
||||
if isinstance(values, dict):
|
||||
output_messages = values.get("messages", [])
|
||||
if isinstance(output_messages, list) and output_messages:
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, dict):
|
||||
msg_type = msg.get("type", "")
|
||||
if msg_type == "ai":
|
||||
return msg.get("content", "")
|
||||
|
||||
# Fallback: try to serialize the whole response
|
||||
verbose_logger.warning(
|
||||
"Could not extract content from LangGraph response, returning raw"
|
||||
)
|
||||
return json.dumps(response_json)
|
||||
|
||||
def get_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
) -> LangGraphSSEStreamIterator:
|
||||
"""
|
||||
Return a streaming iterator for SSE responses.
|
||||
"""
|
||||
return LangGraphSSEStreamIterator(response=raw_response, model=model)
|
||||
|
||||
def get_sync_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, "AsyncHTTPHandler"]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
"""
|
||||
Get a CustomStreamWrapper for synchronous streaming.
|
||||
"""
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client(params={})
|
||||
|
||||
verbose_logger.debug(f"Making sync streaming request to: {api_base}")
|
||||
|
||||
# Make streaming request
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise LangGraphError(
|
||||
status_code=response.status_code, message=str(response.read())
|
||||
)
|
||||
|
||||
# Create iterator for SSE stream
|
||||
completion_stream = self.get_streaming_response(
|
||||
model=model, raw_response=response
|
||||
)
|
||||
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
async def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional["AsyncHTTPHandler"] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
"""
|
||||
Get a CustomStreamWrapper for asynchronous streaming.
|
||||
"""
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=cast(Any, "langgraph"), params={}
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Making async streaming request to: {api_base}")
|
||||
|
||||
# Make async streaming request
|
||||
response = await client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise LangGraphError(
|
||||
status_code=response.status_code, message=str(await response.aread())
|
||||
)
|
||||
|
||||
# Create iterator for SSE stream
|
||||
completion_stream = self.get_streaming_response(
|
||||
model=model, raw_response=response
|
||||
)
|
||||
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
"""Indicates that this config has custom streaming support."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
LangGraph does not use a stream param in request body.
|
||||
Streaming is determined by the endpoint URL.
|
||||
"""
|
||||
return False
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform the LangGraph response to LiteLLM ModelResponse format.
|
||||
"""
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug(f"LangGraph response: {response_json}")
|
||||
|
||||
content = self._extract_content_from_response(response_json)
|
||||
|
||||
# Create the message
|
||||
message = Message(content=content, role="assistant")
|
||||
|
||||
# Create choices
|
||||
choice = Choices(finish_reason="stop", index=0, message=message)
|
||||
|
||||
# Update model response
|
||||
model_response.choices = [choice]
|
||||
model_response.model = model
|
||||
|
||||
# LangGraph doesn't provide token usage, so we estimate it
|
||||
try:
|
||||
from litellm.utils import token_counter
|
||||
|
||||
prompt_tokens = token_counter(model="gpt-3.5-turbo", messages=messages)
|
||||
completion_tokens = token_counter(
|
||||
model="gpt-3.5-turbo", text=content, count_response_tokens=True
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to calculate token usage: {str(e)}")
|
||||
|
||||
return model_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error processing LangGraph response: {str(e)}")
|
||||
raise LangGraphError(
|
||||
message=f"Error processing response: {str(e)}",
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and set up environment for LangGraph requests.
|
||||
"""
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Add API key if provided
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return LangGraphError(status_code=status_code, message=error_message)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
LangGraph has native streaming support, so we don't need to fake stream.
|
||||
"""
|
||||
return False
|
||||
Reference in New Issue
Block a user