Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ all: install
install: clean venv ftio_venv msg

# Installs with external dependencies
full: clean venv ftio_venv_full msg
full: venv ftio_venv_full msg

# Installs debug version external dependencies
debug: venv ftio_debug_venv msg
Expand All @@ -54,7 +54,8 @@ ftio:
$(PYTHON) -m pip install .

ftio_full:
$(PYTHON) -m pip install '.[external-libs,development-libs,plot-libs]'
$(PYTHON) -m pip install -e '.[external-libs,development-libs,plot-libs,ml-libs]' --no-cache-dir || \
(echo "Installing external libs failed, trying fallback..." && $(PYTHON) -m pip install -e . --no-cache-dir)
venv:
$(PYTHON) -m venv .venv
@echo -e "Environment created. Using python from .venv/bin/python3"
Expand Down Expand Up @@ -114,6 +115,13 @@ test_all:
test:
cd test && python3 -m pytest && make clean

test_parallel:
@python3 -m pip show pytest-xdist > /dev/null 2>&1 || python3 -m pip install pytest-xdist
cd test && python3 -m pytest -n 4 && make clean

test_failed:
cd test && python3 -m pytest --ff && make clean

check_style: check_tools
black .
ruff check --fix
Expand Down
60 changes: 60 additions & 0 deletions docs/ml_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Machine-Learning Models Documentation

## Prerequisites

Install the packages needed (`pip install '.[external-libs,development-libs,ml-libs]'`)

## General Usage

[Hybrid Model]
The following example shows the high-level entry to training and forecasting using the function train_hybrid_model() of
the hybrid-model.

```python
file = os.path.join(os.path.dirname(__file__), "../examples/tmio/JSONL/8.jsonl")
model = train_hybrid_model(file, epochs=10, lr=0.003)
prediction = predict_next_sequence(model, file)
```

The function train_hybrid_model() also has parameters with standard values for the underlying structure of the model
which can be changed.
In this example only the embedded dimension is changed, but there are also parameters for the attention heads, the
feed-forward dimension etc.
Common values such as 2^n are usually the most effective variations to explore.

```python
file = os.path.join(os.path.dirname(__file__), "../examples/tmio/JSONL/8.jsonl")
model = train_hybrid_model(file, epochs=10, lr=0.003, emb_dim=256)
prediction = predict_next_sequence(model, file)
```

The training of the hybrid-model can be resumed by loading a .pth file created by the saving process.
It contains the parameters of the model and the state of the used optimizer.

```python
file = os.path.join(os.path.dirname(__file__), "../examples/tmio/JSONL/8.jsonl")
model = train_hybrid_model(file, epochs=10, lr=0.003, save=True)
model = train_hybrid_model(
file,
epochs=10,
lr=0.003,
load_state_dict_and_optimizer_state="model_and_optimizer.pth",
)
prediction = predict_next_sequence(model, file)
```

[(S)ARIMA]

The following example shows the high-level entry to training and forecasting using the train_arima() function of the
ARIMA/SARIMA models.
By changing the model_architecture parameter, SARIMA or ARIMA can be selected. The max_depth is recommended to be
relatively small,
since it's defining the maximum depth of differentations of the underlying data to reach stationarity.
A resumption of training is inherently not supported by the underlying model structure. Therefore, if new data is
available, then
training from the beginning is the only option.

```python
file = os.path.join(os.path.dirname(__file__), "../examples/tmio/JSONL/8.jsonl")
prediction = train_arima(file, max_depth=3, model_architecture="ARIMA")
```
34 changes: 12 additions & 22 deletions ftio/cli/ftio_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from ftio.freq._analysis_figures import AnalysisFigures
from ftio.freq._dft_workflow import ftio_dft
from ftio.freq._share_signal_data import SharedSignalData
from ftio.freq._wavelet_cont_workflow import ftio_wavelet_cont
from ftio.freq._wavelet_disc_workflow import ftio_wavelet_disc
from ftio.freq.autocorrelation import find_autocorrelation
Expand Down Expand Up @@ -119,9 +118,11 @@ def core(sim: dict, args: Namespace) -> tuple[Prediction, AnalysisFigures]:
return Prediction(), AnalysisFigures()

