diff --git a/dose_response_models.py b/dose_response_models.py index 151c604..64326a9 100644 --- a/dose_response_models.py +++ b/dose_response_models.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod -from typing import NamedTuple +from typing import Callable, NamedTuple import numpy as np import pandas as pd @@ -16,6 +16,17 @@ from loss_functions import LossFunctions +# Default loss function (DT4 or DNORM) +DEFAULT_LOSS_FN = LossFunctions.DT4 +# Default bidirectional fit +DEFAULT_BID = True +# Default error (1e-32 or 1e-16) +DEFAULT_ERROR = 1e-32 +# Default error bounds +DEFAULT_ERROR_BOUNDS = (-32, 32) +# Default optimization solver +DEFAULT_METHOD = 'Nelder-Mead' + class DoseResponseModel(ABC): """Abstract base class defining dose-response model behavior.""" @@ -28,7 +39,10 @@ def __init__(self): self.best_fit_ = None self.best_params_ = None - DEFAULT_LOSS_FN = LossFunctions.DT4 + class _Param(NamedTuple): + name: str + guess_fn: Callable[[ArrayLike, ArrayLike, bool], float] + bounds_fn: Callable[[ArrayLike, ArrayLike, bool], tuple[float, float]] @property @abstractmethod @@ -37,75 +51,54 @@ def _name(self) -> str: @property @abstractmethod - def _fit_log_x(self) -> bool: + def _is_log_fit(self) -> bool: """Require flag to fit raw or log x for derived classes.""" + @property @abstractmethod - def _model_fn( - self, - tx: ArrayLike, - *args - ) -> ArrayLike: + def _model_params(self) -> list[_Param]: + """Require list of model parameters for derived classes.""" + + @abstractmethod + def _model_fn(self, tx: ArrayLike, *args) -> ArrayLike: """Require model curve-fitting function for derived classes.""" - class ParamGuess(NamedTuple): - """Named tuple representing an initial param guess and bounds.""" - guess: list[float] - bounds: list[tuple[float, float]] + @staticmethod + def _error_guess(y: ArrayLike): + """Initial model-agnostic estimate of error.""" + return ( + np.log(y_mad) + if (y_mad := median_abs_deviation(y, scale='normal')) > 0 + else np.log(DEFAULT_ERROR) + ) - @abstractmethod - def _parameterize_initial( - self, - tx: ArrayLike, - y: ArrayLike, - bid: bool - ) -> ParamGuess: - """Require method to guess initial conditions for derived classes.""" + @staticmethod + def _meds(tx: ArrayLike, y: ArrayLike, bid: bool): + """Calculate median response at each dose if multiple samples.""" + meds = pd.DataFrame({'y': y, 'tx': tx}).groupby('tx')['y'].median() + return abs(meds) if bid else meds # Absolute value if bidirectional fit def _transform_x(self, x: ArrayLike) -> ArrayLike: """Calculate log x if model is fitted in log space.""" - return x if not self._fit_log_x else np.log10(x) + return x if not self._is_log_fit else np.log10(x) - def _transform_err(self, err: float) -> float: - """Calculate exp err if model is fitted in log space.""" - return err if not self._fit_log_x else np.exp(err) - - def _obj_fn( - self, - params: ArrayLike, - tx: ArrayLike, - y: ArrayLike, - loss_fn: LossFunctions - ) -> float: - """Compute the objective function used to optimize model fitting. - - Args: - params: model parameters, including error term - tx: transformed dose data - y: response data - loss_fn: loss function - Returns: - objective function value - """ - return -loss_fn( - y, - self._model_fn(tx, *params[:-1]), - self._transform_err(params[-1]) - ) + def _transform_error(self, err: float) -> float: + """Calculate exp error if model is fitted in log space.""" + return err if not self._is_log_fit else np.exp(err) def fit( self, x: ArrayLike, y: ArrayLike, loss_fn: LossFunctions = DEFAULT_LOSS_FN, - bid: bool = True + bid: bool = DEFAULT_BID ): """Fit the model function to the provided dose-response data. Args: x: untransformed dose data y: response data - loss_fn: loss function (defaults dt4) + loss_fn: loss function (default dt4) bid: bidirectional fit (default true) Returns: fitted model object @@ -113,16 +106,29 @@ def fit( # Perform log transformation of data if needed tx = self._transform_x(x) - # Guess initial conditions and bounds - ic = self._parameterize_initial(tx, y, bid) + + # Define objective function + def obj_fn(params): + return -loss_fn( + y, + self._model_fn(tx, *params[:-1]), + self._transform_error(params[-1]) + ) + + # Guess initial conditions and bounds for model parameters, + # appending model-agnostic defaults for error parameter + x0 = [p.guess_fn(tx, y, bid) for p in self._model_params]\ + + [DoseResponseModel._error_guess(y)] + bounds = [p.bounds_fn(tx, y, bid) for p in self._model_params]\ + + [DEFAULT_ERROR_BOUNDS] # Perform optimization fit = minimize( - fun=self._obj_fn, - x0=ic.guess, - args=(tx, y, loss_fn), - bounds=ic.bounds, - method='L-BFGS-B' + fun=obj_fn, + x0=x0, + bounds=bounds, + method=DEFAULT_METHOD, + options={'disp': True} ) # Extract the fit information @@ -135,10 +141,7 @@ def fit( return self # Permit chaining with predict() - def predict( - self, - x: ArrayLike - ) -> ArrayLike: + def predict(self, x: ArrayLike) -> ArrayLike: """Use fitted model to perform prediction for new dose data. Args: @@ -155,7 +158,7 @@ def fit_predict( x: ArrayLike, y: ArrayLike, loss_fn: LossFunctions = DEFAULT_LOSS_FN, - bid: bool = True + bid: bool = DEFAULT_BID ) -> ArrayLike: """Fit the model then predict from the same data. @@ -175,28 +178,13 @@ class LogHillModel(DoseResponseModel): tp: theoretical maximal response (top) ga: gain AC50 p: gain power - er: error term """ _name = 'loghill' - _fit_log_x = True - - def _model_fn(self, tx, *args): - return args[0] / (1 + 10 ** (args[2] * (args[1] - tx))) + _is_log_fit = True - def _parameterize_initial(self, tx, y, bid): - # Calculate median response at each dose in case of multiple samples - meds = pd.DataFrame({'y': y, 'tx': tx}).groupby('tx')['y'].median() - # Initial parameter guesses - guess = [ - abs(meds).max() if bid else meds.max(), # tp0 - meds.idxmax() - 0.5, # ga0 - 1.2, # p0 - (np.log(y_mad) if (y_mad := median_abs_deviation(y)) > 0 - else np.log(1e-32)) # er0 - ] - - # Bounds for tp depend on whether fit is bidirectional + @staticmethod + def _tp_bounds_fn(tx, y, bid): if bid: tp_max = 1.2 * max([abs(min(y)), abs(max(y))]) tp_min = -tp_max @@ -204,11 +192,91 @@ def _parameterize_initial(self, tx, y, bid): tp_min = 0 tp_max = 1.2 * max(y) - bounds = [ - (tp_min, tp_max), # tp - (min(tx) - 1, max(tx) + 0.5), # ga - (0.3, 8), # p - (-20, 5), # er - ] + return tp_min, tp_max + + _model_params = [ + DoseResponseModel._Param( + 'tp', + lambda tx, y, bid: DoseResponseModel._meds(tx, y, bid).max(), + _tp_bounds_fn + ), + DoseResponseModel._Param( + 'ga', + (lambda tx, y, bid: + DoseResponseModel._meds(tx, y, bid).idxmax() - 0.5), + lambda tx, y, bid: (min(tx) - 1, max(tx) + 0.5), + ), + DoseResponseModel._Param( + 'p', + lambda tx, y, bid: 1.2, + lambda tx, y, bid: (0.3, 8) + ) + ] + + def _model_fn(self, tx, *params): + return params[0] / (1 + 10 ** (params[2] * (params[1] - tx))) + + +class Poly1Model(DoseResponseModel): + """Degree-1 polynomial (linear) model fitting function. + + Parameters: + a: y-scale (slope) + """ + + _name = 'poly1' + _is_log_fit = False + + @staticmethod + def _max_slope(tx, y, bid): + meds = DoseResponseModel._meds(tx, y, bid) + return meds.max() / max(tx) + + @staticmethod + def _a_bounds_fn(tx, y, bid): + val = 1e8 * abs(Poly1Model._max_slope(tx, y, bid)) + return (-val, val) if bid else (0, val) + + _model_params = [DoseResponseModel._Param('a', _max_slope, _a_bounds_fn)] + + def _model_fn(self, tx, *params): + return params[0] * tx + + +class PowModel(DoseResponseModel): + """Power model fitting function. + + Parameters: + a: y-scale + p: power + """ + + _name = 'pow' + _is_log_fit = False + + @staticmethod + def _max_slope(tx, y, bid): + meds = DoseResponseModel._meds(tx, y, bid) + return meds.max() / max(tx) + + @staticmethod + def _a_bounds_fn(tx, y, bid): + meds_max_abs = abs(DoseResponseModel._meds(tx, y, bid).max()) + val = 1e8 * meds_max_abs + return (-val, val) if bid else (1e-8 * meds_max_abs, val) + + _model_params = [ + DoseResponseModel._Param( + 'a', + lambda tx, y, bid: DoseResponseModel._meds(tx, y, bid).max(), + _a_bounds_fn + ), + DoseResponseModel._Param( + 'p', + lambda tx, y, bid: 1.5, + lambda tx, y, bid: (-20, 20) + ) + ] - return DoseResponseModel.ParamGuess(guess=guess, bounds=bounds) + def _model_fn(self, tx, *params): + return params[0] * tx ** params[1] diff --git a/loss_functions.py b/loss_functions.py index d15c886..4ac7f30 100644 --- a/loss_functions.py +++ b/loss_functions.py @@ -11,10 +11,19 @@ from numpy.typing import ArrayLike from scipy.stats import t, norm +from dose_response_models import DEFAULT_ERROR + class LossFunctions(Enum): """Callable enum of available loss functions.""" + @nonmember + @staticmethod + def _base(o, p, e, pdf, **kwargs): + if not e or e <= 0: + e = DEFAULT_ERROR + return np.sum(pdf((o - p) / e, **kwargs) - np.log(e)) + @nonmember @staticmethod def _loss_fn( @@ -22,7 +31,7 @@ def _loss_fn( **kwargs ) -> Callable[[ArrayLike, ArrayLike, ArrayLike], float]: """Generic log loss function using any input log PDF function.""" - return lambda o, p, e: np.sum(pdf((o - p) / e, **kwargs) - np.log(e)) + return lambda o, p, e: LossFunctions._base(o, p, e, pdf, **kwargs) # t-distributed log error with 4 DoF DT4 = member(staticmethod(_loss_fn(t.logpdf, df=4))) diff --git a/scratch.py b/scratch.py index 49435cc..762fe81 100644 --- a/scratch.py +++ b/scratch.py @@ -3,16 +3,17 @@ import numpy as np import matplotlib.pyplot as plt -from dose_response_models import LogHillModel -# from loss_functions import LossFunctions +from dose_response_models import LogHillModel, Poly1Model, PowModel # Mock data conc = np.array([0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]) -resp = np.array([0, 0, 0.1, 0.2, 0.5, 1.0, 1.5, 2.0]) +hill_resp = np.array([0, 0, .1, .2, .5, 1, 1.5, 2]) +poly1_resp = np.array([0, .01, .1, .1, .2, .5, 2, 5]) +pow_resp = np.array([0, .01, .1, .1, .2, .5, 2, 8]) # Initialize and fit the model -model = LogHillModel() -model.fit(conc, resp, bid=False) +model = PowModel() +model.fit(conc, pow_resp) if model.success_: # Output results @@ -30,7 +31,7 @@ # Generate plot fig, ax = plt.subplots() - ax.scatter(conc, resp, label='Observed', color='black') + ax.scatter(conc, pow_resp, label='Observed', color='black') ax.plot(conc_fine, pred, label='Fit', color='blue') # Add AC50 line ax.axvline(10 ** model.best_params_[1], linestyle='--', c='red')