Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions src/linalg_backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@ impl<F: 'static + Copy + Send + Sync> LinAlgBackendProvider<F> {
feature = "faer_links_ndarray_static_openblas"
))]
use ndarray::s;
#[cfg(any(
feature = "backend_openblas",
feature = "backend_openblas_system",
feature = "backend_mkl",
feature = "backend_mkl_system",
feature = "faer_links_ndarray_static_openblas"
use ndarray::{Array1, Array2};
#[cfg(all(
not(feature = "backend_faer"),
any(
feature = "backend_openblas",
feature = "backend_openblas_system",
feature = "backend_mkl",
feature = "backend_mkl_system",
feature = "faer_links_ndarray_static_openblas"
)
))]
use ndarray_linalg::Lapack;
use ndarray::{Array1, Array2};
// use num_traits::Float; // No longer needed directly by provider
use std::error::Error;
use std::marker::PhantomData;
Expand Down Expand Up @@ -90,9 +93,11 @@ mod ndarray_backend_impl {
use ndarray_linalg::{Eigh, Lapack, SVDInto, QR, UPLO};
use std::error::Error;

#[cfg_attr(feature = "backend_faer", allow(dead_code))]
#[derive(Debug, Default, Copy, Clone)]
pub struct NdarrayLinAlgBackend;

#[cfg_attr(feature = "backend_faer", allow(dead_code))]
fn to_dyn_error<E: Error + Send + Sync + 'static>(e: E) -> Box<dyn Error + Send + Sync> {
Box::new(e)
}
Expand Down Expand Up @@ -185,12 +190,15 @@ mod ndarray_backend_impl {
pub use NdarrayLinAlgBackend as Backend;
}

#[cfg(any(
feature = "backend_openblas",
feature = "backend_openblas_system",
feature = "backend_mkl",
feature = "backend_mkl_system",
feature = "faer_links_ndarray_static_openblas"
#[cfg(all(
not(feature = "backend_faer"),
any(
feature = "backend_openblas",
feature = "backend_openblas_system",
feature = "backend_mkl",
feature = "backend_mkl_system",
feature = "faer_links_ndarray_static_openblas"
)
))]
use ndarray_backend_impl::Backend as NdarrayLinAlgBackend;

Expand Down
15 changes: 12 additions & 3 deletions tests/pca_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ fn eigenvalues_descending(matrix: &Array2<f64>) -> Vec<f64> {
fn eigenvalues_descending(matrix: &Array2<f64>) -> Vec<f64> {
use ndarray_linalg::{Eigh, UPLO};

let (mut eigenvalues, _) = matrix.eigh(UPLO::Upper).expect("eigendecomposition failed");
let (eigenvalues, _) = matrix.eigh(UPLO::Upper).expect("eigendecomposition failed");
let mut values = eigenvalues.to_vec();
values.sort_by(|a, b| b.partial_cmp(a).expect("eigenvalues must be comparable"));
values
Expand Down Expand Up @@ -1953,8 +1953,17 @@ mod pca_tests {
}

if explained_variance_fit_eff_v161.len() > 0 {
if !approx::abs_diff_eq!(explained_variance_fit_eff_v161[0], 4.0, epsilon = TOLERANCE) {
panic!("Efficient PCA (fit) first explained variance for hardcoded data should be approx 4.0. Got: {}", explained_variance_fit_eff_v161[0]);
let expected_first_ev = n_features as f64;
if !approx::abs_diff_eq!(
explained_variance_fit_eff_v161[0],
expected_first_ev,
epsilon = TOLERANCE
) {
panic!(
"Efficient PCA (fit) first explained variance for hardcoded data should reflect sample-variance scaling (~{}). Got: {}",
expected_first_ev,
explained_variance_fit_eff_v161[0]
);
}
}

Expand Down
Loading