From 54473db5ccb601be900284cba0166e8ec301afa7 Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Tue, 21 Jan 2025 09:47:02 +0100 Subject: [PATCH] Refine type hints --- src/pytei/client.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/pytei/client.py b/src/pytei/client.py index 224e71f..0679f7a 100644 --- a/src/pytei/client.py +++ b/src/pytei/client.py @@ -17,7 +17,7 @@ class TEIClient: using a specified datastore. """ - def __init__(self, embedding_store: Union[EmbeddingStore, None] = None, url: str = "http://127.0.0.1:8080", timeout: int = 10): + def __init__(self, embedding_store: Optional[EmbeddingStore] = None, url: str = "http://127.0.0.1:8080", timeout: int = 10): """Constructor method :param embedding_store: Data store used for cacheing. Defaults to in-memory caching. @@ -47,7 +47,7 @@ def _fetch_embeddings(self, texts: List[str], body: Dict[str, Any]) -> List[np.n """Send a batched request to the embedding endpoint.""" body["inputs"] = texts try: - response = requests.post(f"{self._endpoint}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, + response = requests.post(f"{self._endpoint}/embed", json=body, headers={"Content-Type": "application/json"}, timeout=self._timeout) response.raise_for_status() # Raise an HTTPError for non-200 responses embeddings = json.loads(response.text) @@ -56,9 +56,8 @@ def _fetch_embeddings(self, texts: List[str], body: Dict[str, Any]) -> List[np.n raise RuntimeError(f"Failed to fetch embedding: {e}") @staticmethod - def _build_embed_call_body(normalize: bool = True, prompt_name: Union[str, None] = None, - truncate: bool = False, - truncation_direction: Union[Literal['left', 'right'], None] = None) -> Dict[str, Any]: + def _build_embed_call_body(normalize: bool = True, prompt_name: Optional[str] = None, truncate: bool = False, + truncation_direction: Optional[Literal['left', 'right']] = None) -> Dict[str, Any]: body = {"normalize": normalize} if prompt_name is not None: body["prompt_name"] = prompt_name @@ -70,8 +69,8 @@ def _build_embed_call_body(normalize: bool = True, prompt_name: Union[str, None] return body - def embed(self, inputs: Union[str, List[str]], normalize: bool = True, prompt_name: Union[str, None] = None, - truncate: bool = False, truncation_direction: Union[Literal['left', 'right'], None] = None, + def embed(self, inputs: Union[str, List[str]], normalize: bool = True, prompt_name: Optional[str] = None, + truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None, skip_cache: bool = False) -> Union[np.ndarray, List[np.ndarray]]: """ Get the embedding for a single string or a batch of strings. @@ -141,8 +140,10 @@ def embed(self, inputs: Union[str, List[str]], normalize: bool = True, prompt_na def rerank(self, query: str, texts: List[str], raw_score: bool = False, return_text: bool = False, truncate: bool = False, truncation_direction: Union[Literal['left', 'right'], None] = None) -> List[Rank]: + def rerank(self, query: str, texts: List[str], raw_scores: bool = False, return_text: bool = False, + truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None) -> List[Rank]: raise NotImplementedError("Reranking is not yet implemented.") def predict(self, inputs: Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]], raw_scores: bool = False, - truncate: bool = False, truncation_direction: Union[Literal['left', 'right'], None] = None) -> Union[PredictionResult, List[PredictionResult]]: + truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None) -> Union[PredictionResult, List[PredictionResult]]: raise NotImplementedError("Sequence classification is not yet implemented.") \ No newline at end of file