119 lines
4.0 KiB
Python
119 lines
4.0 KiB
Python
|
|
"""
|
||
|
|
Translates from Cohere's `/v1/rerank` input format to Bedrock's `/rerank` input format.
|
||
|
|
|
||
|
|
Why separate file? Make it easy to see how transformation works
|
||
|
|
"""
|
||
|
|
|
||
|
|
from litellm._uuid import uuid
|
||
|
|
from typing import List, Optional, Union
|
||
|
|
|
||
|
|
from litellm.types.llms.bedrock import (
|
||
|
|
BedrockRerankBedrockRerankingConfiguration,
|
||
|
|
BedrockRerankConfiguration,
|
||
|
|
BedrockRerankInlineDocumentSource,
|
||
|
|
BedrockRerankModelConfiguration,
|
||
|
|
BedrockRerankQuery,
|
||
|
|
BedrockRerankRequest,
|
||
|
|
BedrockRerankSource,
|
||
|
|
BedrockRerankTextDocument,
|
||
|
|
BedrockRerankTextQuery,
|
||
|
|
)
|
||
|
|
from litellm.types.rerank import (
|
||
|
|
RerankBilledUnits,
|
||
|
|
RerankRequest,
|
||
|
|
RerankResponse,
|
||
|
|
RerankResponseMeta,
|
||
|
|
RerankResponseResult,
|
||
|
|
RerankTokens,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class BedrockRerankConfig:
|
||
|
|
def _transform_sources(
|
||
|
|
self, documents: List[Union[str, dict]]
|
||
|
|
) -> List[BedrockRerankSource]:
|
||
|
|
"""
|
||
|
|
Transform the sources from RerankRequest format to Bedrock format.
|
||
|
|
"""
|
||
|
|
_sources = []
|
||
|
|
for document in documents:
|
||
|
|
if isinstance(document, str):
|
||
|
|
_sources.append(
|
||
|
|
BedrockRerankSource(
|
||
|
|
inlineDocumentSource=BedrockRerankInlineDocumentSource(
|
||
|
|
textDocument=BedrockRerankTextDocument(text=document),
|
||
|
|
type="TEXT",
|
||
|
|
),
|
||
|
|
type="INLINE",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
_sources.append(
|
||
|
|
BedrockRerankSource(
|
||
|
|
inlineDocumentSource=BedrockRerankInlineDocumentSource(
|
||
|
|
jsonDocument=document, type="JSON"
|
||
|
|
),
|
||
|
|
type="INLINE",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
return _sources
|
||
|
|
|
||
|
|
def _transform_request(self, request_data: RerankRequest) -> BedrockRerankRequest:
|
||
|
|
"""
|
||
|
|
Transform the request from RerankRequest format to Bedrock format.
|
||
|
|
"""
|
||
|
|
_sources = self._transform_sources(request_data.documents)
|
||
|
|
|
||
|
|
return BedrockRerankRequest(
|
||
|
|
queries=[
|
||
|
|
BedrockRerankQuery(
|
||
|
|
textQuery=BedrockRerankTextQuery(text=request_data.query),
|
||
|
|
type="TEXT",
|
||
|
|
)
|
||
|
|
],
|
||
|
|
rerankingConfiguration=BedrockRerankConfiguration(
|
||
|
|
bedrockRerankingConfiguration=BedrockRerankBedrockRerankingConfiguration(
|
||
|
|
modelConfiguration=BedrockRerankModelConfiguration(
|
||
|
|
modelArn=request_data.model
|
||
|
|
),
|
||
|
|
numberOfResults=request_data.top_n or len(request_data.documents),
|
||
|
|
),
|
||
|
|
type="BEDROCK_RERANKING_MODEL",
|
||
|
|
),
|
||
|
|
sources=_sources,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _transform_response(self, response: dict) -> RerankResponse:
|
||
|
|
"""
|
||
|
|
Transform the response from Bedrock into the RerankResponse format.
|
||
|
|
|
||
|
|
example input:
|
||
|
|
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
|
||
|
|
"""
|
||
|
|
_billed_units = RerankBilledUnits(
|
||
|
|
**response.get("usage", {"search_units": 1})
|
||
|
|
) # by default 1 search unit
|
||
|
|
_tokens = RerankTokens(**response.get("usage", {}))
|
||
|
|
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||
|
|
|
||
|
|
_results: Optional[List[RerankResponseResult]] = None
|
||
|
|
|
||
|
|
bedrock_results = response.get("results")
|
||
|
|
if bedrock_results:
|
||
|
|
_results = [
|
||
|
|
RerankResponseResult(
|
||
|
|
index=result.get("index"),
|
||
|
|
relevance_score=result.get("relevanceScore"),
|
||
|
|
)
|
||
|
|
for result in bedrock_results
|
||
|
|
]
|
||
|
|
|
||
|
|
if _results is None:
|
||
|
|
raise ValueError(f"No results found in the response={response}")
|
||
|
|
|
||
|
|
return RerankResponse(
|
||
|
|
id=response.get("id") or str(uuid.uuid4()),
|
||
|
|
results=_results,
|
||
|
|
meta=rerank_meta,
|
||
|
|
) # Return response
|