# Perform frequency analysis (dft/wavelet)
prediction_freq_analysis, analysis_figures, share = freq_analysis(args, sim)
prediction_freq_analysis, analysis_figures = freq_analysis(args, sim)
# Perform autocorrelation if args.autocorrelation is true + Merge the results into a single prediction
prediction_auto = find_autocorrelation(args, sim, analysis_figures, share)
prediction_auto = find_autocorrelation(
args, sim, analysis_figures, prediction_freq_analysis
)
# Merge results
prediction = merge_predictions(
args, prediction_freq_analysis, prediction_auto, analysis_figures
Expand All @@ -130,9 +131,7 @@ def core(sim: dict, args: Namespace) -> tuple[Prediction, AnalysisFigures]:
return prediction, analysis_figures


def freq_analysis(
args: Namespace, data: dict
) -> tuple[Prediction, AnalysisFigures, SharedSignalData]:
def freq_analysis(args: Namespace, data: dict) -> tuple[Prediction, AnalysisFigures]:
"""
Performs frequency analysis (DFT, continuous wavelet, or discrete wavelet) and prepares data for plotting.

Expand All @@ -153,20 +152,13 @@ def freq_analysis(

Returns:
tuple: A tuple containing:
- Prediction: Contains the prediction results, including:
- Prediction: Contains the prediction results, including
- "dominant_freq" (list): The identified dominant frequencies.
- "conf" (np.ndarray): Confidence values corresponding to the dominant frequencies.
- "t_start" (int): Start time of the analysis.
- "t_end" (int): End time of the analysis.
- "total_bytes" (int): Total bytes involved in the analysis.
- AnalysisFigures
- SharedSignalData: Contains sampled data used for sharing (e.g., autocorrelation) containing
the following fields:
- "b_sampled" (np.ndarray): The sampled bandwidth data.
- "freq" (np.ndarray): Frequencies corresponding to the sampled data.
- "t_start" (int): Start time of the sampled data.
- "t_end" (int): End time of the sampled data.
- "total_bytes" (int): Total bytes from the sampled data.
"""

#! Init
Expand All @@ -182,19 +174,17 @@ def freq_analysis(

#! Perform transformation
if "dft" in args.transformation:
prediction, analysis_figures, share = ftio_dft(
prediction, analysis_figures = ftio_dft(
args, bandwidth, time_b, total_bytes, ranks, text
)

elif "wave_disc" in args.transformation:
prediction, analysis_figures, share = ftio_wavelet_disc(
prediction, analysis_figures = ftio_wavelet_disc(
args, bandwidth, time_b, ranks, total_bytes
)

elif "wave_cont" in args.transformation:
prediction, analysis_figures, share = ftio_wavelet_cont(
args, bandwidth, time_b, ranks
)
prediction, analysis_figures = ftio_wavelet_cont(args, bandwidth, time_b, ranks)

elif any(t in args.transformation for t in ("astft", "efd", "vmd")):
# TODO: add a way to pass the results to FTIO
Expand All @@ -211,7 +201,7 @@ def freq_analysis(

from ftio.freq._astft_workflow import ftio_astft

prediction, analysis_figures, share = ftio_astft(
prediction, analysis_figures = ftio_astft(
args, bandwidth, time_b, total_bytes, ranks, text
)
sys.exit()
Expand All @@ -221,15 +211,15 @@ def freq_analysis(

from ftio.freq._amd_workflow import ftio_amd

prediction, analysis_figures, share = ftio_amd(
prediction, analysis_figures = ftio_amd(
args, bandwidth, time_b, total_bytes, ranks, text
)
sys.exit()

else:
raise Exception("Unsupported decomposition specified")

return prediction, analysis_figures, share
return prediction, analysis_figures


def run():
Expand Down
11 changes: 4 additions & 7 deletions ftio/freq/_dft_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ftio.freq._dft import dft
from ftio.freq._filter import filter_signal
from ftio.freq._fourier_fit import fourier_fit
from ftio.freq._share_signal_data import SharedSignalData
from ftio.freq.discretize import sample_data
from ftio.freq.helper import MyConsole
from ftio.freq.prediction import Prediction
Expand All @@ -35,7 +34,7 @@ def ftio_dft(
total_bytes: int = 0,
ranks: int = 1,
text: str = "",
) -> tuple[Prediction, AnalysisFigures, SharedSignalData]:
) -> tuple[Prediction, AnalysisFigures]:
"""
Performs a Discrete Fourier Transform (DFT) on the sampled bandwidth data, finds the dominant frequency, followed by outlier
detection to spot the dominant frequency. This function also prepares the necessary outputs for plotting or reporting.
Expand All @@ -52,10 +51,8 @@ def ftio_dft(
tuple:
- prediction (Prediction): Contains prediction results including dominant frequency, confidence, amplitude, etc.
- analysis_figures (AnalysisFigures): Data and plot figures.
- share (SharedSignalData): Contains shared information, including sampled bandwidth and total bytes.
"""
#! Default values for variables
share = SharedSignalData()
prediction = Prediction(args.transformation)
analysis_figures = AnalysisFigures(args)
console = MyConsole(verbose=args.verbose)
Expand Down Expand Up @@ -156,8 +153,8 @@ def ftio_dft(
plot_dft(args, prediction, analysis_figures)
console.print(" --- Done --- \n")

if args.autocorrelation:
share.set_data_from_predicition(b_sampled, prediction)
if args.autocorrelation or args.machine_learning:
prediction.b_sampled = b_sampled

precision_text = ""
# precision_text = precision_dft(amp, phi, dominant_index, b_sampled, t_sampled, frequencies, args.engine)
Expand All @@ -175,4 +172,4 @@ def ftio_dft(
console.print(
f"\n[cyan]{args.transformation.upper()} + {args.outlier} finished:[/] {time.time() - tik:.3f} s"
)
return prediction, analysis_figures, share
return prediction, analysis_figures
115 changes: 0 additions & 115 deletions ftio/freq/_share_signal_data.py

This file was deleted.

6 changes: 2 additions & 4 deletions ftio/freq/_wavelet_cont_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from ftio.freq._analysis_figures import AnalysisFigures
from ftio.freq._dft_workflow import ftio_dft
from ftio.freq._share_signal_data import SharedSignalData
from ftio.freq._wavelet import wavelet_cont
from ftio.freq._wavelet_helpers import get_scales
from ftio.freq.discretize import sample_data
Expand Down Expand Up @@ -49,7 +48,6 @@ def ftio_wavelet_cont(
ranks (int): The rank value (default is 0).
"""
#! Default values for variables
share = SharedSignalData()
prediction = Prediction(args.transformation)
console = MyConsole(verbose=args.verbose)

Expand Down Expand Up @@ -90,7 +88,7 @@ def ftio_wavelet_cont(
use_dominant_only = False
scales = []
t_sampled = time_stamps[0] + np.arange(0, len(b_sampled)) * 1 / args.freq
prediction, analysis_figures, share = ftio_dft(args, b_sampled, t_sampled)
prediction, analysis_figures = ftio_dft(args, b_sampled, t_sampled)
dominant_freq, _ = get_dominant_and_conf(prediction)

# Adjust wavelet
Expand Down Expand Up @@ -237,4 +235,4 @@ def ftio_wavelet_cont(
f"\n[cyan]{args.transformation.upper()} + {args.outlier} finished:[/] {time.time() - tik:.3f} s"
)

return prediction, analysis_figures, share
return prediction, analysis_figures
Loading