Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions tests/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,44 @@
import sys
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import json # For original_main_logic
import os # For original_main_logic
import scipy.linalg as la # For original_main_logic
import random # For original_main_logic
import json # For original_main_logic
import os # For original_main_logic
import scipy.linalg as la # For original_main_logic
import random # For original_main_logic


NEAR_ZERO_THRESHOLD = 1e-9


def center_and_scale_columns_np(X):
"""Replicates the Rust PCA preprocessing: column-wise centering and scaling
using the unbiased sample standard deviation (n-1 in the denominator) with
small-value sanitization."""

X = np.asarray(X, dtype=float)
if X.size == 0:
return X.copy(), np.array([]), np.array([])

n_samples, n_features = X.shape

mean = np.mean(X, axis=0)
centered = X - mean

if n_samples > 1:
sum_sq = np.sum(centered * centered, axis=0)
variance = np.maximum(sum_sq, 0.0) / float(n_samples - 1)
else:
variance = np.zeros(n_features)

std = np.sqrt(variance, dtype=float)
sanitized_std = np.where(
(~np.isfinite(std)) | (std <= NEAR_ZERO_THRESHOLD),
1.0,
std,
)

scaled = centered / sanitized_std
return scaled, mean, sanitized_std


def print_numpy_array_for_rust(arr):
Expand Down Expand Up @@ -145,8 +178,7 @@ def generate_random_data_original(samples=5, features=5, random_seed=None):
return np.random.randn(samples, features)

def manual_pca_original(X, n_components=None):
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_scaled, _, _ = center_and_scale_columns_np(X)
n_samples, n_features = X_scaled.shape
if n_components is None: n_components = min(n_samples, n_features)
else: n_components = min(n_components, min(n_samples, n_features))
Expand Down Expand Up @@ -181,8 +213,7 @@ def manual_pca_original(X, n_components=None):
return X_transformed, components, eigvals

def library_pca_original(X, n_components=None):
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_scaled, _, _ = center_and_scale_columns_np(X)
n_samples, n_features = X_scaled.shape
max_components = min(n_samples, n_features)
if n_components is None: n_components = max_components
Expand Down Expand Up @@ -288,7 +319,19 @@ def parse_arguments_main():

# Argument for n_components, used by both modes
# For --generate-reference-pca, it's required. For original_main_logic, it's optional.
parser.add_argument("-k", "--n-components", type=int, help="Number of components for PCA.")
# Accept both hyphenated and underscored versions of the flag. The Rust tests
# currently invoke the script with "--n_components", while the original
# command-line interface exposed "--n-components". Argparse treats hyphens
# and underscores as distinct option names, so support both spellings here
# by routing them to the same destination.
parser.add_argument(
"-k",
"--n-components",
"--n_components",
dest="n_components",
type=int,
help="Number of components for PCA.",
)

# Argument to switch to reference generation mode
parser.add_argument("--generate-reference-pca",
Expand Down
Loading