Source code for tseda.decomposition.ssa_decomposition

"""Singular Spectrum Analysis decomposition and reconstruction utilities."""

from __future__ import annotations

from ssalib import SingularSpectrumAnalysis
from matplotlib import pyplot as plt
import ruptures as rpt
from statsmodels.stats.stattools import durbin_watson
from statsmodels.nonparametric.smoothers_lowess import lowess
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

from tseda.decomposition.automatic_grouping_heuristic import AutomaticGroupingHeuristic
from tseda.config.config_loader import ConfigurationManager


[docs] class SSADecomposition: """ A class to perform Singular Spectrum Analysis (SSA) on a time series. """
[docs] def __init__(self, series: pd.Series, window: int) -> None: """ Initialize SSA decomposition for a timestamp-indexed series. Args: series: Input numeric series indexed by datetime. window: Window size for SSA trajectory matrix construction. """ # Load seasonality heuristic leading eigenvalues from config self.SEASONALITY_HEURISTIC_LEADING_EIGENVALUES = ConfigurationManager.get( "seasonality_heuristic.leading_eigenvalues_to_check", 6 ) self._series = series self._window = window self._ssa = SingularSpectrumAnalysis(series, window=window) self._ssa.decompose() self._eigenvalues = self._ssa.eigenvalues self._total_variance = np.sum(self._eigenvalues) self._exp_var = {"var_comp-" + str(i) : (self._eigenvalues[i]/self._total_variance).item() for i in range(self._eigenvalues.shape[0])} self._cum_var = {"var_comp-" + str(i) : (np.sum(self._eigenvalues[:i+1])/self._total_variance).item() for i in range(self._eigenvalues.shape[0])} self._seasonality_check_heuristic = False self._reset_reconstruction_cache()
[docs] def seasonality_check_heuristic(self) -> bool: """Heuristic seasonality check based on near-equal leading eigenvalues. If any pair among the configured leading eigenvalues has a smaller/larger ratio of at least 0.95, mark the internal seasonality flag as True. """ heuristic = self.get_automatic_grouping_heuristic() max_components = min(self.SEASONALITY_HEURISTIC_LEADING_EIGENVALUES, len(heuristic.eigenvalues)) self._seasonality_check_heuristic = heuristic.has_seasonal_pair(max_components=max_components) return self._seasonality_check_heuristic
[docs] def get_automatic_grouping_heuristic( self, grouping_config: dict[str, int | float | str] | None = None, ) -> AutomaticGroupingHeuristic: """Return the automatic grouping heuristic for the current eigen spectrum. Args: grouping_config: Optional per-call overrides for ``grouping_heuristic.*`` settings. """ overrides = grouping_config or {} pool_selection_method = overrides.get( "pool_selection_method", ConfigurationManager.get("grouping_heuristic.pool_selection_method", "kneedle"), ) variance_threshold = overrides.get( "variance_threshold", ConfigurationManager.get("grouping_heuristic.variance_threshold", 0.10), ) pair_similarity_tolerance = overrides.get( "pair_similarity_tolerance", ConfigurationManager.get("grouping_heuristic.pair_similarity_tolerance", 0.05), ) kneedle_min_distance = overrides.get( "kneedle_min_distance", ConfigurationManager.get("grouping_heuristic.kneedle_min_distance", 0.03), ) min_signal_components = overrides.get( "min_signal_components", ConfigurationManager.get("grouping_heuristic.min_signal_components", 1), ) min_noise_components = overrides.get( "min_noise_components", ConfigurationManager.get("grouping_heuristic.min_noise_components", 2), ) return AutomaticGroupingHeuristic( eigenvalues=np.asarray(self._eigenvalues, dtype=float), pool_selection_method=str(pool_selection_method), variance_threshold=variance_threshold, pair_similarity_tolerance=pair_similarity_tolerance, kneedle_min_distance=kneedle_min_distance, min_signal_components=int(min_signal_components), min_noise_components=int(min_noise_components), )
[docs] def suggest_reconstruction_groups( self, grouping_config: dict[str, int | float | str] | None = None, ) -> tuple[dict[str, list[int]], bool]: """Return the best auto-inferred grouping and a Durbin-Watson satisfied flag. Starting from the threshold-based initial assignment, the method expands the assignment one component (or one seasonal pair) at a time until the Durbin-Watson statistic of the noise residual falls in [dw_low, dw_high] or the candidate pool is exhausted. The assignment with the DW value closest to 2.0 is returned. When no assignment achieves a DW in range the flag is False, signalling the caller to prompt the user to try a different window. Side effect: leaves SSA reconstruction state set to the returned assignment. """ # Load DW bounds from config dw_low = ConfigurationManager.get("noise_validation.dw_low", 1.5) dw_high = ConfigurationManager.get("noise_validation.dw_high", 2.5) heuristic = self.get_automatic_grouping_heuristic(grouping_config=grouping_config) assignment = heuristic.suggest_reconstruction() dw = self._compute_dw_for_assignment(assignment) if dw is not None and dw_low <= dw <= dw_high: return assignment, True best_assignment = assignment best_dw_distance = abs(dw - 2.0) if dw is not None else float("inf") while True: expanded, did_expand = heuristic.suggest_next_expansion(assignment) if not did_expand: break assignment = expanded dw = self._compute_dw_for_assignment(assignment) if dw is not None: dist = abs(dw - 2.0) if dist < best_dw_distance: best_assignment = expanded best_dw_distance = dist if dw_low <= dw <= dw_high: return best_assignment, True if best_assignment is not assignment: self._compute_dw_for_assignment(best_assignment) return best_assignment, False
def _compute_dw_for_assignment(self, assignment: dict[str, list[int]]) -> float | None: """Set reconstruction to the given assignment and return the Durbin-Watson statistic. Returns None when the noise group is absent or empty (DW cannot be computed). """ if not assignment.get("Noise"): return None filtered = {k: v for k, v in assignment.items() if v} self.set_reconstruction(filtered) self._ensure_reconstruction_cache() return self._durbin_watson def _reset_reconstruction_cache(self) -> None: """Clear cached reconstruction products so they are rebuilt from the latest grouping.""" self._raw_signal = None self._smoothed_signal = None self._noise_signal = None self._group_signals = {} self._durbin_watson = None self._variation_by_group = {} def _build_reconstruction_cache(self) -> None: """Build cached reconstructed series for raw, smoothed, noise, and each named group.""" if not hasattr(self, "_recon"): raise ValueError("Reconstruction map is not set. Call set_reconstruction() first.") signal_keys = [key for key in self._recon.keys() if key.casefold() != "noise"] noise_keys = [key for key in self._recon.keys() if key.casefold() == "noise"] self._variation_by_group = {} for group_name, component_indices in self._recon.items(): variation = 0.0 for component_index in component_indices: component_key = f"var_comp-{component_index}" variation += self._exp_var.get(component_key, 0.0) self._variation_by_group[group_name] = variation * 100 ssa_frame = self._ssa.to_frame() self._raw_signal = ssa_frame["ssa_preprocessed"].copy() self._group_signals = { group_name.casefold(): ssa_frame[group_name].copy() for group_name in signal_keys } self._smoothed_signal = ssa_frame[signal_keys].sum(axis=1).copy() if signal_keys else None if noise_keys: self._noise_signal = ssa_frame[noise_keys[0]].copy() self._durbin_watson = durbin_watson(self._noise_signal) else: self._noise_signal = None self._durbin_watson = None def _ensure_reconstruction_cache(self) -> None: """Build the reconstruction cache on first access.""" if self._raw_signal is None: self._build_reconstruction_cache()
[docs] def get_reconstructed_series(self, group_key: str) -> pd.Series | None: """Return a cached reconstructed series by case-insensitive group key. Args: group_key: Requested group label. Returns: Reconstructed series for the key, or None if unavailable. """ self._ensure_reconstruction_cache() normalized_key = group_key.strip().casefold() if normalized_key in {"raw", "raw_signal", "raw signal"}: return self._raw_signal if normalized_key in {"smoothed", "smoothed_signal", "smoothed signal"}: return self._smoothed_signal if normalized_key == "noise": return self._noise_signal return self._group_signals.get(normalized_key)
[docs] def get_group_series(self, group_key: str) -> pd.Series | None: """Alias for get_reconstructed_series for external callers.""" return self.get_reconstructed_series(group_key)
[docs] def eigenplot(self) -> go.Figure: """ Create the explained-variance-by-component line plot. Returns: Plotly figure with component explained variance. """ df_eig = pd.DataFrame.from_dict(self._exp_var, orient="index").reset_index() df_eig.columns = ["component", "explained_variance"] # Create the line plot with markers fig = px.line(df_eig, x="component", y="explained_variance", markers=True) fig.update_layout( title={ 'text': "Eigen Value Variance", 'y':0.9, # Position (0-1) 'x':0.5, # Position (0-1) 'xanchor': 'center', 'yanchor': 'top', 'font': dict(size=20) } ) return fig
[docs] def eigen_vector_plot(self) -> plt.Figure: """Return the SSA eigenvector matplotlib figure.""" fig, axes = self._ssa.plot(kind='vectors') return fig
[docs] def set_reconstruction(self, recon: dict[str, list[int]]) -> None: """ Set decomposition groups used for signal reconstruction. Args: recon: Mapping of group names to component index lists. """ self._reset_reconstruction_cache() self._recon = recon self._ssa.reconstruct(recon)
[docs] def wcorr_plot(self) -> plt.Figure: """Create a plot of the w-correlation matrix.""" fig, ax = self._ssa.plot(kind='wcorr', n_components=self._window) _ = ax.set_xlabel('Component Index') _ = ax.set_ylabel('Component Index') cbar = ax.collections[0].colorbar cbar.set_label('Weighted Correlation Values') return fig
[docs] def signal_reconstruction_plot(self) -> go.Figure: """Return a signal reconstruction plot for raw, smoothed, and grouped series. This call also refreshes and caches reconstruction products and updates the seasonality heuristic flag. """ if not hasattr(self, "_recon"): raise ValueError("Reconstruction map is not set. Call set_reconstruction() first.") self.seasonality_check_heuristic() self._ensure_reconstruction_cache() signal_keys = [key for key in self._recon.keys() if key.casefold() != "noise"] raw_signal = self.get_reconstructed_series("raw_signal") smoothed_signal = self.get_reconstructed_series("smoothed_signal") dates = self._series.index fig = go.Figure() fig.add_trace(go.Scatter( x=dates, y=raw_signal.values, mode="lines", name="Raw Signal", line=dict(color="#5a5a5a", width=1.5), opacity=0.7, )) fig.add_trace(go.Scatter( x=dates, y=smoothed_signal.values, mode="lines", name="Smoothed Signal", line=dict(color="#000000", width=2, dash="dash"), )) colors = ["#1b4f72", "#145a32", "#784212", "#4a235a", "#7b241c"] for i, key in enumerate(signal_keys): group_signal = self.get_reconstructed_series(key) fig.add_trace(go.Scatter( x=dates, y=group_signal.values, mode="lines", name=key, line=dict(width=1.5, color=colors[i % len(colors)]), )) fig.update_layout( title="Signal Reconstruction by Group", xaxis_title="Date", yaxis_title="Signal", legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), ) return fig
[docs] def loess_smother(self, fraction: float) -> go.Figure: """Fit a LOESS curve to the SSA preprocessed signal. Args: fraction: Fraction of data used for each local regression in LOESS. Returns: Figure with raw and LOESS-smoothed signals. """ if not 0 < fraction <= 1: raise ValueError("fraction must be in the interval (0, 1].") ssa_frame = self._ssa.to_frame() raw_signal = ssa_frame["ssa_preprocessed"] loess_signal = lowess( endog=raw_signal.values, exog=np.arange(len(raw_signal)), frac=fraction, return_sorted=False, ) plot_df = pd.DataFrame( { "date": self._series.index, "Raw Signal": raw_signal.values, "LOESS Smoothed Signal": loess_signal, } ) combined_df = plot_df.melt( id_vars="date", value_vars=["Raw Signal", "LOESS Smoothed Signal"], var_name="Series", value_name="Value", ) fig = px.line( combined_df, x="date", y="Value", color="Series", title="Raw and LOESS Smoothed Signal", labels={"date": "Date", "Value": "Signal"}, ) return fig
[docs] def change_point_plot(self) -> go.Figure: """Return a change-point analysis plot using the smoothed signal. Two independent PELT detectors are run: * **Trend shifts** — PELT on the z-score-normalised Trend component (``l2`` cost, penalty ``log(n)``). Rendered as vertical *dashed* lines. * **Seasonal amplitude shifts** — PELT on the z-score-normalised rolling-RMS envelope of the Seasonality component, with a rolling window equal to the SSA window size (``l2`` cost, penalty ``log(n)``). Rendered as vertical *dotted* lines. The smoothed signal (trend + all non-noise components) is plotted as a single continuous trace. A plain-language summary of detected changes is appended below the plot. """ if not hasattr(self, "_recon"): raise ValueError("Reconstruction map is not set. Call set_reconstruction() first.") self._ensure_reconstruction_cache() smoothed_signal = self.get_reconstructed_series("smoothed_signal") dates = self._series.index n = len(smoothed_signal) penalty = float(np.log(max(n, 2))) def _pelt_on(signal_1d: np.ndarray) -> np.ndarray: """Z-normalise then run PELT; return interior breakpoint indices.""" std = float(np.std(signal_1d)) normed = ((signal_1d - np.mean(signal_1d)) / std if std > 0.0 else signal_1d.copy()) bkps = rpt.Pelt(model="l2").fit(normed.reshape(-1, 1)).predict(pen=penalty) return np.array([b for b in bkps if b < n], dtype=int) # --- Trend shift detection -------------------------------------------- trend_signal = self._group_signals.get("trend", smoothed_signal) trend_change_points = _pelt_on(trend_signal.values.astype(float)) self._change_points = trend_change_points # backward-compatible attribute # --- Seasonal amplitude shift detection ------------------------------- seas_signal = self._group_signals.get("seasonality") seasonal_change_points: np.ndarray = np.array([], dtype=int) if seas_signal is not None: rms_envelope = ( pd.Series(seas_signal.values.astype(float)) .pow(2) .rolling(self._window, center=True, min_periods=1) .mean() .pow(0.5) .values ) seasonal_change_points = _pelt_on(rms_envelope) # --- Segment labels (stored for downstream consumers) ----------------- segment_ids = np.ones(n, dtype=int) for seg_idx, bp in enumerate(trend_change_points): segment_ids[bp:] = seg_idx + 2 segment_frame = pd.DataFrame( { "date": dates, "Smoothed Signal": smoothed_signal.values, "segment": [f"segment-{s}" for s in segment_ids], } ) self._segment_frame = segment_frame # --- Figure ----------------------------------------------------------- fig = go.Figure() fig.add_trace( go.Scatter( x=dates, y=smoothed_signal.values, mode="lines", name="Smoothed Signal", line=dict(color="steelblue", width=2), ) ) def _ts_to_ms(idx: int) -> float: return float(pd.Timestamp(dates[idx]).timestamp() * 1000) # Trend change points — dashed lines, blue palette trend_colors = px.colors.qualitative.Plotly for i, cp_idx in enumerate(trend_change_points): fig.add_vline( x=_ts_to_ms(cp_idx), line_dash="dash", line_color=trend_colors[i % len(trend_colors)], annotation_text=f"T{i + 1}", annotation_position="top left" if i % 2 == 0 else "top right", ) # Seasonal amplitude change points — dotted lines, warm palette seasonal_colors = px.colors.qualitative.Pastel for i, cp_idx in enumerate(seasonal_change_points): fig.add_vline( x=_ts_to_ms(cp_idx), line_dash="dot", line_color=seasonal_colors[i % len(seasonal_colors)], annotation_text=f"S{i + 1}", annotation_position="bottom left" if i % 2 == 0 else "bottom right", ) # --- Plain-language summary annotations below the plot ---------------- def _fmt_dates(indices: np.ndarray) -> str: if indices.size == 0: return "none detected" return ", ".join( str(dates[idx].date()) if hasattr(dates[idx], "date") else str(dates[idx]) for idx in indices ) trend_summary = ( f"Trend shifts (- -): {_fmt_dates(trend_change_points)}" ) seasonal_summary = ( f"Seasonal amplitude shifts (···): {_fmt_dates(seasonal_change_points)}" ) for row_idx, text in enumerate([trend_summary, seasonal_summary]): fig.add_annotation( text=text, xref="paper", yref="paper", x=0.0, y=-0.22 - row_idx * 0.09, showarrow=False, font=dict(size=11), align="left", xanchor="left", ) fig.update_layout( title="Change Point Analysis", xaxis_title="Date", yaxis_title="Signal", margin=dict(b=160), ) return fig
[docs] def explained_variance_by_group(self, group_name: str) -> float: """Return the explained variance (%) for a named reconstruction group. The lookup is case-insensitive. Returns 0.0 if the group is not present in the reconstruction map. Args: group_name (str): The group name as used in set_reconstruction(). Returns: float: Explained variance percentage (0–100). """ if not hasattr(self, "_recon"): raise ValueError("Reconstruction map is not set. Call set_reconstruction() first.") self._ensure_reconstruction_cache() key = group_name.strip().casefold() for name, value in self._variation_by_group.items(): if name.casefold() == key: return value return 0.0
[docs] def explained_variance_trend(self) -> float: """Return the explained variance (%) attributed to the trend group. Searches for a reconstruction group whose name contains 'trend' (case-insensitive). Returns: float: Explained variance percentage (0–100), or 0.0 if no trend group is defined. """ if not hasattr(self, "_recon"): raise ValueError("Reconstruction map is not set. Call set_reconstruction() first.") self._ensure_reconstruction_cache() for name, value in self._variation_by_group.items(): if "trend" in name.casefold(): return value return 0.0
[docs] def explained_variance_seasonality(self) -> float: """Return the explained variance (%) attributed to the seasonality group. Searches for a reconstruction group whose name contains 'season' (case-insensitive). Returns: float: Explained variance percentage (0–100), or 0.0 if no seasonality group is defined. """ if not hasattr(self, "_recon"): raise ValueError("Reconstruction map is not set. Call set_reconstruction() first.") self._ensure_reconstruction_cache() for name, value in self._variation_by_group.items(): if "season" in name.casefold(): return value return 0.0
[docs] def explained_variance_noise(self) -> float: """Return the explained variance (%) attributed to the noise group. Searches for a reconstruction group whose name contains 'noise' (case-insensitive). Returns: float: Explained variance percentage (0–100), or 0.0 if no noise group is defined. """ if not hasattr(self, "_recon"): raise ValueError("Reconstruction map is not set. Call set_reconstruction() first.") self._ensure_reconstruction_cache() for name, value in self._variation_by_group.items(): if "noise" in name.casefold(): return value return 0.0