From 3bdb4438f211acc2300e72195c1a8a56ef643ee4 Mon Sep 17 00:00:00 2001 From: isabel zimmerman Date: Mon, 17 Mar 2025 14:37:52 -0400 Subject: [PATCH 1/9] predict_proba support for sklearn --- vetiver/handlers/sklearn.py | 17 ++++++++++------- vetiver/handlers/statsmodels.py | 10 +++------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/vetiver/handlers/sklearn.py b/vetiver/handlers/sklearn.py index a9401186..931b8211 100644 --- a/vetiver/handlers/sklearn.py +++ b/vetiver/handlers/sklearn.py @@ -16,7 +16,7 @@ class SKLearnHandler(BaseHandler): model_class = staticmethod(lambda: sklearn.base.BaseEstimator) pip_name = "scikit-learn" - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype: bool, **kw): """ Generates method for /predict endpoint in VetiverAPI @@ -28,16 +28,19 @@ def handler_predict(self, input_data, check_prototype): ---------- input_data: Test data + check_prototype: bool + prediction_type: str + Type of prediction to make. One of "predict", "predict_proba", or "predict_log_proba". + Default is "predict". Returns ------- prediction: Prediction from model """ + prediction_type = kw.get("prediction_type", "predict") + if prediction_type not in ["predict", "predict_proba", "predict_log_proba"]: + raise ValueError('prediction_type must be "predict", "predict_proba", or "predict_log_proba"') - if not check_prototype or isinstance(input_data, pd.DataFrame): - prediction = self.model.predict(input_data) - else: - prediction = self.model.predict([input_data]) - - return prediction.tolist() + input_data = [input_data] if check_prototype and not isinstance(input_data, pd.DataFrame) else input_data + return getattr(self.model, prediction_type)(input_data).tolist() diff --git a/vetiver/handlers/statsmodels.py b/vetiver/handlers/statsmodels.py index 084b5ffc..416fbd3a 100644 --- a/vetiver/handlers/statsmodels.py +++ b/vetiver/handlers/statsmodels.py @@ -42,10 +42,6 @@ def handler_predict(self, input_data, check_prototype): """ if not sm_exists: raise ImportError("Cannot import `statsmodels`") - - if isinstance(input_data, (list, pd.DataFrame)): - prediction = self.model.predict(input_data) - else: - prediction = self.model.predict([input_data]) - - return prediction.tolist() + + input_data = input_data if isinstance(input_data, (list, pd.DataFrame)) else [input_data] + return self.model.predict(input_data).tolist() From 9bec20c7352dfb1753c2dea5dd68adb8cc972ee3 Mon Sep 17 00:00:00 2001 From: isabel zimmerman Date: Tue, 18 Mar 2025 16:12:12 -0400 Subject: [PATCH 2/9] add SklearnPredictionTypes --- vetiver/handlers/base.py | 2 +- vetiver/handlers/spacy.py | 2 +- vetiver/handlers/statsmodels.py | 2 +- vetiver/handlers/torch.py | 2 +- vetiver/handlers/xgboost.py | 2 +- vetiver/server.py | 88 ++++++++++++++++++--------------- vetiver/types.py | 4 ++ 7 files changed, 58 insertions(+), 44 deletions(-) diff --git a/vetiver/handlers/base.py b/vetiver/handlers/base.py index 3f0044c7..7139f2ad 100644 --- a/vetiver/handlers/base.py +++ b/vetiver/handlers/base.py @@ -121,7 +121,7 @@ def handler_startup(): """ ... - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """Generates method for /predict endpoint in VetiverAPI The `handler_predict` function executes at each API call. Use this diff --git a/vetiver/handlers/spacy.py b/vetiver/handlers/spacy.py index 80dfdaca..eb4d0de3 100644 --- a/vetiver/handlers/spacy.py +++ b/vetiver/handlers/spacy.py @@ -53,7 +53,7 @@ def construct_prototype(self): return prototype - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """ Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/handlers/statsmodels.py b/vetiver/handlers/statsmodels.py index 416fbd3a..3eb59739 100644 --- a/vetiver/handlers/statsmodels.py +++ b/vetiver/handlers/statsmodels.py @@ -22,7 +22,7 @@ class StatsmodelsHandler(BaseHandler): if sm_exists: pip_name = "statsmodels" - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """ Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/handlers/torch.py b/vetiver/handlers/torch.py index 15dafa5e..625d4cb1 100644 --- a/vetiver/handlers/torch.py +++ b/vetiver/handlers/torch.py @@ -22,7 +22,7 @@ class TorchHandler(BaseHandler): if torch_exists: pip_name = "torch" - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """ Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/handlers/xgboost.py b/vetiver/handlers/xgboost.py index 9c112343..21bbb6be 100644 --- a/vetiver/handlers/xgboost.py +++ b/vetiver/handlers/xgboost.py @@ -22,7 +22,7 @@ class XGBoostHandler(BaseHandler): if xgb_exists: pip_name = "xgboost" - def handler_predict(self, input_data, check_prototype): + def handler_predict(self, input_data, check_prototype, **kw): """ Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/server.py b/vetiver/server.py index ea552986..d56da12a 100644 --- a/vetiver/server.py +++ b/vetiver/server.py @@ -11,15 +11,16 @@ import pandas as pd import requests import uvicorn -from fastapi import FastAPI, Request +from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse - from .helpers import api_data_to_frame, response_to_frame +from .handlers.sklearn import SKLearnHandler from .meta import VetiverMeta from .utils import _jupyter_nb, get_workbench_path from .vetiver_model import VetiverModel +from .types import SklearnPredictionTypes class VetiverAPI: @@ -111,7 +112,6 @@ async def startup_event(): @app.get("/", include_in_schema=False) def docs_redirect(): - redirect = "__docs__" return RedirectResponse(redirect) @@ -200,65 +200,75 @@ async def validation_exception_handler(request, exc): return app - def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw): - """Create new POST endpoint that is aware of model input data + def vetiver_post( + self, + endpoint_fx: Union[Callable, SklearnPredictionTypes], + endpoint_name: str = None, + **kw, + ): + """Define a new POST endpoint that utilizes the model's input data. Parameters ---------- - endpoint_fx : typing.Callable - Custom function to be run at endpoint + endpoint_fx : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]] + A callable function that specifies the custom logic to execute when the endpoint is called. + This function should take input data (e.g., a DataFrame or dictionary) and return the desired output + (e.g., predictions or transformed data). For scikit-learn models, endpoint_fx can also be one of + "predict", "predict_proba", or "predict_log_proba" if the model supports these methods. + endpoint_name : str - Name of endpoint + The name of the endpoint to be created. Examples ------- - ```{python} + ```python from vetiver import mock, VetiverModel, VetiverAPI X, y = mock.get_mock_data() model = mock.get_mock_model().fit(X, y) - v = VetiverModel(model = model, model_name = "model", prototype_data = X) - v_api = VetiverAPI(model = v, check_prototype = True) + v = VetiverModel(model=model, model_name="model", prototype_data=X) + v_api = VetiverAPI(model=v, check_prototype=True) def sum_values(x): return x.sum() + v_api.vetiver_post(sum_values, "sums") ``` """ - if not endpoint_name: - endpoint_name = endpoint_fx.__name__ - if endpoint_fx.__doc__ is not None: - api_desc = dedent(endpoint_fx.__doc__) - else: - api_desc = None - - if self.check_prototype is True: - - @self.app.post( - urljoin("/", endpoint_name), - name=endpoint_name, - description=api_desc, + if isinstance(endpoint_fx, SklearnPredictionTypes): + if not isinstance(self.model, SKLearnHandler): + raise ValueError( + "The 'endpoint_fx' parameter can only be a string when using scikit-learn models." + ) + self.vetiver_post( + self.model.handler_predict, + SklearnPredictionTypes, + check_prototype=self.check_prototype, + prediction_type=endpoint_fx, ) - async def custom_endpoint(input_data: List[self.model.prototype]): - _to_frame = api_data_to_frame(input_data) - predictions = endpoint_fx(_to_frame, **kw) - if isinstance(predictions, List): - return {endpoint_name: predictions} - else: - return predictions + return - else: + endpoint_name = endpoint_name or endpoint_fx.__name__ + endpoint_doc = dedent(endpoint_fx.__doc__) if endpoint_fx.__doc__ else None - @self.app.post(urljoin("/", endpoint_name)) - async def custom_endpoint(input_data: Request): + @self.app.post( + urljoin("/", endpoint_name), + name=endpoint_name, + description=endpoint_doc, + ) + async def custom_endpoint(input_data: List[self.model.prototype]): + if self.check_prototype: + served_data = api_data_to_frame(input_data) + else: served_data = await input_data.json() - predictions = endpoint_fx(served_data, **kw) - if isinstance(predictions, List): - return {endpoint_name: predictions} - else: - return predictions + predictions = endpoint_fx(served_data, **kw) + + if isinstance(predictions, List): + return {endpoint_name: predictions} + else: + return predictions def run(self, port: int = 8000, host: str = "127.0.0.1", quiet_open=False, **kw): """ diff --git a/vetiver/types.py b/vetiver/types.py index a0970274..e36eeba1 100644 --- a/vetiver/types.py +++ b/vetiver/types.py @@ -1,4 +1,5 @@ from pydantic import BaseModel, create_model +from typing import Literal all = ["Prototype", "create_prototype"] @@ -7,5 +8,8 @@ class Prototype(BaseModel): pass +SklearnPredictionTypes = Literal["predict", "predict_proba", "predict_log_proba"] + + def create_prototype(**dict_data): return create_model("prototype", __base__=Prototype, **dict_data) From 74f9a4448024e6c8328662f391d3e1c4ea6a55c6 Mon Sep 17 00:00:00 2001 From: isabel zimmerman Date: Tue, 18 Mar 2025 17:49:19 -0400 Subject: [PATCH 3/9] clean up vetiver_post --- vetiver/server.py | 45 +++++++++++++++++++-------- vetiver/tests/test_add_endpoint.py | 32 ------------------- vetiver/tests/test_server.py | 50 ++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 45 deletions(-) delete mode 100644 vetiver/tests/test_add_endpoint.py diff --git a/vetiver/server.py b/vetiver/server.py index d56da12a..9f96d3cb 100644 --- a/vetiver/server.py +++ b/vetiver/server.py @@ -11,7 +11,7 @@ import pandas as pd import requests import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse @@ -210,11 +210,13 @@ def vetiver_post( Parameters ---------- - endpoint_fx : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]] - A callable function that specifies the custom logic to execute when the endpoint is called. - This function should take input data (e.g., a DataFrame or dictionary) and return the desired output - (e.g., predictions or transformed data). For scikit-learn models, endpoint_fx can also be one of - "predict", "predict_proba", or "predict_log_proba" if the model supports these methods. + endpoint_fx + : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]] + A callable function that specifies the custom logic to execute when the + endpoint is called. This function should take input data (e.g., a DataFrame + or dictionary) and return the desired output(e.g., predictions or transformed + data). For scikit-learn models, endpoint_fx can also be one of "predict", + "predict_proba", or "predict_log_proba" if the model supports these methods. endpoint_name : str The name of the endpoint to be created. @@ -236,10 +238,20 @@ def sum_values(x): ``` """ - if isinstance(endpoint_fx, SklearnPredictionTypes): + if not isinstance(endpoint_fx, Callable): + if endpoint_fx not in SklearnPredictionTypes: + raise ValueError( + f""" + Prediction type {endpoint_fx} not available. + Available prediction types: {SklearnPredictionTypes} + """ + ) if not isinstance(self.model, SKLearnHandler): raise ValueError( - "The 'endpoint_fx' parameter can only be a string when using scikit-learn models." + """ + The 'endpoint_fx' parameter can only be a + string when using scikit-learn models. + """ ) self.vetiver_post( self.model.handler_predict, @@ -252,17 +264,24 @@ def sum_values(x): endpoint_name = endpoint_name or endpoint_fx.__name__ endpoint_doc = dedent(endpoint_fx.__doc__) if endpoint_fx.__doc__ else None + # this must be split up this way to preserve the correct type hints for + # the input_data schema validation via Pydantic + FastAPI + input_data_type = ( + List[self.model.prototype] if self.check_prototype else Request + ) + @self.app.post( urljoin("/", endpoint_name), name=endpoint_name, description=endpoint_doc, ) - async def custom_endpoint(input_data: List[self.model.prototype]): - if self.check_prototype: - served_data = api_data_to_frame(input_data) - else: - served_data = await input_data.json() + async def custom_endpoint(input_data: input_data_type): + served_data = ( + api_data_to_frame(input_data) + if self.check_prototype + else await input_data.json() + ) predictions = endpoint_fx(served_data, **kw) if isinstance(predictions, List): diff --git a/vetiver/tests/test_add_endpoint.py b/vetiver/tests/test_add_endpoint.py deleted file mode 100644 index 5a5f1a21..00000000 --- a/vetiver/tests/test_add_endpoint.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest -import pandas as pd -from vetiver import mock, VetiverModel - - -@pytest.fixture() -def model(): - X, y = mock.get_mock_data() - model = mock.get_mock_model() - - return VetiverModel(model.fit(X, y), "model", prototype_data=X) - - -@pytest.fixture -def data() -> pd.DataFrame: - return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) - - -def test_endpoint_adds(client, data): - response = client.post("/sum/", data=data.to_json(orient="records")) - - assert response.status_code == 200 - assert response.json() == {"sum": [3, 6, 9]} - - -def test_endpoint_adds_no_prototype(client_no_prototype, data): - - data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) - response = client_no_prototype.post("/sum/", data=data.to_json(orient="records")) - - assert response.status_code == 200 - assert response.json() == {"sum": [3, 6, 9]} diff --git a/vetiver/tests/test_server.py b/vetiver/tests/test_server.py index 97150c05..16dd91c3 100644 --- a/vetiver/tests/test_server.py +++ b/vetiver/tests/test_server.py @@ -12,6 +12,8 @@ import numpy as np import pytest import sys +import pandas as pd +from vetiver.handlers.sklearn import SKLearnHandler @pytest.fixture @@ -125,3 +127,51 @@ def test_vetiver_endpoint(): url = vetiver_endpoint(url_raw) assert url == "http://127.0.0.1:8000/predict" + + +@pytest.fixture +def data() -> pd.DataFrame: + return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) + + +def test_endpoint_adds(client, data): + response = client.post("/sum/", data=data.to_json(orient="records")) + + assert response.status_code == 200 + assert response.json() == {"sum": [3, 6, 9]} + + +def test_endpoint_adds_no_prototype(client_no_prototype, data): + + data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) + response = client_no_prototype.post("/sum/", data=data.to_json(orient="records")) + + assert response.status_code == 200 + assert response.json() == {"sum": [3, 6, 9]} + + +def test_vetiver_post_sklearn_predict(model): + vetiver_api = VetiverAPI(model=model) + if not isinstance(vetiver_api.model, SKLearnHandler): + pytest.skip("Test only applicable for SKLearnHandler models") + + vetiver_api.vetiver_post("predict_proba") + + client = TestClient(vetiver_api.app) + response = client.post( + "/predict_proba", json=vetiver_api.model.prototype.construct().dict() + ) + assert response.status_code == 200 + + +def test_vetiver_post_invalid_sklearn_type(model): + vetiver_api = VetiverAPI(model=model) + if not isinstance(vetiver_api.model, SKLearnHandler): + pytest.skip("Test only applicable for SKLearnHandler models") + + with pytest.raises( + ValueError, + match="The 'endpoint_fx' parameter can only be a string \ + when using scikit-learn models.", + ): + vetiver_api.vetiver_post("invalid_type") From 5e05568593b1abb0a079f7a21acda1767555ca01 Mon Sep 17 00:00:00 2001 From: isabel zimmerman Date: Wed, 19 Mar 2025 13:59:19 -0400 Subject: [PATCH 4/9] lint --- .pre-commit-config.yaml | 2 +- vetiver/handlers/sklearn.py | 17 ++++++++++++----- vetiver/handlers/statsmodels.py | 6 ++++-- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d74dc3f7..eff92427 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: types: - python args: - - "--max-line-length=90" + - "--max-line-length=100" - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml diff --git a/vetiver/handlers/sklearn.py b/vetiver/handlers/sklearn.py index 931b8211..3f27eaa4 100644 --- a/vetiver/handlers/sklearn.py +++ b/vetiver/handlers/sklearn.py @@ -30,8 +30,8 @@ def handler_predict(self, input_data, check_prototype: bool, **kw): Test data check_prototype: bool prediction_type: str - Type of prediction to make. One of "predict", "predict_proba", or "predict_log_proba". - Default is "predict". + Type of prediction to make. One of "predict", "predict_proba", + or "predict_log_proba". Default is "predict". Returns ------- @@ -40,7 +40,14 @@ def handler_predict(self, input_data, check_prototype: bool, **kw): """ prediction_type = kw.get("prediction_type", "predict") if prediction_type not in ["predict", "predict_proba", "predict_log_proba"]: - raise ValueError('prediction_type must be "predict", "predict_proba", or "predict_log_proba"') - - input_data = [input_data] if check_prototype and not isinstance(input_data, pd.DataFrame) else input_data + raise ValueError( + 'prediction_type must be "predict", "predict_proba", \ + or "predict_log_proba"' + ) + + input_data = ( + [input_data] + if check_prototype and not isinstance(input_data, pd.DataFrame) + else input_data + ) return getattr(self.model, prediction_type)(input_data).tolist() diff --git a/vetiver/handlers/statsmodels.py b/vetiver/handlers/statsmodels.py index 3eb59739..ab392898 100644 --- a/vetiver/handlers/statsmodels.py +++ b/vetiver/handlers/statsmodels.py @@ -42,6 +42,8 @@ def handler_predict(self, input_data, check_prototype, **kw): """ if not sm_exists: raise ImportError("Cannot import `statsmodels`") - - input_data = input_data if isinstance(input_data, (list, pd.DataFrame)) else [input_data] + + input_data = ( + input_data if isinstance(input_data, (list, pd.DataFrame)) else [input_data] + ) return self.model.predict(input_data).tolist() From 1554296af07a691b9d3cfd2e62ec6505824fa297 Mon Sep 17 00:00:00 2001 From: isabel zimmerman Date: Thu, 20 Mar 2025 17:36:37 -0400 Subject: [PATCH 5/9] update tests --- vetiver/__init__.py | 2 +- vetiver/handlers/sklearn.py | 11 ++-- vetiver/mock.py | 22 +++++++- vetiver/server.py | 6 +- vetiver/tests/test_server.py | 106 ++++++++++++++++++++++++----------- 5 files changed, 101 insertions(+), 46 deletions(-) diff --git a/vetiver/__init__.py b/vetiver/__init__.py index 31a56bcf..ae500395 100644 --- a/vetiver/__init__.py +++ b/vetiver/__init__.py @@ -10,7 +10,7 @@ ) # noqa from .vetiver_model import VetiverModel # noqa from .server import VetiverAPI, vetiver_endpoint, predict # noqa -from .mock import get_mock_data, get_mock_model # noqa +from .mock import get_mock_data, get_mock_model, get_mtcars_model # noqa from .pin_read_write import vetiver_pin_write # noqa from .attach_pkgs import load_pkgs, get_board_pkgs # noqa from .meta import VetiverMeta # noqa diff --git a/vetiver/handlers/sklearn.py b/vetiver/handlers/sklearn.py index 3f27eaa4..156c3212 100644 --- a/vetiver/handlers/sklearn.py +++ b/vetiver/handlers/sklearn.py @@ -39,15 +39,14 @@ def handler_predict(self, input_data, check_prototype: bool, **kw): Prediction from model """ prediction_type = kw.get("prediction_type", "predict") - if prediction_type not in ["predict", "predict_proba", "predict_log_proba"]: - raise ValueError( - 'prediction_type must be "predict", "predict_proba", \ - or "predict_log_proba"' - ) input_data = ( [input_data] if check_prototype and not isinstance(input_data, pd.DataFrame) else input_data ) - return getattr(self.model, prediction_type)(input_data).tolist() + + if prediction_type in ["predict_proba", "predict_log_proba"]: + return getattr(self.model, prediction_type)(input_data).tolist() + + return self.model.predict(input_data).to_list() diff --git a/vetiver/mock.py b/vetiver/mock.py index 780e4a54..e18fe8a3 100644 --- a/vetiver/mock.py +++ b/vetiver/mock.py @@ -1,7 +1,11 @@ -from sklearn.dummy import DummyRegressor import pandas as pd import numpy as np +from sklearn.dummy import DummyRegressor +from sklearn.linear_model import LogisticRegression + +from .data import mtcars + def get_mock_data(): """Create mock data for testing @@ -26,5 +30,17 @@ def get_mock_model(): model : sklearn.dummy.DummyRegressor Arbitrary model for testing purposes """ - model = DummyRegressor() - return model + return DummyRegressor() + + +def get_mtcars_model(): + """Create mock model for testing + + Returns + ------- + model : sklearn.dummy.DummyRegressor + Arbitrary model for testing purposes + """ + return LogisticRegression(max_iter=1000).fit( + mtcars.drop(columns="cyl"), mtcars["cyl"] + ) diff --git a/vetiver/server.py b/vetiver/server.py index 9f96d3cb..a5be4207 100644 --- a/vetiver/server.py +++ b/vetiver/server.py @@ -239,14 +239,14 @@ def sum_values(x): """ if not isinstance(endpoint_fx, Callable): - if endpoint_fx not in SklearnPredictionTypes: + if endpoint_fx not in ["predict", "predict_proba", "predict_log_proba"]: raise ValueError( f""" Prediction type {endpoint_fx} not available. Available prediction types: {SklearnPredictionTypes} """ ) - if not isinstance(self.model, SKLearnHandler): + if not isinstance(self.model.handler_predict.__self__, SKLearnHandler): raise ValueError( """ The 'endpoint_fx' parameter can only be a @@ -255,7 +255,7 @@ def sum_values(x): ) self.vetiver_post( self.model.handler_predict, - SklearnPredictionTypes, + endpoint_fx, check_prototype=self.check_prototype, prediction_type=endpoint_fx, ) diff --git a/vetiver/tests/test_server.py b/vetiver/tests/test_server.py index 16dd91c3..ff40e9b3 100644 --- a/vetiver/tests/test_server.py +++ b/vetiver/tests/test_server.py @@ -1,3 +1,11 @@ +import pytest +import sys +import pandas as pd +import numpy as np +from fastapi.testclient import TestClient +from pydantic import BaseModel, conint + +from vetiver.data import mtcars from vetiver import ( mock, VetiverModel, @@ -7,26 +15,18 @@ vetiver_endpoint, predict, ) -from pydantic import BaseModel, conint -from fastapi.testclient import TestClient -import numpy as np -import pytest -import sys -import pandas as pd -from vetiver.handlers.sklearn import SKLearnHandler @pytest.fixture def model(): np.random.seed(500) - X, y = mock.get_mock_data() - model = mock.get_mock_model().fit(X, y) + model = mock.get_mtcars_model() v = VetiverModel( model=model, - prototype_data=X, + prototype_data=mtcars.drop(columns="cyl"), model_name="my_model", versioned=None, - description="A regression model for testing purposes", + description="A logistic regression model for testing purposes", ) return v @@ -84,11 +84,29 @@ def test_get_prototype(client, model): assert response.status_code == 200, response.text assert response.json() == { "properties": { - "B": {"example": 55, "type": "integer"}, - "C": {"example": 65, "type": "integer"}, - "D": {"example": 17, "type": "integer"}, + "mpg": {"example": 21.0, "type": "number"}, + "disp": {"example": 160.0, "type": "number"}, + "hp": {"example": 110.0, "type": "number"}, + "drat": {"example": 3.9, "type": "number"}, + "wt": {"example": 2.62, "type": "number"}, + "qsec": {"example": 16.46, "type": "number"}, + "vs": {"example": 0.0, "type": "number"}, + "am": {"example": 1.0, "type": "number"}, + "gear": {"example": 4.0, "type": "number"}, + "carb": {"example": 4.0, "type": "number"}, }, - "required": ["B", "C", "D"], + "required": [ + "mpg", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb", + ], "title": "prototype", "type": "object", } @@ -131,14 +149,28 @@ def test_vetiver_endpoint(): @pytest.fixture def data() -> pd.DataFrame: - return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) + return pd.DataFrame( + { + "mpg": [20, 20], + "disp": [160, 160], + "hp": [110, 110], + "drat": [3.9, 3.9], + "wt": [2.62, 2.62], + "qsec": [16.00, 16.00], + "vs": [0, 0], + "am": [1, 1], + "gear": [4, 4], + "carb": [4, 4], + } + ) def test_endpoint_adds(client, data): + response = client.post("/sum/", data=data.to_json(orient="records")) assert response.status_code == 200 - assert response.json() == {"sum": [3, 6, 9]} + assert response.json() == {"sum": [40, 320, 220, 7.8, 5.24, 32.00, 0, 2, 8, 8]} def test_endpoint_adds_no_prototype(client_no_prototype, data): @@ -150,28 +182,36 @@ def test_endpoint_adds_no_prototype(client_no_prototype, data): assert response.json() == {"sum": [3, 6, 9]} -def test_vetiver_post_sklearn_predict(model): - vetiver_api = VetiverAPI(model=model) - if not isinstance(vetiver_api.model, SKLearnHandler): - pytest.skip("Test only applicable for SKLearnHandler models") - - vetiver_api.vetiver_post("predict_proba") - - client = TestClient(vetiver_api.app) - response = client.post( - "/predict_proba", json=vetiver_api.model.prototype.construct().dict() - ) - assert response.status_code == 200 +def test_vetiver_post_sklearn_predict(model, data): + api = VetiverAPI(model=model) + api.vetiver_post("predict_proba") + + client = TestClient(api.app) + response = predict(endpoint="/predict_proba/", data=data, test_client=client) + + assert isinstance(response, pd.DataFrame) + assert len(response) == 2 + assert response.to_dict() == { + "predict_proba": { + 0: [ + 0.00627480416153554, + 0.9937251958346092, + 3.855256735904704e-12, + ], + 1: [ + 0.00627480416153554, + 0.9937251958346092, + 3.855256735904704e-12, + ], + }, + } def test_vetiver_post_invalid_sklearn_type(model): vetiver_api = VetiverAPI(model=model) - if not isinstance(vetiver_api.model, SKLearnHandler): - pytest.skip("Test only applicable for SKLearnHandler models") with pytest.raises( ValueError, - match="The 'endpoint_fx' parameter can only be a string \ - when using scikit-learn models.", + match="Prediction type invalid_type not available", ): vetiver_api.vetiver_post("invalid_type") From 0d089ab4fb718bbd73dcc8f061818cfaab5ce068 Mon Sep 17 00:00:00 2001 From: isabel zimmerman Date: Thu, 20 Mar 2025 17:38:31 -0400 Subject: [PATCH 6/9] generalize model calls --- vetiver/handlers/sklearn.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vetiver/handlers/sklearn.py b/vetiver/handlers/sklearn.py index 156c3212..fbac67ae 100644 --- a/vetiver/handlers/sklearn.py +++ b/vetiver/handlers/sklearn.py @@ -46,7 +46,4 @@ def handler_predict(self, input_data, check_prototype: bool, **kw): else input_data ) - if prediction_type in ["predict_proba", "predict_log_proba"]: - return getattr(self.model, prediction_type)(input_data).tolist() - - return self.model.predict(input_data).to_list() + return getattr(self.model, prediction_type)(input_data).tolist() From df66f8cd7563fbad34981be7413917899848591a Mon Sep 17 00:00:00 2001 From: isabel zimmerman Date: Thu, 20 Mar 2025 17:40:36 -0400 Subject: [PATCH 7/9] seed for mock model --- vetiver/mock.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vetiver/mock.py b/vetiver/mock.py index e18fe8a3..e914d37b 100644 --- a/vetiver/mock.py +++ b/vetiver/mock.py @@ -41,6 +41,7 @@ def get_mtcars_model(): model : sklearn.dummy.DummyRegressor Arbitrary model for testing purposes """ + np.random.seed(500) return LogisticRegression(max_iter=1000).fit( mtcars.drop(columns="cyl"), mtcars["cyl"] ) From 16fdf99d0a880be5fc7b0442572080d261217ace Mon Sep 17 00:00:00 2001 From: isabel zimmerman Date: Fri, 21 Mar 2025 11:50:06 -0400 Subject: [PATCH 8/9] move seed --- vetiver/mock.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vetiver/mock.py b/vetiver/mock.py index e914d37b..0ed3de41 100644 --- a/vetiver/mock.py +++ b/vetiver/mock.py @@ -41,7 +41,6 @@ def get_mtcars_model(): model : sklearn.dummy.DummyRegressor Arbitrary model for testing purposes """ - np.random.seed(500) - return LogisticRegression(max_iter=1000).fit( + return LogisticRegression(max_iter=1000, random_state=500).fit( mtcars.drop(columns="cyl"), mtcars["cyl"] ) From c50513750c54b9b0bf263a05aca35bf776ee476a Mon Sep 17 00:00:00 2001 From: isabel zimmerman Date: Fri, 21 Mar 2025 12:01:17 -0400 Subject: [PATCH 9/9] allow approx values for differences in version/arch --- vetiver/tests/test_server.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/vetiver/tests/test_server.py b/vetiver/tests/test_server.py index ff40e9b3..6d953787 100644 --- a/vetiver/tests/test_server.py +++ b/vetiver/tests/test_server.py @@ -191,21 +191,26 @@ def test_vetiver_post_sklearn_predict(model, data): assert isinstance(response, pd.DataFrame) assert len(response) == 2 - assert response.to_dict() == { + # Allow for slight differences in architecture or library versions + expected = { "predict_proba": { 0: [ - 0.00627480416153554, - 0.9937251958346092, - 3.855256735904704e-12, + 0.0063, + 0.9937, + 3.59e-12, ], 1: [ - 0.00627480416153554, - 0.9937251958346092, - 3.855256735904704e-12, + 0.0063, + 0.9937, + 3.59e-12, ], }, } + response_dict = response.to_dict() + for key, value in expected["predict_proba"].items(): + assert response_dict["predict_proba"][key] == pytest.approx(value, rel=1e-2) + def test_vetiver_post_invalid_sklearn_type(model): vetiver_api = VetiverAPI(model=model)