diff --git a/lime/lime_tabular.py b/lime/lime_tabular.py index 880f3d391..b9752d15d 100644 --- a/lime/lime_tabular.py +++ b/lime/lime_tabular.py @@ -305,7 +305,8 @@ def explain_instance(self, num_samples=5000, distance_metric='euclidean', model_regressor=None, - sampling_method='gaussian'): + sampling_method='gaussian', + predict_fn_accept_dense_only=False): """Generates explanations for a prediction. First, we generate neighborhood data by randomly perturbing features @@ -358,8 +359,17 @@ def explain_instance(self, metric=distance_metric ).ravel() + is_convert_to_dense = False + if sp.sparse.isspmatrix_csr(inverse) and predict_fn_accept_dense_only: + inverse = inverse.toarray() + is_convert_to_dense = True + yss = predict_fn(inverse) + if is_convert_to_dense: + inverse = sp.sparse.csr_matrix(inverse) + is_convert_to_dense = False + # for classification, the model needs to provide a list of tuples - classes # along with prediction probabilities if self.mode == "classification":