Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/pytei/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TEIClient:
using a specified datastore.
"""

def __init__(self, embedding_store: Union[EmbeddingStore, None] = None, url: str = "http://127.0.0.1:8080", timeout: int = 10):
def __init__(self, embedding_store: Optional[EmbeddingStore] = None, url: str = "http://127.0.0.1:8080", timeout: int = 10):
"""Constructor method

:param embedding_store: Data store used for cacheing. Defaults to in-memory caching.
Expand Down Expand Up @@ -47,7 +47,7 @@ def _fetch_embeddings(self, texts: List[str], body: Dict[str, Any]) -> List[np.n
"""Send a batched request to the embedding endpoint."""
body["inputs"] = texts
try:
response = requests.post(f"{self._endpoint}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"},
response = requests.post(f"{self._endpoint}/embed", json=body, headers={"Content-Type": "application/json"},
timeout=self._timeout)
response.raise_for_status() # Raise an HTTPError for non-200 responses
embeddings = json.loads(response.text)
Expand All @@ -56,9 +56,8 @@ def _fetch_embeddings(self, texts: List[str], body: Dict[str, Any]) -> List[np.n
raise RuntimeError(f"Failed to fetch embedding: {e}")

@staticmethod
def _build_embed_call_body(normalize: bool = True, prompt_name: Union[str, None] = None,
truncate: bool = False,
truncation_direction: Union[Literal['left', 'right'], None] = None) -> Dict[str, Any]:
def _build_embed_call_body(normalize: bool = True, prompt_name: Optional[str] = None, truncate: bool = False,
truncation_direction: Optional[Literal['left', 'right']] = None) -> Dict[str, Any]:
body = {"normalize": normalize}
if prompt_name is not None:
body["prompt_name"] = prompt_name
Expand All @@ -70,8 +69,8 @@ def _build_embed_call_body(normalize: bool = True, prompt_name: Union[str, None]
return body


def embed(self, inputs: Union[str, List[str]], normalize: bool = True, prompt_name: Union[str, None] = None,
truncate: bool = False, truncation_direction: Union[Literal['left', 'right'], None] = None,
def embed(self, inputs: Union[str, List[str]], normalize: bool = True, prompt_name: Optional[str] = None,
truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None,
skip_cache: bool = False) -> Union[np.ndarray, List[np.ndarray]]:
"""
Get the embedding for a single string or a batch of strings.
Expand Down Expand Up @@ -141,8 +140,10 @@ def embed(self, inputs: Union[str, List[str]], normalize: bool = True, prompt_na

def rerank(self, query: str, texts: List[str], raw_score: bool = False, return_text: bool = False,
truncate: bool = False, truncation_direction: Union[Literal['left', 'right'], None] = None) -> List[Rank]:
def rerank(self, query: str, texts: List[str], raw_scores: bool = False, return_text: bool = False,
truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None) -> List[Rank]:
raise NotImplementedError("Reranking is not yet implemented.")

def predict(self, inputs: Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]], raw_scores: bool = False,
truncate: bool = False, truncation_direction: Union[Literal['left', 'right'], None] = None) -> Union[PredictionResult, List[PredictionResult]]:
truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None) -> Union[PredictionResult, List[PredictionResult]]:
raise NotImplementedError("Sequence classification is not yet implemented.")
Loading