From 5109ac4eb302105b2c80f56dbfe4273a4925568d Mon Sep 17 00:00:00 2001
From: guofu
Date: Sun, 1 Mar 2026 23:14:05 +0800
Subject: [PATCH] Simplify pipeline with struct columns
- Remove split_at_end parameter from pipeline.transform(), always return DataFrame
- Add pack_struct parameter to pack feature groups into struct columns
- Rename exporters: select_feature_groups_from_df -> get_groups, select_feature_groups -> get_groups_from_fg
- Add pack_structs() and unpack_struct() helper functions
- Remove split_from_merged() method from FeatureGroups (no longer needed)
- Rename dump_polars_dataset.py to dump_features.py with --pack-struct flag
- Update README with new CLI usage and struct column documentation
Co-Authored-By: Claude Opus 4.6
---
stock_1d/d033/alpha158_beta/README.md | 26 +-
.../alpha158_beta/scripts/dump_features.py | 265 ++++
.../scripts/dump_polars_dataset.py | 364 ------
.../scripts/generate_beta_embedding.py | 1100 ++++-------------
stock_1d/d033/alpha158_beta/src/__init__.py | 53 +
.../alpha158_beta/src/processors/__init__.py | 136 ++
.../alpha158_beta/src/processors/dataclass.py | 115 ++
.../alpha158_beta/src/processors/exporters.py | 523 ++++++++
.../alpha158_beta/src/processors/loaders.py | 279 +++++
.../alpha158_beta/src/processors/pipeline.py | 539 ++++++++
.../src/processors/processors.py | 447 +++++++
11 files changed, 2630 insertions(+), 1217 deletions(-)
create mode 100644 stock_1d/d033/alpha158_beta/scripts/dump_features.py
delete mode 100644 stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py
create mode 100644 stock_1d/d033/alpha158_beta/src/__init__.py
create mode 100644 stock_1d/d033/alpha158_beta/src/processors/__init__.py
create mode 100644 stock_1d/d033/alpha158_beta/src/processors/dataclass.py
create mode 100644 stock_1d/d033/alpha158_beta/src/processors/exporters.py
create mode 100644 stock_1d/d033/alpha158_beta/src/processors/loaders.py
create mode 100644 stock_1d/d033/alpha158_beta/src/processors/pipeline.py
create mode 100644 stock_1d/d033/alpha158_beta/src/processors/processors.py
diff --git a/stock_1d/d033/alpha158_beta/README.md b/stock_1d/d033/alpha158_beta/README.md
index c1b678a..2b84307 100644
--- a/stock_1d/d033/alpha158_beta/README.md
+++ b/stock_1d/d033/alpha158_beta/README.md
@@ -115,17 +115,22 @@ All fixed processors preserve the trained parameters from the original proc_list
### Polars Dataset Generation
-The `scripts/dump_polars_dataset.py` script generates datasets using a polars-based pipeline that replicates the qlib preprocessing:
+The `scripts/dump_features.py` script generates datasets using a polars-based pipeline that replicates the qlib preprocessing:
```bash
-# Generate raw and processed datasets
-python scripts/dump_polars_dataset.py
+# Generate merged features (flat columns)
+python scripts/dump_features.py --start-date 2024-01-01 --end-date 2024-01-31 --groups merged
+
+# Generate with struct columns (packed feature groups)
+python scripts/dump_features.py --start-date 2024-01-01 --end-date 2024-01-31 --groups merged --pack-struct
+
+# Generate specific feature groups
+python scripts/dump_features.py --start-date 2024-01-01 --end-date 2024-01-31 --groups alpha158 market_ext
```
This script:
1. Loads data from Parquet files (alpha158, kline, market flags, industry flags)
-2. Saves raw data (before processors) to `data_polars/raw_data_*.pkl`
-3. Applies the full processor pipeline:
+2. Applies the full processor pipeline:
- Diff processor (adds diff features)
- FlagMarketInjector (adds market_0, market_1)
- ColumnRemover (removes log_size_diff, IsN, IsZt, IsDt)
@@ -133,13 +138,20 @@ This script:
- IndusNtrlInjector (industry neutralization)
- RobustZScoreNorm (using pre-fitted qlib parameters via `from_version()`)
- Fillna (fill NaN with 0)
-4. Saves processed data to `data_polars/processed_data_*.pkl`
+3. Saves to parquet/pickle format
+
+**Output modes:**
+- **Flat mode (default)**: All columns as separate fields (348 columns for merged)
+- **Struct mode (`--pack-struct`)**: Feature groups packed into struct columns:
+ - `features_alpha158` (316 fields)
+ - `features_market_ext` (14 fields)
+ - `features_market_flag` (11 fields)
**Note**: The `FlagSTInjector` step is skipped because it fails silently even in the gold-standard qlib code (see `BUG_ANALYSIS_FINAL.md` for details).
Output structure:
- Raw data: ~204 columns (158 feature + 4 feature_ext + 12 feature_flag + 30 indus_flag)
-- Processed data: 342 columns (316 feature + 14 feature_ext + 11 feature_flag + 1 indus_idx)
+- Processed data: 348 columns (318 alpha158 + 14 market_ext + 14 market_flag + 2 index)
- VAE input dimension: 341 (excluding indus_idx)
### RobustZScoreNorm Parameter Extraction
diff --git a/stock_1d/d033/alpha158_beta/scripts/dump_features.py b/stock_1d/d033/alpha158_beta/scripts/dump_features.py
new file mode 100644
index 0000000..3a5d10d
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/scripts/dump_features.py
@@ -0,0 +1,265 @@
+#!/usr/bin/env python
+"""
+Script to generate and dump transformed features from the alpha158_beta pipeline.
+
+This script provides fine-grained control over the feature generation and dumping process:
+- Select which feature groups to dump (alpha158, market_ext, market_flag, merged, vae_input)
+- Choose output format (parquet, pickle, numpy)
+- Control date range and universe filtering
+- Save intermediate pipeline outputs
+- Enable streaming mode for large datasets (>1 year)
+
+Usage:
+ # Dump all features to parquet
+ python dump_features.py --start-date 2025-01-01 --end-date 2025-01-31
+
+ # Dump only alpha158 features to pickle
+ python dump_features.py --groups alpha158 --format pickle
+
+ # Dump with custom output path
+ python dump_features.py --output /path/to/output.parquet
+
+ # Dump merged features with all columns
+ python dump_features.py --groups merged --verbose
+
+ # Use streaming mode for large date ranges (>1 year)
+ python dump_features.py --start-date 2020-01-01 --end-date 2023-12-31 --streaming
+"""
+
+import os
+import sys
+import argparse
+from pathlib import Path
+from typing import Optional, List
+
+# Add src to path for imports
+SCRIPT_DIR = Path(__file__).parent
+sys.path.insert(0, str(SCRIPT_DIR.parent / 'src'))
+
+from processors import (
+ FeaturePipeline,
+ FeatureGroups,
+ VAE_INPUT_DIM,
+ ALPHA158_COLS,
+ MARKET_EXT_BASE_COLS,
+ COLUMNS_TO_REMOVE,
+ get_groups,
+ dump_to_parquet,
+ dump_to_pickle,
+ dump_to_numpy,
+)
+
+# Default output directory
+DEFAULT_OUTPUT_DIR = SCRIPT_DIR.parent / "data"
+
+
+def generate_and_dump(
+ start_date: str,
+ end_date: str,
+ output_path: str,
+ output_format: str = 'parquet',
+ groups: List[str] = None,
+ universe: str = 'csiallx',
+ filter_universe: bool = True,
+ robust_zscore_params_path: Optional[str] = None,
+ verbose: bool = True,
+ pack_struct: bool = False,
+ streaming: bool = False,
+) -> None:
+ """
+ Generate features and dump to file.
+
+ Args:
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+ output_path: Output file path
+ output_format: Output format ('parquet', 'pickle', 'numpy')
+ groups: Feature groups to dump (default: ['merged'])
+ universe: Stock universe name
+ filter_universe: Whether to filter to stock universe
+ robust_zscore_params_path: Path to robust zscore parameters
+ verbose: Whether to print progress
+ pack_struct: If True, pack each feature group into struct columns
+ (features_alpha158, features_market_ext, features_market_flag)
+ streaming: If True, use Polars streaming mode for large datasets (>1 year)
+ """
+ if groups is None:
+ groups = ['merged']
+
+ print("=" * 60)
+ print("Feature Dump Script")
+ print("=" * 60)
+ print(f"Date range: {start_date} to {end_date}")
+ print(f"Output format: {output_format}")
+ print(f"Feature groups: {groups}")
+ print(f"Universe: {universe} (filter: {filter_universe})")
+ print(f"Pack struct: {pack_struct}")
+ print(f"Output path: {output_path}")
+ print("=" * 60)
+
+ # Initialize pipeline
+ pipeline = FeaturePipeline(
+ robust_zscore_params_path=robust_zscore_params_path
+ )
+
+ # Load data
+ feature_groups = pipeline.load_data(
+ start_date, end_date,
+ filter_universe=filter_universe,
+ universe_name=universe,
+ streaming=streaming
+ )
+
+ # Apply transformations - get merged DataFrame (pipeline always returns merged DataFrame now)
+ df_transformed = pipeline.transform(feature_groups, pack_struct=pack_struct)
+
+ # Select feature groups from merged DataFrame
+ outputs = get_groups(df_transformed, groups, verbose, use_struct=False)
+
+ # Ensure output directory exists
+ output_dir = os.path.dirname(output_path)
+ if output_dir:
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Dump to file(s)
+ if output_format == 'numpy':
+ # For numpy, we save the merged features
+ dump_to_numpy(feature_groups, output_path, include_metadata=True, verbose=verbose)
+ elif output_format == 'pickle':
+ if 'merged' in outputs:
+ dump_to_pickle(outputs['merged'], output_path, verbose=verbose)
+ elif len(outputs) == 1:
+ # Single group output
+ key = list(outputs.keys())[0]
+ base_path = Path(output_path)
+ dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
+ dump_to_pickle(outputs[key], dump_path, verbose=verbose)
+ else:
+ # Multiple groups - save each separately
+ base_path = Path(output_path)
+ for key, df_out in outputs.items():
+ dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
+ dump_to_pickle(df_out, dump_path, verbose=verbose)
+ else: # parquet
+ if 'merged' in outputs:
+ dump_to_parquet(outputs['merged'], output_path, verbose=verbose)
+ elif len(outputs) == 1:
+ # Single group output
+ key = list(outputs.keys())[0]
+ base_path = Path(output_path)
+ dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
+ dump_to_parquet(outputs[key], dump_path, verbose=verbose)
+ else:
+ # Multiple groups - save each separately
+ base_path = Path(output_path)
+ for key, df_out in outputs.items():
+ dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
+ dump_to_parquet(df_out, dump_path, verbose=verbose)
+
+ print("=" * 60)
+ print("Feature dump complete!")
+ print("=" * 60)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Generate and dump transformed features from alpha158_beta pipeline"
+ )
+
+ # Date range
+ parser.add_argument(
+ "--start-date", type=str, required=True,
+ help="Start date in YYYY-MM-DD format"
+ )
+ parser.add_argument(
+ "--end-date", type=str, required=True,
+ help="End date in YYYY-MM-DD format"
+ )
+
+ # Output settings
+ parser.add_argument(
+ "--output", "-o", type=str, default=None,
+ help=f"Output file path (default: {DEFAULT_OUTPUT_DIR}/features.parquet)"
+ )
+ parser.add_argument(
+ "--format", "-f", type=str, default='parquet',
+ choices=['parquet', 'pickle', 'numpy'],
+ help="Output format (default: parquet)"
+ )
+
+ # Feature groups
+ parser.add_argument(
+ "--groups", "-g", type=str, nargs='+', default=['merged'],
+ choices=['merged', 'alpha158', 'market_ext', 'market_flag', 'vae_input'],
+ help="Feature groups to dump (default: merged)"
+ )
+
+ # Universe settings
+ parser.add_argument(
+ "--universe", type=str, default='csiallx',
+ help="Stock universe name (default: csiallx)"
+ )
+ parser.add_argument(
+ "--no-filter-universe", action="store_true",
+ help="Disable stock universe filtering"
+ )
+
+ # Robust zscore parameters
+ parser.add_argument(
+ "--robust-zscore-params", type=str, default=None,
+ help="Path to robust zscore parameters directory"
+ )
+
+ # Verbose mode
+ parser.add_argument(
+ "--verbose", "-v", action="store_true", default=True,
+ help="Enable verbose output (default: True)"
+ )
+ parser.add_argument(
+ "--quiet", "-q", action="store_true",
+ help="Disable verbose output"
+ )
+
+ # Struct option
+ parser.add_argument(
+ "--pack-struct", "-s", action="store_true",
+ help="Pack each feature group into separate struct columns (features_alpha158, features_market_ext, features_market_flag)"
+ )
+
+ # Streaming option
+ parser.add_argument(
+ "--streaming", action="store_true",
+ help="Use Polars streaming mode for large datasets (recommended for date ranges > 1 year)"
+ )
+
+ args = parser.parse_args()
+
+ # Handle verbose/quiet flags
+ verbose = args.verbose and not args.quiet
+
+ # Set default output path
+ if args.output is None:
+ # Build default output path: {data_dir}/features_{group}.parquet
+ # Note: generate_and_dump will add group suffix, so use base name "features"
+ output_path = str(DEFAULT_OUTPUT_DIR / "features.parquet")
+ else:
+ output_path = args.output
+
+ # Generate and dump
+ generate_and_dump(
+ start_date=args.start_date,
+ end_date=args.end_date,
+ output_path=output_path,
+ output_format=args.format,
+ groups=args.groups,
+ universe=args.universe,
+ filter_universe=not args.no_filter_universe,
+ robust_zscore_params_path=args.robust_zscore_params,
+ verbose=verbose,
+ pack_struct=args.pack_struct,
+ streaming=args.streaming,
+ )
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py
deleted file mode 100644
index 7b961d0..0000000
--- a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py
+++ /dev/null
@@ -1,364 +0,0 @@
-#!/usr/bin/env python
-"""
-Script to dump raw and processed datasets using the polars-based pipeline.
-
-This generates:
-1. Raw data (before applying processors) - equivalent to qlib's handler output
-2. Processed data (after applying all processors) - ready for VAE encoding
-
-Date range: 2026-02-23 to today (2026-02-27)
-"""
-
-import os
-import sys
-import pickle as pkl
-import numpy as np
-import polars as pl
-from pathlib import Path
-from datetime import datetime
-
-# Add parent directory to path for imports
-sys.path.insert(0, str(Path(__file__).parent))
-
-# Import processors from the new shared module
-from cta_1d.src.processors import (
- DiffProcessor,
- FlagMarketInjector,
- FlagSTInjector,
- ColumnRemover,
- FlagToOnehot,
- IndusNtrlInjector,
- RobustZScoreNorm,
- Fillna,
-)
-
-# Import constants from local module
-from generate_beta_embedding import (
- ALPHA158_COLS,
- INDUSTRY_FLAG_COLS,
-)
-
-# Date range
-START_DATE = "2026-02-23"
-END_DATE = "2026-02-27"
-
-# Output directory
-OUTPUT_DIR = Path(__file__).parent.parent / "data_polars"
-OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
-
-
-def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
- """
- Apply the full processor pipeline (equivalent to qlib's proc_list).
-
- This mimics the qlib proc_list:
- 0. Diff: Adds diff features for market_ext columns
- 1. FlagMarketInjector: Adds market_0, market_1
- 2. FlagSTInjector: Adds IsST (placeholder if ST flags not available)
- 3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
- 4. FlagToOnehot: Converts 29 industry flags to indus_idx
- 5. IndusNtrlInjector: Industry neutralization for feature
- 6. IndusNtrlInjector: Industry neutralization for feature_ext
- 7. RobustZScoreNorm: Normalization using pre-fitted qlib params
- 8. Fillna: Fill NaN with 0
- """
- print("=" * 60)
- print("Applying processor pipeline")
- print("=" * 60)
-
- # market_ext columns (4 base)
- market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
-
- # market_flag columns (12 total before ColumnRemover)
- market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
- 'open_limit', 'close_limit', 'low_limit',
- 'open_stop', 'close_stop', 'high_stop']
-
- # Step 1: Diff Processor
- print("\n[1] Applying Diff processor...")
- diff_processor = DiffProcessor(market_ext_base)
- df = diff_processor.process(df)
-
- # After Diff: market_ext has 8 columns (4 base + 4 diff)
- market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base]
-
- # Step 2: FlagMarketInjector (adds market_0, market_1)
- print("[2] Applying FlagMarketInjector...")
- flag_injector = FlagMarketInjector()
- df = flag_injector.process(df)
-
- # Add market_0, market_1 to flag list
- market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
-
- # Step 3: FlagSTInjector - adds IsST (placeholder if ST flags not available)
- print("[3] Applying FlagSTInjector...")
- flag_st_injector = FlagSTInjector()
- df = flag_st_injector.process(df)
-
- # Add IsST to flag list
- market_flag_with_st = market_flag_with_market + ['IsST']
-
- # Step 4: ColumnRemover
- print("[4] Applying ColumnRemover...")
- columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
- remover = ColumnRemover(columns_to_remove)
- df = remover.process(df)
-
- # Update column lists after removal
- market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove]
- market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove]
-
- print(f" Removed columns: {columns_to_remove}")
- print(f" Remaining market_ext: {len(market_ext_cols)} columns")
- print(f" Remaining market_flag: {len(market_flag_with_st)} columns")
-
- # Step 5: FlagToOnehot
- print("[5] Applying FlagToOnehot...")
- flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
- df = flag_to_onehot.process(df)
-
- # Step 6 & 7: IndusNtrlInjector
- print("[6] Applying IndusNtrlInjector for alpha158...")
- alpha158_cols = ALPHA158_COLS.copy()
- indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl')
- df = indus_ntrl_alpha.process(df)
-
- print("[7] Applying IndusNtrlInjector for market_ext...")
- indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl')
- df = indus_ntrl_ext.process(df)
-
- # Build column lists for normalization
- alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols]
- market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols]
-
- # Step 8: RobustZScoreNorm
- print("[8] Applying RobustZScoreNorm...")
- norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
-
- # Load RobustZScoreNorm with pre-fitted parameters from version
- robust_norm = RobustZScoreNorm.from_version(
- "csiallx_feature2_ntrla_flag_pnlnorm",
- feature_cols=norm_feature_cols
- )
-
- # Verify parameter shape matches expected features
- expected_features = len(norm_feature_cols)
- if robust_norm.qlib_mean.shape[0] != expected_features:
- print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
- f"got {robust_norm.qlib_mean.shape[0]}")
-
- df = robust_norm.process(df)
-
- # Step 9: Fillna
- print("[9] Applying Fillna...")
- final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx']
- fillna = Fillna()
- df = fillna.process(df, final_feature_cols)
-
- print("\n" + "=" * 60)
- print("Processor pipeline complete!")
- print(f" Normalized features: {len(norm_feature_cols)}")
- print(f" Market flags: {len(market_flag_with_st)}")
- print(f" Total features (with indus_idx): {len(final_feature_cols)}")
- print("=" * 60)
-
- return df
-
-
-def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
- """
- Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
- This matches the format of qlib's output.
-
- IMPORTANT: Qlib's IndusNtrlInjector outputs columns in order [_ntrl] + [raw],
- so we need to reorder columns to match this expected order.
- """
- import pandas as pd
-
- # Convert to pandas
- df = df_polars.to_pandas()
-
- # Check if datetime and instrument are columns
- if 'datetime' in df.columns and 'instrument' in df.columns:
- # Set MultiIndex
- df = df.set_index(['datetime', 'instrument'])
- # If they're already not in columns, assume they're already the index
-
- # Drop raw columns that shouldn't be in processed data
- raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue']
- existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns]
- if existing_raw_cols:
- df = df.drop(columns=existing_raw_cols)
-
- # Build MultiIndex columns based on column name patterns
- # IMPORTANT: Qlib order is [_ntrl columns] + [raw columns] for each group
- columns_with_group = []
-
- # Define column sets
- alpha158_base = set(ALPHA158_COLS)
- market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'}
- market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'}
- market_ext_all = market_ext_base | market_ext_diff
- feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
- 'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
-
- # First pass: collect _ntrl columns (these come first in qlib order)
- ntrl_alpha158_cols = []
- ntrl_market_ext_cols = []
- raw_alpha158_cols = []
- raw_market_ext_cols = []
- flag_cols = []
- indus_idx_col = None
-
- for col in df.columns:
- if col == 'indus_idx':
- indus_idx_col = col
- elif col in feature_flag_cols:
- flag_cols.append(col)
- elif col.endswith('_ntrl'):
- base_name = col[:-5] # Remove _ntrl suffix (5 characters)
- if base_name in alpha158_base:
- ntrl_alpha158_cols.append(col)
- elif base_name in market_ext_all:
- ntrl_market_ext_cols.append(col)
- elif col in alpha158_base:
- raw_alpha158_cols.append(col)
- elif col in market_ext_all:
- raw_market_ext_cols.append(col)
- elif col in INDUSTRY_FLAG_COLS:
- columns_with_group.append(('indus_flag', col))
- elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
- columns_with_group.append(('st_flag', col))
- else:
- # Unknown column - print warning
- print(f" Warning: Unknown column '{col}', assigning to 'other' group")
- columns_with_group.append(('other', col))
-
- # Build columns in qlib order: [_ntrl] + [raw] for each feature group
- # Feature group: alpha158_ntrl + alpha158
- for col in sorted(ntrl_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x.replace('_ntrl', '')) if x.replace('_ntrl', '') in ALPHA158_COLS else 999):
- columns_with_group.append(('feature', col))
- for col in sorted(raw_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x) if x in ALPHA158_COLS else 999):
- columns_with_group.append(('feature', col))
-
- # Feature_ext group: market_ext_ntrl + market_ext
- for col in ntrl_market_ext_cols:
- columns_with_group.append(('feature_ext', col))
- for col in raw_market_ext_cols:
- columns_with_group.append(('feature_ext', col))
-
- # Feature_flag group
- for col in flag_cols:
- columns_with_group.append(('feature_flag', col))
-
- # Indus_idx
- if indus_idx_col:
- columns_with_group.append(('indus_idx', indus_idx_col))
-
- # Create MultiIndex columns
- multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
- df.columns = multi_cols
-
- return df
-
-
-def main():
- print("=" * 80)
- print("Dumping Polars Dataset")
- print("=" * 80)
- print(f"Date range: {START_DATE} to {END_DATE}")
- print(f"Output directory: {OUTPUT_DIR}")
- print()
-
- # Step 1: Load all data
- print("Step 1: Loading data from parquet...")
- df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE)
- print(f" Alpha158 shape: {df_alpha.shape}")
- print(f" Kline (market_ext) shape: {df_kline.shape}")
- print(f" Flags shape: {df_flag.shape}")
- print(f" Industry shape: {df_industry.shape}")
-
- # Step 2: Merge data sources
- print("\nStep 2: Merging data sources...")
- df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
- print(f" Merged shape (after csiallx filter): {df_merged.shape}")
-
- # Step 3: Save raw data (before processors)
- print("\nStep 3: Saving raw data (before processors)...")
-
- # Keep columns that match qlib's raw output format
- # Include datetime and instrument for MultiIndex conversion
- raw_columns = (
- ['datetime', 'instrument'] + # Index columns
- ALPHA158_COLS + # feature group
- ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base
- ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline
- 'open_limit', 'close_limit', 'low_limit',
- 'open_stop', 'close_stop', 'high_stop'] +
- INDUSTRY_FLAG_COLS + # indus_flag
- (['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available)
- )
-
- # Filter to available columns
- available_raw_cols = [c for c in raw_columns if c in df_merged.columns]
- print(f" Selecting {len(available_raw_cols)} columns for raw data...")
- df_raw_polars = df_merged.select(available_raw_cols)
-
- # Convert to pandas with MultiIndex
- df_raw_pd = convert_to_multiindex_df(df_raw_polars)
-
- raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
- with open(raw_output_path, "wb") as f:
- pkl.dump(df_raw_pd, f)
- print(f" Saved raw data to: {raw_output_path}")
- print(f" Raw data shape: {df_raw_pd.shape}")
- print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}")
-
- # Step 4: Apply processor pipeline
- print("\nStep 4: Applying processor pipeline...")
- df_processed = apply_processor_pipeline(df_merged)
-
- # Step 5: Save processed data
- print("\nStep 5: Saving processed data (after processors)...")
-
- # Convert to pandas with MultiIndex
- df_processed_pd = convert_to_multiindex_df(df_processed)
-
- processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
- with open(processed_output_path, "wb") as f:
- pkl.dump(df_processed_pd, f)
- print(f" Saved processed data to: {processed_output_path}")
- print(f" Processed data shape: {df_processed_pd.shape}")
- print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}")
-
- # Count columns per group
- print("\n Column counts by group:")
- for grp in df_processed_pd.columns.get_level_values(0).unique().tolist():
- count = (df_processed_pd.columns.get_level_values(0) == grp).sum()
- print(f" {grp}: {count} columns")
-
- # Step 6: Verify column counts
- print("\n" + "=" * 80)
- print("Verification")
- print("=" * 80)
-
- feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag']
- has_market_0 = 'market_0' in feature_flag_cols
- has_market_1 = 'market_1' in feature_flag_cols
-
- print(f" feature_flag columns: {feature_flag_cols}")
- print(f" Has market_0: {has_market_0}")
- print(f" Has market_1: {has_market_1}")
-
- if has_market_0 and has_market_1:
- print("\n SUCCESS: market_0 and market_1 columns are present!")
- else:
- print("\n WARNING: market_0 or market_1 columns are missing!")
-
- print("\n" + "=" * 80)
- print("Dataset dump complete!")
- print("=" * 80)
-
-
-if __name__ == "__main__":
- main()
diff --git a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py
index d8316c3..f7bf05e 100644
--- a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py
+++ b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py
@@ -2,51 +2,44 @@
"""
Standalone script to generate embeddings from alpha158_0_7_beta factors using the VAE encoder.
-This script implements the full feature transformation pipeline:
-1. Load all 6 data sources from Parquet:
- - Alpha158: stg_1day_wind_alpha158_0_7_beta_1D/ (158 features)
- - Market Ext: stg_1day_wind_kline_adjusted_1D/ (Turnover, FreeTurnover, MarketValue -> log_size)
- - Con Rating: stg_1day_gds_con_rating_1D/ (con_rating_strength)
- - Market Flag: stg_1day_wind_kline_adjusted_1D/ (IsZt, IsDt, IsN, IsXD, IsXR, IsDR)
- - Market Flag: stg_1day_wind_market_flag_1D/ (open_limit, close_limit, low_limit, open_stop, close_stop, high_stop)
- - Industry Flag: stg_1day_gds_indus_flag_cc1_1D/ (29 one-hot industries)
-
-2. Apply 9 processors in sequence:
- - Diff: Adds 4 diff features to feature_ext
- - FlagMarketInjector: Adds market_0, market_1 to feature_flag
- - FlagSTInjector: Adds IsST (placeholder, all zeros)
- - ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
- - FlagToOnehot: Converts 29 industry flags to single indus_idx
- - IndusNtrlInjector (x2): Industry neutralization for feature and feature_ext
- - RobustZScoreNorm: Robust z-score normalization using pre-fitted qlib parameters
- - Fillna: Fill NaN values with 0
-
-3. Encode with VAE:
- - Load VAE model from alpha/data_ops/tasks/dwm_feature_vae/model/
- - Run inference to generate 32-dim embeddings
- - Save to parquet
-
-Note: Feature order is critical - alpha158 columns are in explicit order matching the VAE training.
+This script uses the new modular processors package for data loading and feature
+transformation, and focuses on VAE encoding and output generation.
+
+Workflow:
+1. Load data using FeaturePipeline (loads from 6 parquet sources)
+2. Transform features using the modular processor pipeline
+3. Encode with VAE to generate 32-dim embeddings
+4. Save embeddings to parquet
+
+Note: The data loading and transformation logic is now in the processors module:
+ stock_1d/d033/alpha158_beta/src/processors/
"""
import os
import sys
import pickle as pkl
-import io
import numpy as np
import polars as pl
import torch
import torch.nn as nn
from pathlib import Path
-from datetime import datetime
-from typing import Optional, List, Tuple, Dict, Set
-
-# Constants
-PARQUET_ALPHA158_BETA_PATH = "/data/parquet/dataset/stg_1day_wind_alpha158_0_7_beta_1D/"
-PARQUET_KLINE_PATH = "/data/parquet/dataset/stg_1day_wind_kline_adjusted_1D/"
-PARQUET_MARKET_FLAG_PATH = "/data/parquet/dataset/stg_1day_wind_market_flag_1D/"
-PARQUET_INDUSTRY_FLAG_PATH = "/data/parquet/dataset/stg_1day_gds_indus_flag_cc1_1D/"
-PARQUET_CON_RATING_PATH = "/data/parquet/dataset/stg_1day_gds_con_rating_1D/"
+from typing import Optional, List, Tuple
+
+# Import from the new processors module
+sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
+from processors import (
+ FeaturePipeline,
+ FeatureGroups,
+ filter_stock_universe,
+ ALPHA158_COLS,
+ MARKET_EXT_BASE_COLS,
+ MARKET_FLAG_COLS,
+ COLUMNS_TO_REMOVE,
+ VAE_INPUT_DIM,
+ DEFAULT_ROBUST_ZSCORE_PARAMS_PATH,
+)
+
+# Constants for VAE and output
VAE_MODEL_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/model/csiallx_feature2_ntrla_flag_pnlnorm_vae4_dim32a_beta0001/module.pt"
OUTPUT_DIR = "../data"
@@ -54,752 +47,239 @@ OUTPUT_DIR = "../data"
DEFAULT_START_DATE = "2019-01-01"
DEFAULT_END_DATE = "2025-12-31"
-# Expected VAE input dimension
-# Based on original pipeline:
-# - feature: 158 alpha158 + 158 alpha158_ntrl = 316
-# - feature_ext: 7 market_ext + 7 market_ext_ntrl = 14
-# - feature_flag: 11 columns (after ColumnRemover, FlagMarketInjector; excluding IsST)
-# Total: 316 + 14 + 11 = 341
-#
-# NOTE: The VAE model encode() function takes feature + feature_ext + feature_flag groups
-# (indus_idx is NOT included in VAE input)
-VAE_INPUT_DIM = 341
-
-# Industry flag columns (29 one-hot columns)
-INDUSTRY_FLAG_COLS = [
- 'gds_CC10', 'gds_CC11', 'gds_CC12', 'gds_CC20', 'gds_CC21', 'gds_CC22',
- 'gds_CC23', 'gds_CC24', 'gds_CC25', 'gds_CC26', 'gds_CC27', 'gds_CC28',
- 'gds_CC30', 'gds_CC31', 'gds_CC32', 'gds_CC33', 'gds_CC34', 'gds_CC35',
- 'gds_CC36', 'gds_CC37', 'gds_CC40', 'gds_CC41', 'gds_CC42', 'gds_CC43',
- 'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70'
-]
-
-
-def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame:
- """
- Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE) using qshare spine functions.
- This uses qshare's filter_instruments which loads the instrument list from:
- /data/qlib/default/data_ops/target/instruments/csiallx.txt
+def load_vae_model(model_path: str) -> nn.Module:
+ """
+ Load the VAE model from file.
Args:
- df: Input DataFrame with datetime and instrument columns
- instruments: Market name for spine creation (default: 'csiallx')
+ model_path: Path to the pickled VAE model
Returns:
- Filtered DataFrame with only instruments in the specified universe
+ Loaded VAE model in eval mode on CPU
"""
- from qshare.algo.polars.spine import filter_instruments
-
- # Use qshare's filter_instruments with csiallx market name
- df = filter_instruments(df, instruments=instruments)
-
- return df
-
-# Alpha158 feature columns in EXPLICIT ORDER
-# These are the 158 alpha158 features in the order they appear in the parquet file
-# This order MUST match the order used when training the VAE model
-ALPHA158_COLS = [
- 'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
- 'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
- 'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
- 'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
- 'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
- 'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
- 'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
- 'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
- 'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
- 'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
- 'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
- 'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
- 'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
- 'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
- 'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
- 'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
- 'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
- 'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
- 'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
- 'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
- 'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
- 'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
- 'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
- 'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
- 'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
- 'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
- 'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
- 'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
- 'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
- 'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
- 'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
-]
-# Verify we have 158 features
-assert len(ALPHA158_COLS) == 158, f"Expected 158 alpha158 cols, got {len(ALPHA158_COLS)}"
-
-# Market extension columns - MUST match original qlib HANDLER_MARKET_EXT config
-# Original config loads:
-# 'Turnover as turnover', 'FreeTurnover as free_turnover',
-# 'log(MarketValue) as log_size', 'con_rating_strength'
-#
-# We use lowercase names to match the original pipeline exactly.
-# NOTE: con_rating_strength is not available in parquet, so we'll create it as zeros.
-MARKET_EXT_RAW_COLS = ['Turnover', 'FreeTurnover', 'MarketValue'] # Raw columns from parquet
-MARKET_EXT_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] # Final names
-
-# Market flag columns (before processors)
-# According to HANDLER_MARKET_FLAG in qlib config:
-# From stg_1day_wind_kline_adjusted: IsZt, IsDt, IsN, IsXD, IsXR, IsDR (boolean)
-# From stg_1day_wind_market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (boolean)
-MARKET_FLAG_COLS_KLINE = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR']
-MARKET_FLAG_COLS_MARKET = ['open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop']
-
-
-def get_date_partitions(start_date: str, end_date: str) -> List[str]:
- """Generate a list of date partitions to load from Parquet."""
- start = datetime.strptime(start_date, "%Y-%m-%d")
- end = datetime.strptime(end_date, "%Y-%m-%d")
-
- partitions = []
- current = start
- while current <= end:
- if current.weekday() < 5:
- partitions.append(f"datetime={current.strftime('%Y%m%d')}")
- current = datetime(current.year, current.month, current.day + 1)
-
- return partitions
-
-
-def load_parquet_by_date_range(
- base_path: str,
- start_date: str,
- end_date: str,
- columns: Optional[List[str]] = None
-) -> pl.DataFrame:
- """Load parquet data filtered by date range."""
- start_int = int(start_date.replace("-", ""))
- end_int = int(end_date.replace("-", ""))
+ print(f"Loading VAE model from {model_path}...")
+
+ # Patch torch.load to use CPU
+ original_torch_load = torch.load
+ def cpu_torch_load(*args, **kwargs):
+ kwargs['map_location'] = 'cpu'
+ return original_torch_load(*args, **kwargs)
+ torch.load = cpu_torch_load
try:
- df = pl.scan_parquet(base_path)
+ with open(model_path, "rb") as fin:
+ model = pkl.load(fin)
- # Filter by date range
- df = df.filter(
- (pl.col('datetime') >= start_int) &
- (pl.col('datetime') <= end_int)
- )
+ model.eval()
+ print(f"Loaded VAE model: {model.__class__.__name__}")
+ print(f" Input size: {model.input_size}")
+ print(f" Hidden size: {model.hidden_size}")
- # Select specific columns if provided
- if columns:
- available_cols = ['instrument', 'datetime'] + [c for c in columns if c not in ['instrument', 'datetime']]
- df = df.select(available_cols)
+ return model
- return df.collect()
- except Exception as e:
- print(f"Error loading from {base_path}: {e}")
- return pl.DataFrame()
+ finally:
+ torch.load = original_torch_load
-def load_all_data(
- start_date: str,
- end_date: str
-) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]:
+def encode_with_vae(features: np.ndarray, model: nn.Module, batch_size: int = 5000) -> np.ndarray:
"""
- Load all data sources from Parquet.
-
- According to original HANDLER_MARKET_EXT and HANDLER_MARKET_FLAG configs:
- - alpha158: 158 features
- - market_ext: turnover, free_turnover, log_size (=log(MarketValue)), con_rating_strength
- - market_flag: IsZt, IsDt, IsN, IsXD, IsXR, IsDR + open_limit, close_limit, low_limit, open_stop, close_stop, high_stop
- - indus_flag: 29 industry flags
+ Encode features using the VAE model.
- NOTE: con_rating_strength is not available in parquet, so we create it as zeros (placeholder).
+ Args:
+ features: Input features of shape (n_samples, VAE_INPUT_DIM)
+ model: VAE model with encode() method
+ batch_size: Batch size for inference
Returns:
- Tuple of (alpha158_df, market_ext_df, market_flag_df, industry_df)
+ Embeddings of shape (n_samples, 32)
"""
- print(f"Loading data from {start_date} to {end_date}...")
-
- # 1. Load Alpha158 beta factors (158 features)
- print("Loading alpha158_0_7_beta factors...")
- df_alpha = load_parquet_by_date_range(PARQUET_ALPHA158_BETA_PATH, start_date, end_date)
- print(f" Alpha158 shape: {df_alpha.shape}")
-
- # 2. Load Kline data for market_ext columns
- # Original config: 'Turnover as turnover', 'FreeTurnover as free_turnover',
- # 'log(MarketValue) as log_size', 'con_rating_strength'
- # We load raw columns and transform them
- print("Loading kline data (market ext columns)...")
- kline_cols = ['Turnover', 'FreeTurnover', 'MarketValue']
- df_kline = load_parquet_by_date_range(PARQUET_KLINE_PATH, start_date, end_date, kline_cols)
- print(f" Kline (market ext raw) shape: {df_kline.shape}")
-
- # 3. Load con_rating_strength from parquet
- print("Loading con_rating_strength from parquet...")
- df_con_rating = load_parquet_by_date_range(
- PARQUET_CON_RATING_PATH, start_date, end_date, ['con_rating_strength']
- )
- print(f" Con rating shape: {df_con_rating.shape}")
-
- # Transform market_ext columns to match original pipeline:
- # - Turnover -> turnover (rename)
- # - FreeTurnover -> free_turnover (rename)
- # - MarketValue -> log_size = log(MarketValue)
- # - con_rating_strength: loaded from parquet (will merge below)
- print("Transforming market_ext columns...")
- df_kline = df_kline.with_columns([
- pl.col('Turnover').alias('turnover'),
- pl.col('FreeTurnover').alias('free_turnover'),
- pl.col('MarketValue').log().alias('log_size'),
- ])
- print(f" Kline (market ext transformed) shape: {df_kline.shape}")
-
- # Merge con_rating_strength into kline dataframe
- df_kline = df_kline.join(df_con_rating, on=['instrument', 'datetime'], how='left')
- # Fill NaN with 0 for instruments/dates without con_rating data
- df_kline = df_kline.with_columns([
- pl.col('con_rating_strength').fill_null(0.0)
- ])
- print(f" Kline (with con_rating) shape: {df_kline.shape}")
-
- # 4. Load Market Flag data from kline_adjusted (all 6 columns)
- print("Loading market flags from kline_adjusted...")
- kline_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR']
- df_kline_flag = load_parquet_by_date_range(PARQUET_KLINE_PATH, start_date, end_date, kline_flag_cols)
- print(f" Kline flags shape: {df_kline_flag.shape}")
-
- # 5. Load Market Flag data from market_flag table (ALL 6 columns as per original config)
- print("Loading market flags from market_flag table (6 cols)...")
- market_flag_cols = ['open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop']
- df_market_flag = load_parquet_by_date_range(PARQUET_MARKET_FLAG_PATH, start_date, end_date, market_flag_cols)
- print(f" Market flag shape: {df_market_flag.shape}")
-
- # 6. Load Industry flags
- print("Loading industry flags...")
- df_industry = load_parquet_by_date_range(PARQUET_INDUSTRY_FLAG_PATH, start_date, end_date, INDUSTRY_FLAG_COLS)
- print(f" Industry shape: {df_industry.shape}")
-
- # Merge kline flag and market flag
- df_flag = df_kline_flag.join(df_market_flag, on=['instrument', 'datetime'], how='inner')
- print(f" Combined flags shape: {df_flag.shape}")
-
- return df_alpha, df_kline, df_flag, df_industry
-
-
-def merge_data_sources(
- df_alpha: pl.DataFrame,
- df_kline: pl.DataFrame,
- df_flag: pl.DataFrame,
- df_industry: pl.DataFrame
-) -> pl.DataFrame:
- """Merge all data sources on instrument and datetime."""
- print("Merging data sources...")
-
- # Start with alpha158
- df = df_alpha
-
- # Merge kline data (market_ext with transformed columns)
- # df_kline now has: turnover, free_turnover, log_size, con_rating_strength
- df = df.join(df_kline, on=['instrument', 'datetime'], how='inner')
-
- # Merge flags (kline_flag + market_flag)
- df = df.join(df_flag, on=['instrument', 'datetime'], how='inner')
-
- # Merge industry flags
- df = df.join(df_industry, on=['instrument', 'datetime'], how='inner')
-
- print(f"Merged data shape (before filter): {df.shape}")
-
- # Apply stock universe filter to match csiallx universe
- # This is CRITICAL for correct industry neutralization:
- # - Must use the same stock universe as the original pipeline
- # - Industry means are calculated per datetime across this universe
- df = filter_stock_universe(df)
-
- print(f"Merged data shape (after csiallx filter): {df.shape}")
- return df
-
+ print(f"Encoding {features.shape[0]} samples with VAE...")
-class DiffProcessor:
- """
- Diff Processor: Calculate diff features for market_ext columns.
- For each column in feature_ext, calculate diff with period=1 within each instrument group.
- """
- def __init__(self, columns: List[str]):
- self.columns = columns
+ device = torch.device('cpu')
+ model = model.to(device)
+ model.eval()
- def process(self, df: pl.DataFrame) -> pl.DataFrame:
- """Add diff features for specified columns."""
- print("Applying Diff processor...")
+ all_embeddings = []
- # Sort by instrument and datetime
- df = df.sort(['instrument', 'datetime'])
+ with torch.no_grad():
+ for i in range(0, len(features), batch_size):
+ batch = features[i:i + batch_size]
+ batch_tensor = torch.tensor(batch, dtype=torch.float32, device=device)
- # Add diff for each column
- for col in self.columns:
- if col in df.columns:
- diff_col = f"{col}_diff"
- df = df.with_columns([
- pl.col(col)
- .diff()
- .over('instrument')
- .alias(diff_col)
- ])
+ # Use model.encode() to get mu (the embedding)
+ mu, _ = model.encode(batch_tensor)
- return df
+ # Convert to numpy
+ embeddings_np = mu.cpu().numpy()
+ all_embeddings.append(embeddings_np)
+ if (i // batch_size + 1) % 10 == 0:
+ print(f" Processed {min(i + batch_size, len(features))}/{len(features)} samples...")
-class FlagMarketInjector:
- """
- Flag Market Injector: Create market_0, market_1 columns based on instrument code.
+ embeddings = np.concatenate(all_embeddings, axis=0)
+ print(f"Generated embeddings shape: {embeddings.shape}")
- Maps to Qlib's map_market_sec logic with vocab_size=2:
- - market_0 (主板): SH60xxx, SZ00xxx
- - market_1 (科创板/创业板): SH688xxx, SH689xxx, SZ300xxx, SZ301xxx
+ return embeddings
- NOTE: vocab_size=2 (not 3!) - the original qlib pipeline does NOT include
- 新三板/北交所 (NE4xxxx, NE8xxxx) in the market classification.
- This uses the gds encoding where:
- - 6xxxxx -> SH main board
- - 0xxxxx, 3xxxxx -> SZ (main/ChiNext)
- - 4xxxxx, 8xxxxx -> NE (新三板/北交所) - NOT included in vocab_size=2
+def prepare_vae_features(
+ feature_groups: FeatureGroups,
+ exclude_isst: bool = True
+) -> Tuple[np.ndarray, List[str]]:
"""
- def process(self, df: pl.DataFrame) -> pl.DataFrame:
- """Add market_0, market_1 columns."""
- print("Applying FlagMarketInjector (vocab_size=2)...")
+ Prepare features for VAE encoding from FeatureGroups.
- # Convert instrument to string and pad to 6 digits
- inst_str = pl.col('instrument').cast(pl.String).str.zfill(6)
-
- # Determine market type based on first digit
- # vocab_size=2: only market_0 (主板) and market_1 (科创/创业)
- is_sh_main = inst_str.str.starts_with('6') # SH600xxx, SH601xxx, etc.
- is_sz_main = inst_str.str.starts_with('0') | inst_str.str.starts_with('00') # SZ000xxx
- is_sh_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689') # SH688xxx, SH689xxx
- is_sz_chi = inst_str.str.starts_with('300') | inst_str.str.starts_with('301') # SZ300xxx, SZ301xxx
-
- df = df.with_columns([
- # market_0 = 主板 (SH main + SZ main)
- (is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'),
- # market_1 = 科创板 + 创业板 (SH star + SZ ChiNext)
- (is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1')
- ])
+ VAE input structure (341 features):
+ - feature group (316): 158 alpha158 + 158 alpha158_ntrl
+ - feature_ext group (14): 7 market_ext + 7 market_ext_ntrl
+ - feature_flag group (11): market flags (excluding IsST)
- return df
+ NOTE: indus_idx is NOT included in VAE input.
+ Args:
+ feature_groups: Transformed FeatureGroups container
+ exclude_isst: Whether to exclude IsST from VAE input
-class ColumnRemover:
- """
- Column Remover: Drop specific columns.
- Removes: log_size_diff (TotalValue_diff), IsN, IsZt, IsDt
+ Returns:
+ Tuple of (features numpy array, list of embedding column names)
"""
- def __init__(self, columns_to_remove: List[str]):
- self.columns_to_remove = columns_to_remove
-
- def process(self, df: pl.DataFrame) -> pl.DataFrame:
- """Remove specified columns."""
- print(f"Applying ColumnRemover (removing {len(self.columns_to_remove)} columns)...")
-
- # Only remove columns that exist
- cols_to_drop = [c for c in self.columns_to_remove if c in df.columns]
- if cols_to_drop:
- df = df.drop(cols_to_drop)
-
- return df
+ print("Preparing features for VAE...")
+ # Merge all groups for final feature extraction
+ df = feature_groups.merge_for_processors()
-class FlagToOnehot:
- """
- Flag To Onehot: Convert 29 one-hot industry columns to single indus_idx.
- For each row, find which industry column is True/1 and set indus_idx to that index.
- """
- def __init__(self, industry_cols: List[str]):
- self.industry_cols = industry_cols
-
- def process(self, df: pl.DataFrame) -> pl.DataFrame:
- """Convert industry flags to single indus_idx column."""
- print("Applying FlagToOnehot (converting 29 industry flags to indus_idx)...")
+ # Build alpha158 feature columns
+ alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
+ alpha158_cols = ALPHA158_COLS.copy()
- # Build a when/then chain to find the industry index
- # Start with -1 (no industry) as default
- indus_expr = pl.lit(-1)
+ # Build market_ext feature columns (with diff, minus removed columns)
+ market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
+ market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
+ market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
+ market_ext_cols = market_ext_with_diff.copy()
- for idx, col in enumerate(self.industry_cols):
- if col in df.columns:
- indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr)
+ # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
+ norm_feature_cols = (
+ alpha158_ntrl_cols + alpha158_cols +
+ market_ext_ntrl_cols + market_ext_cols
+ )
- df = df.with_columns([indus_expr.alias('indus_idx')])
+ # Market flag columns (excluding IsST if requested)
+ # After ColumnRemover removes IsN, IsZt, IsDt:
+ # - From kline_adjusted: IsXD, IsXR, IsDR (3 cols)
+ # - From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (6 cols)
+ # - Added by FlagMarketInjector: market_0, market_1 (2 cols)
+ # - Added by FlagSTInjector: IsST (1 col, excluded from VAE)
+ # - Total: 3 + 6 + 2 = 11 flags (excluding IsST)
+ market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
+ market_flag_cols += ['market_0', 'market_1']
+ if not exclude_isst:
+ market_flag_cols.append('IsST')
+ market_flag_cols = list(dict.fromkeys(market_flag_cols))
- # Drop the original one-hot columns
- cols_to_drop = [c for c in self.industry_cols if c in df.columns]
- if cols_to_drop:
- df = df.drop(cols_to_drop)
+ # Combine all VAE input columns
+ vae_cols = norm_feature_cols + market_flag_cols
- return df
+ print(f" norm_feature_cols: {len(norm_feature_cols)}")
+ print(f" market_flag_cols: {len(market_flag_cols)}")
+ print(f" Total VAE input columns: {len(vae_cols)}")
+ # Verify all columns exist
+ missing_cols = [c for c in vae_cols if c not in df.columns]
+ if missing_cols:
+ print(f"WARNING: Missing columns: {missing_cols}")
+ # Add missing columns as zeros
+ for col in missing_cols:
+ df = df.with_columns(pl.lit(0).alias(col))
-class IndusNtrlInjector:
- """
- Industry Neutralization Injector: Industry neutralization for features.
- For each feature, subtract the industry mean (grouped by indus_idx) from the feature value.
- Creates new columns with "_ntrl" suffix while keeping original columns.
+ # Select features and convert to numpy
+ features_df = df.select(vae_cols)
+ features = features_df.to_numpy().astype(np.float32)
- IMPORTANT: Industry neutralization must be done PER DATETIME (cross-sectional),
- not across the entire dataset. This matches qlib's cal_indus_ntrl behavior.
- """
- def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'):
- self.feature_cols = feature_cols
- self.suffix = suffix
+ # Handle any remaining NaN/Inf values
+ features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
- def process(self, df: pl.DataFrame) -> pl.DataFrame:
- """Apply industry neutralization to specified features."""
- print(f"Applying IndusNtrlInjector to {len(self.feature_cols)} features...")
+ print(f"Feature matrix shape: {features.shape}")
- # Filter to only columns that exist
- existing_cols = [c for c in self.feature_cols if c in df.columns]
+ # Verify dimensions
+ if features.shape[1] != VAE_INPUT_DIM:
+ print(f"WARNING: Expected {VAE_INPUT_DIM} features, got {features.shape[1]}")
+ diff = VAE_INPUT_DIM - features.shape[1]
+ if diff > 0:
+ print(f" Difference: {diff} columns missing")
+ else:
+ print(f" Difference: {-diff} extra columns")
- for col in existing_cols:
- ntrl_col = f"{col}{self.suffix}"
- # Calculate industry mean PER DATETIME and subtract from feature
- # This is the CORRECT cross-sectional neutralization
- df = df.with_columns([
- (pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'])).alias(ntrl_col)
- ])
+ # Generate embedding column names
+ embedding_cols = [f"embedding_{i}" for i in range(32)]
- return df
+ return features, embedding_cols
-class RobustZScoreNorm:
+def prepare_vae_features_from_df(
+ df: pl.DataFrame,
+ exclude_isst: bool = True
+) -> Tuple[np.ndarray, List[str]]:
"""
- Robust Z-Score Normalization: Per datetime normalization.
- (x - median) / (1.4826 * MAD) where MAD = median(|x - median|)
- Clip outliers at [-3, 3].
-
- Can use pre-fitted parameters from qlib's pickled processor:
- # Load from qlib pickle
- with open('proc_list.proc', 'rb') as f:
- proc_list = pickle.load(f)
- zscore_proc = proc_list[7] # RobustZScoreNorm is 8th processor
-
- # Create with pre-fitted parameters
- normalizer = RobustZScoreNorm(
- feature_cols=feature_cols,
- use_qlib_params=True,
- qlib_mean=zscore_proc.mean_train,
- qlib_std=zscore_proc.std_train
- )
- """
- def __init__(self, feature_cols: List[str],
- clip_range: Tuple[float, float] = (-3, 3),
- use_qlib_params: bool = False,
- qlib_mean: Optional[np.ndarray] = None,
- qlib_std: Optional[np.ndarray] = None):
- self.feature_cols = feature_cols
- self.clip_range = clip_range
- self.use_qlib_params = use_qlib_params
- self.mean_train = qlib_mean
- self.std_train = qlib_std
-
- if use_qlib_params:
- if qlib_mean is None or qlib_std is None:
- raise ValueError("Must provide qlib_mean and qlib_std when use_qlib_params=True")
- print(f"Using pre-fitted qlib parameters (mean shape: {qlib_mean.shape}, std shape: {qlib_std.shape})")
-
- def process(self, df: pl.DataFrame) -> pl.DataFrame:
- """Apply robust z-score normalization."""
- print(f"Applying RobustZScoreNorm to {len(self.feature_cols)} features...")
-
- # Filter to only columns that exist
- existing_cols = [c for c in self.feature_cols if c in df.columns]
-
- if self.use_qlib_params:
- # Use pre-fitted parameters from qlib (fit once, apply to all dates)
- for i, col in enumerate(existing_cols):
- if i < len(self.mean_train):
- mean_val = float(self.mean_train[i])
- std_val = float(self.std_train[i])
-
- # Apply z-score normalization using pre-fitted params
- df = df.with_columns([
- ((pl.col(col) - mean_val) / (std_val + 1e-8))
- .clip(self.clip_range[0], self.clip_range[1])
- .alias(col)
- ])
- else:
- # Compute per-datetime robust z-score (original behavior)
- for col in existing_cols:
- # First compute median per datetime as a new column
- median_col = f"__median_{col}"
- df = df.with_columns([
- pl.col(col).median().over('datetime').alias(median_col)
- ])
-
- # Then compute absolute deviation
- abs_dev_col = f"__absdev_{col}"
- df = df.with_columns([
- (pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col)
- ])
-
- # Compute MAD (median of absolute deviations)
- mad_col = f"__mad_{col}"
- df = df.with_columns([
- pl.col(abs_dev_col).median().over('datetime').alias(mad_col)
- ])
-
- # Compute robust z-score and clip
- df = df.with_columns([
- ((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8))
- .clip(self.clip_range[0], self.clip_range[1])
- .alias(col)
- ])
-
- # Clean up temporary columns
- df = df.drop([median_col, abs_dev_col, mad_col])
-
- return df
-
-
-class Fillna:
- """
- Fill NaN: Fill all NaN values with 0 for numeric columns.
- """
- def process(self, df: pl.DataFrame, feature_cols: List[str]) -> pl.DataFrame:
- """Fill NaN values with 0 for specified columns."""
- print("Applying Fillna processor...")
+ Prepare features for VAE encoding from a merged DataFrame.
- # Filter to only columns that exist and are numeric (not boolean)
- existing_cols = [c for c in feature_cols if c in df.columns]
-
- for col in existing_cols:
- # Check column dtype
- dtype = df[col].dtype
- if dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64, pl.UInt32, pl.UInt64]:
- df = df.with_columns([pl.col(col).fill_null(0.0).fill_nan(0.0)])
-
- return df
+ VAE input structure (341 features):
+ - feature group (316): 158 alpha158 + 158 alpha158_ntrl
+ - feature_ext group (14): 7 market_ext + 7 market_ext_ntrl
+ - feature_flag group (11): market flags (excluding IsST)
+ NOTE: indus_idx is NOT included in VAE input.
-def apply_feature_pipeline(df: pl.DataFrame) -> Tuple[pl.DataFrame, List[str]]:
- """
- Apply the full feature transformation pipeline.
+ Args:
+ df: Transformed merged DataFrame
+ exclude_isst: Whether to exclude IsST from VAE input
Returns:
- Tuple of (processed DataFrame, list of final feature columns)
+ Tuple of (features numpy array, list of embedding column names)
"""
- print("=" * 60)
- print("Starting feature transformation pipeline")
- print("=" * 60)
+ print("Preparing features for VAE...")
- # Use EXPLICIT alpha158 column order (158 features)
- # This order MUST match what the VAE was trained with
+ # Build alpha158 feature columns
+ alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
alpha158_cols = ALPHA158_COLS.copy()
- # market_ext: 4 features - MUST match original HANDLER_MARKET_EXT config
- # Original: 'Turnover as turnover', 'FreeTurnover as free_turnover',
- # 'log(MarketValue) as log_size', 'con_rating_strength'
- # We already transformed these in load_all_data(), so use lowercase names
- market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
-
- # market_flag: ALL 12 columns as per original HANDLER_MARKET_FLAG config
- # From kline_adjusted: IsZt, IsDt, IsN, IsXD, IsXR, IsDR (6 cols)
- # From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (6 cols)
- market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
- 'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop']
-
- print(f"Initial column counts:")
- print(f" Alpha158 features: {len(alpha158_cols)}")
- print(f" Market ext base: {len(market_ext_base)}")
- print(f" Market flag: {len(market_flag_cols)}")
- print(f" Industry flags: {len(INDUSTRY_FLAG_COLS)}")
-
- # Step 1: Diff Processor - adds diff features for market_ext
- diff_processor = DiffProcessor(market_ext_base)
- df = diff_processor.process(df)
-
- # After Diff: market_ext becomes 8 columns (4 base + 4 diff)
- market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base]
-
- # Step 2: FlagMarketInjector - adds market_0, market_1 (vocab_size=2)
- flag_injector = FlagMarketInjector()
- df = flag_injector.process(df)
-
- # After FlagMarketInjector: market_flag = 12 + 2 = 14 columns
- market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
-
- # Step 3: FlagSTInjector - create IsST from ST flags
- # Note: ST flags (ST_Y, ST_S) may not be available in parquet data.
- # If available, IsST = ST_S | ST_Y; otherwise create placeholder (all zeros).
- # This maintains compatibility with the VAE's expected input dimension.
- print("Applying FlagSTInjector (creating IsST)...")
- # Check if ST flags are available
- if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns:
- # Create IsST from actual ST flags
- df = df.with_columns([
- ((pl.col('ST_S').cast(pl.Boolean, strict=False) |
- pl.col('ST_Y').cast(pl.Boolean, strict=False))
- .cast(pl.Int8).alias('IsST'))
- ])
- else:
- # Create placeholder (all zeros) if ST flags not available
- df = df.with_columns([
- pl.lit(0).cast(pl.Int8).alias('IsST')
- ])
- market_flag_with_st = market_flag_with_market + ['IsST']
-
- # Step 4: ColumnRemover - remove specific columns
- # Qlib ColumnRemover removes: ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
- columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
- remover = ColumnRemover(columns_to_remove)
- df = remover.process(df)
-
- # Update column lists after removal
- market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove]
- market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove]
-
- # Step 5: FlagToOnehot - convert 29 industry flags to indus_idx
- flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
- df = flag_to_onehot.process(df)
-
- print(f"After FlagToOnehot: industry flags -> indus_idx")
-
- # Step 6 & 7: IndusNtrlInjector - industry neutralization for alpha158 and market_ext
- indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl')
- df = indus_ntrl_alpha.process(df)
-
- indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl')
- df = indus_ntrl_ext.process(df)
-
- # After IndusNtrlInjector: each feature gets a _ntrl version
- # IMPORTANT: qlib's IndusNtrlInjector with keep_origin=True produces columns in order
- # [all _ntrl] + [all raw] for EACH feature group, NOT [all raw] + [all _ntrl]
- # This is critical for matching the VAE training feature order!
- alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols]
- market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols]
-
- # Step 8: RobustZScoreNorm - robust z-score normalization
- # Qlib applies RobustZScoreNorm ONLY to ['feature', 'feature_ext'] groups
- # NOT to feature_flag columns (binary flags should not be normalized)
- # NOT to indus_idx (single column industry index)
- #
- # Feature order MUST match what the VAE was trained with:
- # [alpha158_ntrl (158), alpha158 (158), market_ext_ntrl (7), market_ext (7)] = 330 features
- # This order comes from qlib's IndusNtrlInjector which outputs [ntrl] + [raw] for each group
- norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
-
- print(f"Applying RobustZScoreNorm to {len(norm_feature_cols)} features...")
- print(f" (Excluding {len(market_flag_with_st)} market flags and indus_idx)")
-
- # Load pre-fitted qlib parameters for consistent normalization
- qlib_params = load_qlib_processor_params()
-
- # Verify parameter shape matches expected features
- expected_features = len(norm_feature_cols)
- if qlib_params['mean_train'].shape[0] != expected_features:
- print(f"WARNING: Feature count mismatch! Expected {expected_features}, "
- f"but qlib params have {qlib_params['mean_train'].shape[0]}")
- print(f" This means the feature order/columns may not match what the VAE was trained with.")
-
- robust_norm = RobustZScoreNorm(
- norm_feature_cols,
- clip_range=(-3, 3),
- use_qlib_params=True,
- qlib_mean=qlib_params['mean_train'],
- qlib_std=qlib_params['std_train']
- )
- df = robust_norm.process(df)
-
- # Step 9: Fillna - fill NaN with 0 for ALL feature columns
- # This includes normalized features, market flags, and indus_idx
- #
- # IMPORTANT: IsST is a placeholder (all zeros) and should NOT be included in VAE input.
- # The VAE was trained with 11 market flags (excluding IsST).
- #
- # Define final feature list first
- final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx']
-
- fillna = Fillna()
- df = fillna.process(df, final_feature_cols)
-
- # Final feature list breakdown for VAE input:
- # The VAE model takes feature, feature_ext, feature_flag groups (indus_idx is separate)
- # After ColumnRemover removes IsN, IsZt, IsDt:
- # - From kline_adjusted: IsXD, IsXR, IsDR (3 cols)
- # - From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (6 cols)
- # - Added by FlagMarketInjector: market_0, market_1 (2 cols)
- # - Added by FlagSTInjector: IsST (1 col, placeholder if ST flags not available)
- # - Total market flags: 3 + 6 + 2 + 1 = 12 (IsST excluded from VAE input)
- #
- # Total features:
- # - norm_feature_cols: 158 + 158 + 7 + 7 = 330
- # - market_flag_with_st: 12 (including IsST)
- # - indus_idx: 1
- # - Total: 330 + 12 + 1 = 343 features
- #
- # VAE input dimension (feature + feature_ext + feature_flag only, no indus_idx):
- # - 316 (alpha158 + ntrl) + 14 (market_ext + ntrl) + 11 (flags, excluding IsST) = 341
-
- # Exclude IsST from VAE input features (it's a placeholder)
- market_flag_for_vae = [c for c in market_flag_with_st if c != 'IsST']
-
- print("=" * 60)
- print(f"Pipeline complete. Final feature count: {len(final_feature_cols)}")
- print(f"Expected VAE input dim: {VAE_INPUT_DIM}")
- print(f" norm_feature_cols: {len(norm_feature_cols)}")
- print(f" market_flag_for_vae (excluding IsST): {len(market_flag_for_vae)}")
- print(f" indus_idx: 1")
- print("=" * 60)
-
- # Verify we have the expected number of features
- vae_feature_count = len(norm_feature_cols) + len(market_flag_for_vae)
- if vae_feature_count != VAE_INPUT_DIM:
- print(f"WARNING: Feature count mismatch! Expected {VAE_INPUT_DIM}, got {vae_feature_count}")
- print(f"Difference: {vae_feature_count - VAE_INPUT_DIM} columns")
- print(f"Market flag columns for VAE ({len(market_flag_for_vae)}): {market_flag_for_vae}")
- else:
- print(f"✓ Feature count matches VAE input dimension!")
-
- # Return additional lists needed for VAE feature preparation
- return df, final_feature_cols, norm_feature_cols, market_flag_for_vae
-
+ # Build market_ext feature columns (with diff, minus removed columns)
+ market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
+ market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
+ market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
+ market_ext_cols = market_ext_with_diff.copy()
-def prepare_vae_features(df: pl.DataFrame, feature_cols: List[str],
- norm_feature_cols: List[str],
- market_flag_for_vae: List[str]) -> np.ndarray:
- """
- Prepare features for VAE encoding.
- Ensure we have exactly VAE_INPUT_DIM features in the correct order.
-
- VAE input structure (341 features):
- - feature group (316): 158 alpha158 + 158 alpha158_ntrl
- - feature_ext group (14): 7 market_ext + 7 market_ext_ntrl
- - feature_flag group (11): market flags (excluding IsST which is a placeholder)
-
- NOTE: indus_idx is NOT included in VAE input (it's used separately by the model).
+ # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
+ norm_feature_cols = (
+ alpha158_ntrl_cols + alpha158_cols +
+ market_ext_ntrl_cols + market_ext_cols
+ )
- Args:
- df: Processed DataFrame
- feature_cols: All feature columns (including indus_idx and IsST)
- norm_feature_cols: Normalized feature columns (330 features)
- market_flag_for_vae: Market flag columns for VAE (11 features, excluding IsST)
- """
- print("Preparing features for VAE...")
+ # Market flag columns (excluding IsST if requested)
+ market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
+ market_flag_cols += ['market_0', 'market_1']
+ if not exclude_isst:
+ market_flag_cols.append('IsST')
+ market_flag_cols = list(dict.fromkeys(market_flag_cols))
- # Construct VAE input columns explicitly in correct order:
- # [norm_feature_cols (330), market_flag_for_vae (11)] = 341 total
- vae_cols = norm_feature_cols + market_flag_for_vae
+ # Combine all VAE input columns
+ vae_cols = norm_feature_cols + market_flag_cols
print(f" norm_feature_cols: {len(norm_feature_cols)}")
- print(f" market_flag_for_vae: {len(market_flag_for_vae)}")
+ print(f" market_flag_cols: {len(market_flag_cols)}")
print(f" Total VAE input columns: {len(vae_cols)}")
# Verify all columns exist
missing_cols = [c for c in vae_cols if c not in df.columns]
if missing_cols:
print(f"WARNING: Missing columns: {missing_cols}")
+ # Add missing columns as zeros
+ for col in missing_cols:
+ df = df.with_columns(pl.lit(0).alias(col))
- # Select features
+ # Select features and convert to numpy
features_df = df.select(vae_cols)
-
- # Convert to numpy
features = features_df.to_numpy().astype(np.float32)
# Handle any remaining NaN/Inf values
@@ -810,89 +290,37 @@ def prepare_vae_features(df: pl.DataFrame, feature_cols: List[str],
# Verify dimensions
if features.shape[1] != VAE_INPUT_DIM:
print(f"WARNING: Expected {VAE_INPUT_DIM} features, got {features.shape[1]}")
-
- if features.shape[1] < VAE_INPUT_DIM:
- # Pad with zeros
- padding = np.zeros((features.shape[0], VAE_INPUT_DIM - features.shape[1]), dtype=np.float32)
- features = np.concatenate([features, padding], axis=1)
- print(f"Padded to shape: {features.shape}")
+ diff = VAE_INPUT_DIM - features.shape[1]
+ if diff > 0:
+ print(f" Difference: {diff} columns missing")
else:
- # Truncate
- features = features[:, :VAE_INPUT_DIM]
- print(f"Truncated to shape: {features.shape}")
-
- return features
-
-
-def load_vae_model(model_path: str) -> nn.Module:
- """
- Load the VAE model from file.
- """
- print(f"Loading VAE model from {model_path}...")
-
- # Patch torch.load to use CPU
- original_torch_load = torch.load
- def cpu_torch_load(*args, **kwargs):
- kwargs['map_location'] = 'cpu'
- return original_torch_load(*args, **kwargs)
- torch.load = cpu_torch_load
-
- try:
- with open(model_path, "rb") as fin:
- model = pkl.load(fin)
-
- model.eval()
- print(f"Loaded VAE model: {model.__class__.__name__}")
- print(f" Input size: {model.input_size}")
- print(f" Hidden size: {model.hidden_size}")
-
- return model
-
- finally:
- torch.load = original_torch_load
-
+ print(f" Difference: {-diff} extra columns")
-def encode_with_vae(features: np.ndarray, model: nn.Module, batch_size: int = 5000) -> np.ndarray:
- """
- Encode features using the VAE model.
- """
- print(f"Encoding {features.shape[0]} samples with VAE...")
-
- device = torch.device('cpu')
- model = model.to(device)
- model.eval()
-
- all_embeddings = []
-
- with torch.no_grad():
- for i in range(0, len(features), batch_size):
- batch = features[i:i + batch_size]
- batch_tensor = torch.tensor(batch, dtype=torch.float32, device=device)
-
- # Use model.encode() to get mu (the embedding)
- mu, _ = model.encode(batch_tensor)
-
- # Convert to numpy
- embeddings_np = mu.cpu().numpy()
- all_embeddings.append(embeddings_np)
-
- if (i // batch_size + 1) % 10 == 0:
- print(f" Processed {min(i + batch_size, len(features))}/{len(features)} samples...")
-
- embeddings = np.concatenate(all_embeddings, axis=0)
- print(f"Generated embeddings shape: {embeddings.shape}")
+ # Generate embedding column names
+ embedding_cols = [f"embedding_{i}" for i in range(32)]
- return embeddings
+ return features, embedding_cols
def generate_embeddings(
start_date: str = DEFAULT_START_DATE,
end_date: str = DEFAULT_END_DATE,
output_file: Optional[str] = None,
- use_vae: bool = True
+ use_vae: bool = True,
+ robust_zscore_params_path: Optional[str] = None
) -> pl.DataFrame:
"""
Main function to generate embeddings from alpha158_0_7_beta factors.
+
+ Args:
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+ output_file: Optional output parquet file path
+ use_vae: Whether to use VAE encoding (or random embeddings)
+ robust_zscore_params_path: Optional path to robust zscore parameters
+
+ Returns:
+ DataFrame with datetime, instrument, and embedding columns
"""
print("=" * 60)
print(f"Generating Alpha158 0_7 Beta Embeddings")
@@ -900,25 +328,23 @@ def generate_embeddings(
print(f"Use VAE: {use_vae}")
print("=" * 60)
- # Load all data sources
- df_alpha, df_kline, df_flag, df_industry = load_all_data(start_date, end_date)
+ # Initialize pipeline
+ pipeline = FeaturePipeline(
+ robust_zscore_params_path=robust_zscore_params_path
+ )
- # Merge data sources
- df = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
+ # Load data
+ feature_groups = pipeline.load_data(start_date, end_date)
- # Get datetime and instrument columns before processing
- datetime_col = df['datetime'].clone()
- instrument_col = df['instrument'].clone()
+ # Apply transformations - get merged DataFrame
+ df_transformed = pipeline.transform(feature_groups)
- # Apply feature transformation pipeline
- df_processed, feature_cols, norm_feature_cols, market_flag_for_vae = apply_feature_pipeline(df)
+ # Get datetime and instrument columns from merged DataFrame
+ datetime_col = df_transformed['datetime'].to_list()
+ instrument_col = df_transformed['instrument'].to_list()
- # Prepare features for VAE
- features = prepare_vae_features(
- df_processed, feature_cols,
- norm_feature_cols=norm_feature_cols,
- market_flag_for_vae=market_flag_for_vae
- )
+ # Prepare VAE input features from DataFrame
+ features, embedding_cols = prepare_vae_features_from_df(df_transformed)
# Encode with VAE
if use_vae:
@@ -938,11 +364,9 @@ def generate_embeddings(
embeddings = np.random.randn(features.shape[0], 32).astype(np.float32)
# Create output DataFrame
- embedding_cols = [f"embedding_{i}" for i in range(embeddings.shape[1])]
-
result_data = {
- 'datetime': datetime_col.to_list(),
- 'instrument': instrument_col.to_list()
+ 'datetime': datetime_col,
+ 'instrument': instrument_col
}
for i, col_name in enumerate(embedding_cols):
result_data[col_name] = embeddings[:, i].tolist()
@@ -961,16 +385,17 @@ def generate_embeddings(
return df_result
-def load_qlib_processor_params(proc_path: str = None) -> Dict[str, np.ndarray]:
+def load_qlib_processor_params(
+ proc_path: str = None
+) -> dict:
"""
Load pre-fitted processor parameters from qlib's pickle file.
- This demonstrates how to extract the fitted mean/std from qlib's
- RobustZScoreNorm processor for use in standalone code.
+ This is kept for backwards compatibility and reference.
+ The new pipeline uses load_robust_zscore_params() instead.
Args:
- proc_path: Path to qlib's proc_list.proc file.
- If None, uses the path from the original VAE model.
+ proc_path: Path to qlib's proc_list.proc file
Returns:
Dictionary with 'mean_train' and 'std_train' numpy arrays
@@ -983,7 +408,7 @@ def load_qlib_processor_params(proc_path: str = None) -> Dict[str, np.ndarray]:
with open(proc_path, "rb") as fin:
proc_list = pkl.load(fin)
- # Find RobustZScoreNorm processor (index 7 in the list)
+ # Find RobustZScoreNorm processor
zscore_proc = None
for proc in proc_list:
if type(proc).__name__ == "RobustZScoreNorm":
@@ -1008,50 +433,32 @@ def load_qlib_processor_params(proc_path: str = None) -> Dict[str, np.ndarray]:
return params
-# Example usage function
-def generate_embeddings_with_qlib_params(
- start_date: str = DEFAULT_START_DATE,
- end_date: str = DEFAULT_END_DATE,
- output_file: Optional[str] = None
-) -> pl.DataFrame:
- """
- Example of how to use pre-fitted qlib parameters for normalization.
-
- This is an alternative to generate_embeddings() that uses the exact
- same normalization parameters as the original qlib pipeline.
- """
- # Load the pre-fitted parameters
- qlib_params = load_qlib_processor_params()
-
- # Load data (same as regular pipeline)
- df_alpha, df_kline, df_industry = load_all_data(start_date, end_date)
- df = merge_data_sources(df_alpha, df_kline, df_industry)
-
- datetime_col = df['datetime'].clone()
- instrument_col = df['instrument'].clone()
-
- # Process through pipeline, but use qlib params for normalization
- # (This would require modifying apply_feature_pipeline to accept params)
- # For now, this is a demonstration of the pattern
-
- print("\nNote: To use qlib params, modify apply_feature_pipeline() to accept")
- print("qlib_mean and qlib_std arguments and pass them to RobustZScoreNorm")
-
- return df
-
-
if __name__ == "__main__":
import argparse
- parser = argparse.ArgumentParser(description="Generate embeddings from alpha158_0_7_beta factors")
- parser.add_argument("--start-date", type=str, default=DEFAULT_START_DATE,
- help="Start date (YYYY-MM-DD)")
- parser.add_argument("--end-date", type=str, default=DEFAULT_END_DATE,
- help="End date (YYYY-MM-DD)")
- parser.add_argument("--output", type=str, default=None,
- help="Output parquet file path")
- parser.add_argument("--no-vae", action="store_true",
- help="Skip VAE encoding (use random embeddings for testing)")
+ parser = argparse.ArgumentParser(
+ description="Generate embeddings from alpha158_0_7_beta factors"
+ )
+ parser.add_argument(
+ "--start-date", type=str, default=DEFAULT_START_DATE,
+ help="Start date (YYYY-MM-DD)"
+ )
+ parser.add_argument(
+ "--end-date", type=str, default=DEFAULT_END_DATE,
+ help="End date (YYYY-MM-DD)"
+ )
+ parser.add_argument(
+ "--output", type=str, default=None,
+ help="Output parquet file path"
+ )
+ parser.add_argument(
+ "--no-vae", action="store_true",
+ help="Skip VAE encoding (use random embeddings for testing)"
+ )
+ parser.add_argument(
+ "--robust-zscore-params", type=str, default=None,
+ help="Path to robust zscore parameters directory"
+ )
args = parser.parse_args()
@@ -1059,7 +466,8 @@ if __name__ == "__main__":
start_date=args.start_date,
end_date=args.end_date,
output_file=args.output,
- use_vae=not args.no_vae
+ use_vae=not args.no_vae,
+ robust_zscore_params_path=args.robust_zscore_params
)
print("\nDone!")
diff --git a/stock_1d/d033/alpha158_beta/src/__init__.py b/stock_1d/d033/alpha158_beta/src/__init__.py
new file mode 100644
index 0000000..e599eca
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/src/__init__.py
@@ -0,0 +1,53 @@
+"""
+Source package for alpha158_beta experiments.
+
+This package provides modules for data loading, feature transformation,
+and model training for alpha158 beta factor experiments.
+"""
+
+from .processors import (
+ FeaturePipeline,
+ FeatureGroups,
+ DiffProcessor,
+ FlagMarketInjector,
+ FlagSTInjector,
+ ColumnRemover,
+ FlagToOnehot,
+ IndusNtrlInjector,
+ RobustZScoreNorm,
+ Fillna,
+ load_alpha158,
+ load_market_ext,
+ load_market_flags,
+ load_industry_flags,
+ load_all_data,
+ load_robust_zscore_params,
+ filter_stock_universe,
+)
+
+__all__ = [
+ # Main pipeline
+ 'FeaturePipeline',
+ 'FeatureGroups',
+
+ # Processors
+ 'DiffProcessor',
+ 'FlagMarketInjector',
+ 'FlagSTInjector',
+ 'ColumnRemover',
+ 'FlagToOnehot',
+ 'IndusNtrlInjector',
+ 'RobustZScoreNorm',
+ 'Fillna',
+
+ # Loaders
+ 'load_alpha158',
+ 'load_market_ext',
+ 'load_market_flags',
+ 'load_industry_flags',
+ 'load_all_data',
+
+ # Utilities
+ 'load_robust_zscore_params',
+ 'filter_stock_universe',
+]
diff --git a/stock_1d/d033/alpha158_beta/src/processors/__init__.py b/stock_1d/d033/alpha158_beta/src/processors/__init__.py
new file mode 100644
index 0000000..3d94ddc
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/src/processors/__init__.py
@@ -0,0 +1,136 @@
+"""
+Processors package for the alpha158_beta feature pipeline.
+
+This package provides a modular, polars-native data loading and transformation
+pipeline for generating VAE input features from alpha158 beta factors.
+
+Main components:
+- FeatureGroups: Dataclass container for separate feature groups
+- FeaturePipeline: Main orchestrator class for the full pipeline
+- Processors: Individual transformation classes (DiffProcessor, IndusNtrlInjector, etc.)
+- Loaders: Data loading functions for parquet sources
+
+Usage:
+ from processors import FeaturePipeline, FeatureGroups
+
+ pipeline = FeaturePipeline()
+ feature_groups = pipeline.load_data(start_date, end_date)
+ transformed = pipeline.transform(feature_groups)
+ vae_input = pipeline.get_vae_input(transformed)
+"""
+
+from .dataclass import FeatureGroups
+from .loaders import (
+ load_alpha158,
+ load_market_ext,
+ load_market_flags,
+ load_industry_flags,
+ load_all_data,
+ load_parquet_by_date_range,
+ get_date_partitions,
+ # Constants
+ PARQUET_ALPHA158_BETA_PATH,
+ PARQUET_KLINE_PATH,
+ PARQUET_MARKET_FLAG_PATH,
+ PARQUET_INDUSTRY_FLAG_PATH,
+ PARQUET_CON_RATING_PATH,
+ INDUSTRY_FLAG_COLS,
+ MARKET_FLAG_COLS_KLINE,
+ MARKET_FLAG_COLS_MARKET,
+ MARKET_EXT_RAW_COLS,
+)
+from .processors import (
+ DiffProcessor,
+ FlagMarketInjector,
+ FlagSTInjector,
+ ColumnRemover,
+ FlagToOnehot,
+ IndusNtrlInjector,
+ RobustZScoreNorm,
+ Fillna,
+)
+from .pipeline import (
+ FeaturePipeline,
+ load_robust_zscore_params,
+ filter_stock_universe,
+ # Constants
+ ALPHA158_COLS,
+ MARKET_EXT_BASE_COLS,
+ MARKET_FLAG_COLS,
+ COLUMNS_TO_REMOVE,
+ VAE_INPUT_DIM,
+ DEFAULT_ROBUST_ZSCORE_PARAMS_PATH,
+)
+from .exporters import (
+ get_groups,
+ get_groups_from_fg,
+ pack_structs,
+ unpack_struct,
+ dump_to_parquet,
+ dump_to_pickle,
+ dump_to_numpy,
+ dump_features,
+ # Also import old names for backward compatibility
+ get_groups as select_feature_groups_from_df,
+ get_groups_from_fg as select_feature_groups,
+)
+
+__all__ = [
+ # Main classes
+ 'FeaturePipeline',
+ 'FeatureGroups',
+
+ # Processors
+ 'DiffProcessor',
+ 'FlagMarketInjector',
+ 'FlagSTInjector',
+ 'ColumnRemover',
+ 'FlagToOnehot',
+ 'IndusNtrlInjector',
+ 'RobustZScoreNorm',
+ 'Fillna',
+
+ # Loaders
+ 'load_alpha158',
+ 'load_market_ext',
+ 'load_market_flags',
+ 'load_industry_flags',
+ 'load_all_data',
+ 'load_parquet_by_date_range',
+ 'get_date_partitions',
+
+ # Utility functions
+ 'load_robust_zscore_params',
+ 'filter_stock_universe',
+
+ # Exporter functions
+ 'get_groups',
+ 'get_groups_from_fg',
+ 'pack_structs',
+ 'unpack_struct',
+ 'dump_to_parquet',
+ 'dump_to_pickle',
+ 'dump_to_numpy',
+ 'dump_features',
+
+ # Backward compatibility aliases
+ 'select_feature_groups_from_df',
+ 'select_feature_groups',
+
+ # Constants
+ 'PARQUET_ALPHA158_BETA_PATH',
+ 'PARQUET_KLINE_PATH',
+ 'PARQUET_MARKET_FLAG_PATH',
+ 'PARQUET_INDUSTRY_FLAG_PATH',
+ 'PARQUET_CON_RATING_PATH',
+ 'INDUSTRY_FLAG_COLS',
+ 'MARKET_FLAG_COLS_KLINE',
+ 'MARKET_FLAG_COLS_MARKET',
+ 'MARKET_EXT_RAW_COLS',
+ 'ALPHA158_COLS',
+ 'MARKET_EXT_BASE_COLS',
+ 'MARKET_FLAG_COLS',
+ 'COLUMNS_TO_REMOVE',
+ 'VAE_INPUT_DIM',
+ 'DEFAULT_ROBUST_ZSCORE_PARAMS_PATH',
+]
diff --git a/stock_1d/d033/alpha158_beta/src/processors/dataclass.py b/stock_1d/d033/alpha158_beta/src/processors/dataclass.py
new file mode 100644
index 0000000..043f45c
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/src/processors/dataclass.py
@@ -0,0 +1,115 @@
+"""Dataclass definitions for the feature pipeline."""
+
+from dataclasses import dataclass, field
+from typing import Optional, List
+import polars as pl
+
+# Import constants for column categorization
+# Alpha158 base columns (158 features)
+ALPHA158_BASE_COLS = [
+ 'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
+ 'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
+ 'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
+ 'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
+ 'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
+ 'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
+ 'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
+ 'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
+ 'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
+ 'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
+ 'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
+ 'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
+ 'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
+ 'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
+ 'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
+ 'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
+ 'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
+ 'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
+ 'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
+ 'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
+ 'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
+ 'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
+ 'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
+ 'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
+ 'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
+ 'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
+ 'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
+ 'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
+ 'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
+ 'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
+ 'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
+]
+
+# Market extension base columns
+MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
+
+# Market flag base columns (before processors)
+MARKET_FLAG_BASE_COLS = [
+ 'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
+ 'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'
+]
+
+
+@dataclass
+class FeatureGroups:
+ """
+ Container for separate feature groups in the pipeline.
+
+ Keeps feature groups separate throughout the pipeline to avoid
+ unnecessary merging and complex column management.
+
+ Attributes:
+ alpha158: 158 alpha158 features (+ _ntrl after industry neutralization)
+ market_ext: Market extension features (Turnover, FreeTurnover, MarketValue, etc.)
+ market_flag: Market flag features (IsZt, IsDt, IsXD, etc.)
+ industry: 29 industry flags (converted to indus_idx by FlagToOnehot)
+ indus_idx: Single column industry index (after FlagToOnehot processing)
+ instruments: List of instrument IDs for metadata
+ dates: List of datetime values for metadata
+ """
+ # Core feature groups
+ alpha158: pl.DataFrame # 158 alpha158 features (+ _ntrl after processing)
+ market_ext: pl.DataFrame # Market extension features (+ _ntrl after processing)
+ market_flag: pl.DataFrame # Market flags (12 cols initially, 11 after ColumnRemover)
+ industry: Optional[pl.DataFrame] = None # 29 industry flags -> indus_idx
+
+ # Processed industry index (separate after FlagToOnehot)
+ indus_idx: Optional[pl.DataFrame] = None # Single column after FlagToOnehot
+
+ # Metadata (extracted from dataframes for easy access)
+ instruments: List[str] = field(default_factory=list)
+ dates: List[int] = field(default_factory=list)
+
+ def extract_metadata(self) -> None:
+ """Extract instrument and date lists from the alpha158 dataframe."""
+ if self.alpha158 is not None and len(self.alpha158) > 0:
+ self.instruments = self.alpha158['instrument'].to_list()
+ self.dates = self.alpha158['datetime'].to_list()
+
+ def merge_for_processors(self) -> pl.DataFrame:
+ """
+ Merge all feature groups into a single DataFrame for processors.
+
+ This is used by processors that need access to multiple groups
+ (e.g., IndusNtrlInjector needs industry index).
+
+ Returns:
+ Merged DataFrame with all features
+ """
+ df = self.alpha158
+
+ # Merge market_ext if not already merged
+ if self.market_ext is not None and self.market_ext is not self.alpha158:
+ df = df.join(self.market_ext, on=['instrument', 'datetime'], how='left')
+
+ # Merge market_flag if not already merged
+ if self.market_flag is not None and self.market_flag is not self.alpha158:
+ df = df.join(self.market_flag, on=['instrument', 'datetime'], how='left')
+
+ # Merge industry/indus_idx if available
+ if self.indus_idx is not None:
+ df = df.join(self.indus_idx, on=['instrument', 'datetime'], how='left')
+ elif self.industry is not None:
+ df = df.join(self.industry, on=['instrument', 'datetime'], how='left')
+
+ return df
diff --git a/stock_1d/d033/alpha158_beta/src/processors/exporters.py b/stock_1d/d033/alpha158_beta/src/processors/exporters.py
new file mode 100644
index 0000000..a656442
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/src/processors/exporters.py
@@ -0,0 +1,523 @@
+"""
+Feature exporters for the alpha158_beta pipeline.
+
+This module provides functions to select and export feature groups from the
+transformed pipeline output. It can be used by both dump_features.py and
+generate_beta_embedding.py.
+
+Feature groups:
+- merged: All columns after transformation
+- alpha158: Alpha158 features + _ntrl versions
+- market_ext: Market extended features + _ntrl + _diff
+- market_flag: Market flag columns
+- vae_input: 341 features specifically curated for VAE training
+
+Struct mode:
+- When pack_struct=True, each feature group is packed into a struct column:
+ - features_alpha158 (316 fields)
+ - features_market_ext (14 fields)
+ - features_market_flag (11 fields)
+"""
+
+import os
+import pickle
+from pathlib import Path
+from typing import Dict, Any, List, Optional, Union
+
+import numpy as np
+import polars as pl
+
+from .pipeline import (
+ ALPHA158_COLS,
+ MARKET_EXT_BASE_COLS,
+ MARKET_FLAG_COLS,
+ COLUMNS_TO_REMOVE,
+)
+from .dataclass import FeatureGroups
+
+
+# =============================================================================
+# Helper functions for struct packing/unpacking
+# =============================================================================
+
+def pack_structs(df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Pack feature columns into struct columns based on feature groups.
+
+ Creates:
+ - features_alpha158: struct with 316 fields (158 + 158 _ntrl)
+ - features_market_ext: struct with 14 fields (7 + 7 _ntrl)
+ - features_market_flag: struct with 11 fields
+
+ Args:
+ df: Input DataFrame with flat columns
+
+ Returns:
+ DataFrame with struct columns: instrument, datetime, indus_idx, features_*
+ """
+ # Define column groups
+ alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
+ alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS
+
+ market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
+ market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
+ market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
+ market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff
+
+ market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
+ market_flag_cols += ['market_0', 'market_1', 'IsST']
+ market_flag_cols = list(dict.fromkeys(market_flag_cols))
+
+ # Build result with struct columns
+ result_cols = ['instrument', 'datetime']
+
+ # Check if indus_idx exists
+ if 'indus_idx' in df.columns:
+ result_cols.append('indus_idx')
+
+ # Pack alpha158
+ alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns]
+ if alpha158_cols_in_df:
+ result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158'))
+
+ # Pack market_ext
+ ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns]
+ if ext_cols_in_df:
+ result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext'))
+
+ # Pack market_flag
+ flag_cols_in_df = [c for c in market_flag_cols if c in df.columns]
+ if flag_cols_in_df:
+ result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag'))
+
+ return df.select(result_cols)
+
+
+def unpack_struct(df: pl.DataFrame, struct_name: str) -> pl.DataFrame:
+ """
+ Unpack a struct column back into individual columns.
+
+ Args:
+ df: DataFrame containing struct column
+ struct_name: Name of the struct column to unpack
+
+ Returns:
+ DataFrame with struct fields as individual columns
+ """
+ if struct_name not in df.columns:
+ raise ValueError(f"Struct column '{struct_name}' not found in DataFrame")
+
+ # Get the struct field names
+ struct_dtype = df.schema[struct_name]
+ if not isinstance(struct_dtype, pl.Struct):
+ raise ValueError(f"Column '{struct_name}' is not a struct type")
+
+ field_names = struct_dtype.fields
+
+ # Unpack using struct field access
+ unpacked_cols = []
+ for field in field_names:
+ col_name = field.name
+ unpacked_cols.append(
+ pl.col(struct_name).struct.field(col_name).alias(col_name)
+ )
+
+ # Select original columns + unpacked columns
+ other_cols = [c for c in df.columns if c != struct_name]
+ return df.select(other_cols + unpacked_cols)
+
+
+def dump_to_parquet(df: pl.DataFrame, path: str, verbose: bool = True) -> None:
+ """Save DataFrame to parquet file."""
+ if verbose:
+ print(f"Saving to parquet: {path}")
+ df.write_parquet(path)
+ if verbose:
+ print(f" Shape: {df.shape}")
+
+
+def dump_to_pickle(df: pl.DataFrame, path: str, verbose: bool = True) -> None:
+ """Save DataFrame to pickle file."""
+ if verbose:
+ print(f"Saving to pickle: {path}")
+ with open(path, 'wb') as f:
+ pickle.dump(df, f)
+ if verbose:
+ print(f" Shape: {df.shape}")
+
+
+def dump_to_numpy(
+ feature_groups: FeatureGroups,
+ path: str,
+ include_metadata: bool = True,
+ verbose: bool = True
+) -> None:
+ """
+ Save features to numpy format.
+
+ Saves:
+ - features.npy: The feature matrix
+ - metadata.pkl: Column names and metadata (if include_metadata=True)
+ """
+ if verbose:
+ print(f"Saving to numpy: {path}")
+
+ # Merge all groups for numpy array
+ df = feature_groups.merge_for_processors()
+
+ # Get all feature columns (exclude instrument, datetime, indus_idx)
+ feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
+
+ # Extract features
+ features = df.select(feature_cols).to_numpy().astype(np.float32)
+ features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Save features
+ base_path = Path(path)
+ if base_path.suffix == '':
+ base_path = Path(str(path) + '.npy')
+
+ np.save(str(base_path), features)
+
+ if verbose:
+ print(f" Features shape: {features.shape}")
+
+ # Save metadata if requested
+ if include_metadata:
+ metadata_path = str(base_path).replace('.npy', '_metadata.pkl')
+ metadata = {
+ 'feature_cols': feature_cols,
+ 'instruments': df['instrument'].to_list(),
+ 'dates': df['datetime'].to_list(),
+ 'n_features': len(feature_cols),
+ 'n_samples': len(df),
+ }
+ with open(metadata_path, 'wb') as f:
+ pickle.dump(metadata, f)
+ if verbose:
+ print(f" Metadata saved to: {metadata_path}")
+
+
+def get_groups(
+ df: pl.DataFrame,
+ groups_to_dump: List[str],
+ verbose: bool = True,
+ use_struct: bool = False,
+) -> Dict[str, Any]:
+ """
+ Select which feature groups to include in the output from a merged DataFrame.
+
+ Args:
+ df: Transformed merged DataFrame (flat or with struct columns)
+ groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input')
+ verbose: Whether to print progress
+ use_struct: If True, pack feature columns into a single 'features' struct column.
+ If df already has struct columns (pack_struct=True was used in pipeline),
+ this function will automatically handle them.
+
+ Returns:
+ Dictionary with selected DataFrames and metadata
+ """
+ # Check if df already has struct columns from pipeline.pack_struct
+ has_struct_cols = any(
+ isinstance(df.schema.get(c), pl.Struct)
+ for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']
+ )
+
+ result = {}
+
+ if 'merged' in groups_to_dump:
+ if has_struct_cols:
+ # Already has struct columns from pipeline.pack_struct
+ result['merged'] = df
+ elif use_struct:
+ # Keep instrument, datetime, and pack rest into struct
+ feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
+ result['merged'] = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
+ else:
+ result['merged'] = df
+ if verbose:
+ print(f"Merged features: {result['merged'].shape}")
+
+ if 'alpha158' in groups_to_dump:
+ # Check if struct column already exists from pipeline.pack_struct
+ if has_struct_cols and 'features_alpha158' in df.columns:
+ result['alpha158'] = df.select(['instrument', 'datetime', 'features_alpha158'])
+ else:
+ # Select alpha158 columns from merged DataFrame
+ alpha_cols = ['instrument', 'datetime'] + [c for c in ALPHA158_COLS if c in df.columns]
+ # Also include _ntrl versions
+ alpha_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS if f"{c}_ntrl" in df.columns]
+ alpha_cols += alpha_ntrl_cols
+ df_alpha = df.select(alpha_cols)
+ if use_struct:
+ feature_cols = [c for c in df_alpha.columns if c not in ['instrument', 'datetime']]
+ df_alpha = df_alpha.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
+ result['alpha158'] = df_alpha
+ if verbose:
+ print(f"Alpha158 features: {result['alpha158'].shape}")
+
+ if 'market_ext' in groups_to_dump:
+ # Check if struct column already exists from pipeline.pack_struct
+ if has_struct_cols and 'features_market_ext' in df.columns:
+ result['market_ext'] = df.select(['instrument', 'datetime', 'features_market_ext'])
+ else:
+ # Select market_ext columns from merged DataFrame
+ market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
+ market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
+ ext_cols = ['instrument', 'datetime'] + market_ext_with_diff
+ # Also include _ntrl versions
+ ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff if f"{c}_ntrl" in df.columns]
+ ext_cols += ext_ntrl_cols
+ df_ext = df.select(ext_cols)
+ if use_struct:
+ feature_cols = [c for c in df_ext.columns if c not in ['instrument', 'datetime']]
+ df_ext = df_ext.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
+ result['market_ext'] = df_ext
+ if verbose:
+ print(f"Market ext features: {result['market_ext'].shape}")
+
+ if 'market_flag' in groups_to_dump:
+ # Check if struct column already exists from pipeline.pack_struct
+ if has_struct_cols and 'features_market_flag' in df.columns:
+ result['market_flag'] = df.select(['instrument', 'datetime', 'features_market_flag'])
+ else:
+ # Select market_flag columns from merged DataFrame
+ flag_cols = ['instrument', 'datetime']
+ flag_cols += [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE and c in df.columns]
+ flag_cols += ['market_0', 'market_1', 'IsST'] if all(c in df.columns for c in ['market_0', 'market_1', 'IsST']) else []
+ flag_cols = list(dict.fromkeys(flag_cols)) # Remove duplicates
+ df_flag = df.select(flag_cols)
+ if use_struct:
+ feature_cols = [c for c in df_flag.columns if c not in ['instrument', 'datetime']]
+ df_flag = df_flag.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
+ result['market_flag'] = df_flag
+ if verbose:
+ print(f"Market flag features: {result['market_flag'].shape}")
+
+ if 'vae_input' in groups_to_dump:
+ # Get VAE input columns from merged DataFrame
+ # VAE input = 330 normalized features + 11 market flags = 341 features
+ # Note: indus_idx is NOT included in VAE input
+
+ # Build alpha158 feature columns (158 original + 158 _ntrl = 316)
+ alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
+ alpha158_cols = ALPHA158_COLS.copy()
+
+ # Build market_ext feature columns (7 original + 7 _ntrl = 14)
+ market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
+ market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
+ market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
+ market_ext_cols = market_ext_with_diff.copy()
+
+ # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
+ norm_feature_cols = (
+ alpha158_ntrl_cols + alpha158_cols +
+ market_ext_ntrl_cols + market_ext_cols
+ )
+
+ # Market flag columns (excluding IsST)
+ market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
+ market_flag_cols += ['market_0', 'market_1']
+ market_flag_cols = list(dict.fromkeys(market_flag_cols))
+
+ # Combine all VAE input columns (341 total)
+ vae_cols = norm_feature_cols + market_flag_cols
+
+ if use_struct:
+ # Pack all features into a single struct column
+ result['vae_input'] = df.select([
+ 'instrument',
+ 'datetime',
+ pl.struct(vae_cols).alias('features')
+ ])
+ if verbose:
+ print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)")
+ else:
+ # Always keep datetime and instrument as index columns
+ # Select features with index columns first
+ result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols)
+ if verbose:
+ print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)")
+
+ return result
+
+
+def get_groups_from_fg(
+ feature_groups: FeatureGroups,
+ groups_to_dump: List[str],
+ verbose: bool = True,
+ use_struct: bool = False,
+) -> Dict[str, Any]:
+ """
+ Select which feature groups to include in the output.
+
+ Args:
+ feature_groups: Transformed FeatureGroups container
+ groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input')
+ verbose: Whether to print progress
+ use_struct: If True, pack feature columns into a single 'features' struct column
+
+ Returns:
+ Dictionary with selected DataFrames and metadata
+ """
+ result = {}
+
+ if 'merged' in groups_to_dump:
+ df = feature_groups.merge_for_processors()
+ if use_struct:
+ # Keep instrument, datetime, and pack rest into struct
+ feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
+ df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
+ result['merged'] = df
+ if verbose:
+ print(f"Merged features: {result['merged'].shape}")
+
+ if 'alpha158' in groups_to_dump:
+ df = feature_groups.alpha158
+ if use_struct:
+ feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
+ df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
+ result['alpha158'] = df
+ if verbose:
+ print(f"Alpha158 features: {df.shape}")
+
+ if 'market_ext' in groups_to_dump:
+ df = feature_groups.market_ext
+ if use_struct:
+ feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
+ df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
+ result['market_ext'] = df
+ if verbose:
+ print(f"Market ext features: {df.shape}")
+
+ if 'market_flag' in groups_to_dump:
+ df = feature_groups.market_flag
+ if use_struct:
+ feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
+ df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
+ result['market_flag'] = df
+ if verbose:
+ print(f"Market flag features: {df.shape}")
+
+ if 'vae_input' in groups_to_dump:
+ # Get VAE input columns from already-transformed feature_groups
+ # VAE input = 330 normalized features + 11 market flags = 341 features
+ # Note: indus_idx is NOT included in VAE input
+ df = feature_groups.merge_for_processors()
+
+ # Build alpha158 feature columns (158 original + 158 _ntrl = 316)
+ alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
+ alpha158_cols = ALPHA158_COLS.copy()
+
+ # Build market_ext feature columns (7 original + 7 _ntrl = 14)
+ market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
+ market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
+ market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
+ market_ext_cols = market_ext_with_diff.copy()
+
+ # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
+ norm_feature_cols = (
+ alpha158_ntrl_cols + alpha158_cols +
+ market_ext_ntrl_cols + market_ext_cols
+ )
+
+ # Market flag columns (excluding IsST)
+ market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
+ market_flag_cols += ['market_0', 'market_1']
+ market_flag_cols = list(dict.fromkeys(market_flag_cols))
+
+ # Combine all VAE input columns (341 total)
+ vae_cols = norm_feature_cols + market_flag_cols
+
+ if use_struct:
+ # Pack all features into a single struct column
+ result['vae_input'] = df.select([
+ 'instrument',
+ 'datetime',
+ pl.struct(vae_cols).alias('features')
+ ])
+ if verbose:
+ print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)")
+ else:
+ # Always keep datetime and instrument as index columns
+ # Select features with index columns first
+ result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols)
+ if verbose:
+ print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)")
+
+ return result
+
+
+def dump_features(
+ df: pl.DataFrame,
+ output_path: str,
+ output_format: str = 'parquet',
+ groups: List[str] = None,
+ verbose: bool = True,
+ use_struct: bool = False,
+) -> None:
+ """
+ Dump features to file.
+
+ Args:
+ df: Transformed merged DataFrame
+ output_path: Output file path
+ output_format: Output format ('parquet', 'pickle', 'numpy')
+ groups: Feature groups to dump (default: ['merged'])
+ verbose: Whether to print progress
+ use_struct: If True, pack feature columns into a single 'features' struct column
+ """
+ if groups is None:
+ groups = ['merged']
+
+ # Select feature groups from merged DataFrame
+ outputs = get_groups(df, groups, verbose, use_struct)
+
+ # Ensure output directory exists
+ output_dir = os.path.dirname(output_path)
+ if output_dir:
+ os.makedirs(output_dir, exist_ok=True)
+
+ # For numpy, we use FeatureGroups (need to convert df back)
+ # This is a simplified version - for numpy, use dump_to_numpy directly
+ if output_format == 'numpy':
+ raise NotImplementedError(
+ "For numpy output, use the FeaturePipeline directly. "
+ "This function handles parquet/pickle output only."
+ )
+
+ # Dump to file(s)
+ if output_format == 'pickle':
+ if 'merged' in outputs:
+ dump_to_pickle(outputs['merged'], output_path, verbose=verbose)
+ elif len(outputs) == 1:
+ # Single group output
+ key = list(outputs.keys())[0]
+ base_path = Path(output_path)
+ dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
+ dump_to_pickle(outputs[key], dump_path, verbose=verbose)
+ else:
+ # Multiple groups - save each separately
+ base_path = Path(output_path)
+ for key, df_out in outputs.items():
+ dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
+ dump_to_pickle(df_out, dump_path, verbose=verbose)
+ else: # parquet
+ if 'merged' in outputs:
+ dump_to_parquet(outputs['merged'], output_path, verbose=verbose)
+ elif len(outputs) == 1:
+ # Single group output
+ key = list(outputs.keys())[0]
+ base_path = Path(output_path)
+ dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
+ dump_to_parquet(outputs[key], dump_path, verbose=verbose)
+ else:
+ # Multiple groups - save each separately
+ base_path = Path(output_path)
+ for key, df_out in outputs.items():
+ dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
+ dump_to_parquet(df_out, dump_path, verbose=verbose)
+
+ if verbose:
+ print("Feature dump complete!")
\ No newline at end of file
diff --git a/stock_1d/d033/alpha158_beta/src/processors/loaders.py b/stock_1d/d033/alpha158_beta/src/processors/loaders.py
new file mode 100644
index 0000000..e1a22fc
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/src/processors/loaders.py
@@ -0,0 +1,279 @@
+"""Data loading functions for the feature pipeline.
+
+Memory Best Practices:
+- Always filter on partition keys (datetime) BEFORE collecting
+- Use streaming for large date ranges (>1 year)
+- Never collect full parquet tables - filter on datetime first
+"""
+
+import polars as pl
+from datetime import datetime
+from typing import Optional, List, Tuple
+
+# Data paths
+PARQUET_ALPHA158_BETA_PATH = "/data/parquet/dataset/stg_1day_wind_alpha158_0_7_beta_1D/"
+PARQUET_KLINE_PATH = "/data/parquet/dataset/stg_1day_wind_kline_adjusted_1D/"
+PARQUET_MARKET_FLAG_PATH = "/data/parquet/dataset/stg_1day_wind_market_flag_1D/"
+PARQUET_INDUSTRY_FLAG_PATH = "/data/parquet/dataset/stg_1day_gds_indus_flag_cc1_1D/"
+PARQUET_CON_RATING_PATH = "/data/parquet/dataset/stg_1day_gds_con_rating_1D/"
+
+# Industry flag columns (30 one-hot columns - note: gds_CC29 is not present in the data)
+INDUSTRY_FLAG_COLS = [
+ 'gds_CC10', 'gds_CC11', 'gds_CC12', 'gds_CC20', 'gds_CC21', 'gds_CC22',
+ 'gds_CC23', 'gds_CC24', 'gds_CC25', 'gds_CC26', 'gds_CC27', 'gds_CC28',
+ 'gds_CC30', 'gds_CC31', 'gds_CC32', 'gds_CC33', 'gds_CC34', 'gds_CC35',
+ 'gds_CC36', 'gds_CC37', 'gds_CC40', 'gds_CC41', 'gds_CC42', 'gds_CC43',
+ 'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70'
+]
+
+# Market flag columns
+MARKET_FLAG_COLS_KLINE = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR']
+MARKET_FLAG_COLS_MARKET = ['open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop']
+
+# Market extension raw columns
+MARKET_EXT_RAW_COLS = ['Turnover', 'FreeTurnover', 'MarketValue']
+
+
+def get_date_partitions(start_date: str, end_date: str) -> List[str]:
+ """
+ Generate a list of date partitions to load from Parquet.
+
+ Args:
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+
+ Returns:
+ List of datetime=YYYYMMDD partition strings for weekdays only
+ """
+ start = datetime.strptime(start_date, "%Y-%m-%d")
+ end = datetime.strptime(end_date, "%Y-%m-%d")
+
+ partitions = []
+ current = start
+ while current <= end:
+ if current.weekday() < 5: # Monday = 0, Friday = 4
+ partitions.append(f"datetime={current.strftime('%Y%m%d')}")
+ current = datetime(current.year, current.month, current.day + 1)
+
+ return partitions
+
+
+def load_parquet_by_date_range(
+ base_path: str,
+ start_date: str,
+ end_date: str,
+ columns: Optional[List[str]] = None,
+ collect: bool = True,
+ streaming: bool = False
+) -> pl.LazyFrame | pl.DataFrame:
+ """
+ Load parquet data filtered by date range.
+
+ CRITICAL: This function filters on the datetime partition key BEFORE collecting.
+ This is essential for memory efficiency with large parquet datasets.
+
+ Args:
+ base_path: Base path to the parquet dataset
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+ columns: Optional list of columns to select (excluding instrument/datetime)
+ collect: If True, return DataFrame (default); if False, return LazyFrame
+ streaming: If True, use streaming mode for large datasets (recommended for >1 year)
+
+ Returns:
+ Polars DataFrame with the loaded data (LazyFrame if collect=False)
+ """
+ start_int = int(start_date.replace("-", ""))
+ end_int = int(end_date.replace("-", ""))
+
+ try:
+ # Start with lazy scan - DO NOT COLLECT YET
+ lf = pl.scan_parquet(base_path)
+
+ # CRITICAL: Filter on partition key FIRST, before any other operations
+ # This ensures partition pruning happens at the scan level
+ lf = lf.filter(pl.col('datetime') >= start_int)
+ lf = lf.filter(pl.col('datetime') <= end_int)
+
+ # Select specific columns if provided (column pruning)
+ if columns:
+ # Get schema to check which columns exist
+ schema = lf.collect_schema()
+ available_cols = ['instrument', 'datetime'] + [c for c in columns if c in schema.names()]
+ lf = lf.select(available_cols)
+
+ # Collect with optional streaming mode
+ if collect:
+ if streaming:
+ return lf.collect(streaming=True)
+ return lf.collect()
+ return lf
+
+ except Exception as e:
+ print(f"Error loading from {base_path}: {e}")
+ # Return empty DataFrame with expected schema
+ if columns:
+ return pl.DataFrame({
+ 'instrument': pl.Series([], dtype=pl.String),
+ 'datetime': pl.Series([], dtype=pl.Int32),
+ **{col: pl.Series([], dtype=pl.Float64) for col in columns if c not in ['instrument', 'datetime']}
+ })
+ return pl.DataFrame({
+ 'instrument': pl.Series([], dtype=pl.String),
+ 'datetime': pl.Series([], dtype=pl.Int32)
+ })
+
+
+def load_alpha158(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
+ """
+ Load alpha158 beta factors from parquet.
+
+ Args:
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+ streaming: If True, use streaming mode for large datasets
+
+ Returns:
+ DataFrame with instrument, datetime, and 158 alpha158 features
+ """
+ print("Loading alpha158_0_7_beta factors...")
+ df = load_parquet_by_date_range(PARQUET_ALPHA158_BETA_PATH, start_date, end_date, streaming=streaming)
+ print(f" Alpha158 shape: {df.shape}")
+ return df
+
+
+def load_market_ext(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
+ """
+ Load market extension features from parquet.
+
+ Loads Turnover, FreeTurnover, MarketValue from kline data and transforms:
+ - Turnover -> turnover (rename)
+ - FreeTurnover -> free_turnover (rename)
+ - MarketValue -> log_size = log(MarketValue)
+ - con_rating_strength: loaded from parquet (or zeros if not available)
+
+ Args:
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+ streaming: If True, use streaming mode for large datasets
+
+ Returns:
+ DataFrame with instrument, datetime, turnover, free_turnover, log_size, con_rating_strength
+ """
+ print("Loading kline data (market ext columns)...")
+
+ # Load raw kline columns
+ df_kline = load_parquet_by_date_range(
+ PARQUET_KLINE_PATH, start_date, end_date, MARKET_EXT_RAW_COLS, streaming=streaming
+ )
+ print(f" Kline (market ext raw) shape: {df_kline.shape}")
+
+ # Load con_rating_strength from parquet
+ print("Loading con_rating_strength from parquet...")
+ df_con_rating = load_parquet_by_date_range(
+ PARQUET_CON_RATING_PATH, start_date, end_date, ['con_rating_strength'], streaming=streaming
+ )
+ print(f" Con rating shape: {df_con_rating.shape}")
+
+ # Transform columns
+ df_kline = df_kline.with_columns([
+ pl.col('Turnover').alias('turnover'),
+ pl.col('FreeTurnover').alias('free_turnover'),
+ pl.col('MarketValue').log().alias('log_size'),
+ ])
+ print(f" Kline (market ext transformed) shape: {df_kline.shape}")
+
+ # Merge con_rating_strength
+ df_kline = df_kline.join(df_con_rating, on=['instrument', 'datetime'], how='left')
+ df_kline = df_kline.with_columns([
+ pl.col('con_rating_strength').fill_null(0.0)
+ ])
+ print(f" Kline (with con_rating) shape: {df_kline.shape}")
+
+ return df_kline
+
+
+def load_market_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
+ """
+ Load market flag features from parquet.
+
+ Combines two sources:
+ - From kline_adjusted: IsZt, IsDt, IsN, IsXD, IsXR, IsDR
+ - From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop
+
+ Args:
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+ streaming: If True, use streaming mode for large datasets
+
+ Returns:
+ DataFrame with instrument, datetime, and 12 market flag columns
+ """
+ # Load kline flags
+ print("Loading market flags from kline_adjusted...")
+ df_kline_flag = load_parquet_by_date_range(
+ PARQUET_KLINE_PATH, start_date, end_date, MARKET_FLAG_COLS_KLINE, streaming=streaming
+ )
+ print(f" Kline flags shape: {df_kline_flag.shape}")
+
+ # Load market flags
+ print("Loading market flags from market_flag table...")
+ df_market_flag = load_parquet_by_date_range(
+ PARQUET_MARKET_FLAG_PATH, start_date, end_date, MARKET_FLAG_COLS_MARKET, streaming=streaming
+ )
+ print(f" Market flag shape: {df_market_flag.shape}")
+
+ # Merge both flag sources
+ df_flag = df_kline_flag.join(df_market_flag, on=['instrument', 'datetime'], how='inner')
+ print(f" Combined flags shape: {df_flag.shape}")
+
+ return df_flag
+
+
+def load_industry_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
+ """
+ Load industry flag features from parquet.
+
+ Args:
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+ streaming: If True, use streaming mode for large datasets
+
+ Returns:
+ DataFrame with instrument, datetime, and 29 industry flag columns
+ """
+ print("Loading industry flags...")
+ df = load_parquet_by_date_range(
+ PARQUET_INDUSTRY_FLAG_PATH, start_date, end_date, INDUSTRY_FLAG_COLS, streaming=streaming
+ )
+ print(f" Industry shape: {df.shape}")
+ return df
+
+
+def load_all_data(
+ start_date: str,
+ end_date: str,
+ streaming: bool = False
+) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]:
+ """
+ Load all data sources from Parquet.
+
+ This is a convenience function that loads all four data sources
+ and returns them as separate DataFrames.
+
+ Args:
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+ streaming: If True, use streaming mode for large datasets
+
+ Returns:
+ Tuple of (alpha158_df, market_ext_df, market_flag_df, industry_df)
+ """
+ print(f"Loading data from {start_date} to {end_date}...")
+
+ df_alpha = load_alpha158(start_date, end_date, streaming=streaming)
+ df_market_ext = load_market_ext(start_date, end_date, streaming=streaming)
+ df_market_flag = load_market_flags(start_date, end_date, streaming=streaming)
+ df_industry = load_industry_flags(start_date, end_date, streaming=streaming)
+
+ return df_alpha, df_market_ext, df_market_flag, df_industry
diff --git a/stock_1d/d033/alpha158_beta/src/processors/pipeline.py b/stock_1d/d033/alpha158_beta/src/processors/pipeline.py
new file mode 100644
index 0000000..ecf75a9
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/src/processors/pipeline.py
@@ -0,0 +1,539 @@
+"""FeaturePipeline orchestrator for the data loading and transformation pipeline."""
+
+import os
+import json
+import numpy as np
+import polars as pl
+from typing import List, Dict, Optional, Tuple
+from pathlib import Path
+
+from .dataclass import FeatureGroups
+from .loaders import (
+ load_alpha158,
+ load_market_ext,
+ load_market_flags,
+ load_industry_flags,
+ load_all_data,
+ INDUSTRY_FLAG_COLS,
+)
+from .processors import (
+ DiffProcessor,
+ FlagMarketInjector,
+ FlagSTInjector,
+ ColumnRemover,
+ FlagToOnehot,
+ IndusNtrlInjector,
+ RobustZScoreNorm,
+ Fillna
+)
+
+# Constants - Import from loaders module (single source of truth)
+# INDUSTRY_FLAG_COLS is now imported from .loaders above
+
+# Alpha158 feature columns in explicit order
+ALPHA158_COLS = [
+ 'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
+ 'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
+ 'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
+ 'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
+ 'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
+ 'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
+ 'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
+ 'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
+ 'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
+ 'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
+ 'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
+ 'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
+ 'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
+ 'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
+ 'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
+ 'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
+ 'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
+ 'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
+ 'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
+ 'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
+ 'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
+ 'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
+ 'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
+ 'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
+ 'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
+ 'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
+ 'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
+ 'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
+ 'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
+ 'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
+ 'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
+]
+
+# Market extension base columns
+MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
+
+# Market flag columns (before processors)
+MARKET_FLAG_COLS = [
+ 'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
+ 'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'
+]
+
+# Columns to remove after FlagMarketInjector and FlagSTInjector
+COLUMNS_TO_REMOVE = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
+
+# Expected VAE input dimension
+VAE_INPUT_DIM = 341
+
+# Default robust zscore parameters path
+DEFAULT_ROBUST_ZSCORE_PARAMS_PATH = (
+ "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/"
+ "data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/"
+)
+
+
+def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame:
+ """
+ Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE).
+
+ This uses qshare's filter_instruments which loads the instrument list from:
+ /data/qlib/default/data_ops/target/instruments/csiallx.txt
+
+ Args:
+ df: Input DataFrame with datetime and instrument columns
+ instruments: Market name for spine creation (default: 'csiallx')
+
+ Returns:
+ Filtered DataFrame with only instruments in the specified universe
+ """
+ from qshare.algo.polars.spine import filter_instruments
+ return filter_instruments(df, instruments=instruments)
+
+
+def load_robust_zscore_params(
+ params_path: str = None
+) -> Dict[str, np.ndarray]:
+ """
+ Load pre-fitted RobustZScoreNorm parameters from numpy files.
+
+ Loads mean_train.npy and std_train.npy from the specified directory.
+ Parameters are cached to avoid repeated file I/O.
+
+ Args:
+ params_path: Path to the directory containing mean_train.npy and std_train.npy.
+ If None, uses the default path.
+
+ Returns:
+ Dictionary with 'mean_train' and 'std_train' numpy arrays
+
+ Raises:
+ FileNotFoundError: If parameter files are not found
+ """
+ if params_path is None:
+ params_path = DEFAULT_ROBUST_ZSCORE_PARAMS_PATH
+
+ # Check for cached params in the class (module-level cache)
+ if not hasattr(load_robust_zscore_params, '_cached_params'):
+ load_robust_zscore_params._cached_params = {}
+
+ if params_path in load_robust_zscore_params._cached_params:
+ return load_robust_zscore_params._cached_params[params_path]
+
+ print(f"Loading robust zscore parameters from: {params_path}")
+
+ mean_path = os.path.join(params_path, 'mean_train.npy')
+ std_path = os.path.join(params_path, 'std_train.npy')
+
+ if not os.path.exists(mean_path):
+ raise FileNotFoundError(f"mean_train.npy not found at {mean_path}")
+ if not os.path.exists(std_path):
+ raise FileNotFoundError(f"std_train.npy not found at {std_path}")
+
+ mean_train = np.load(mean_path)
+ std_train = np.load(std_path)
+
+ print(f"Loaded parameters:")
+ print(f" mean_train shape: {mean_train.shape}")
+ print(f" std_train shape: {std_train.shape}")
+
+ # Try to load metadata if available
+ metadata_path = os.path.join(params_path, 'metadata.json')
+ if os.path.exists(metadata_path):
+ with open(metadata_path, 'r') as f:
+ metadata = json.load(f)
+ print(f" fitted on: {metadata.get('fit_start_time', 'N/A')} to {metadata.get('fit_end_time', 'N/A')}")
+
+ params = {
+ 'mean_train': mean_train,
+ 'std_train': std_train
+ }
+
+ # Cache the loaded parameters
+ load_robust_zscore_params._cached_params[params_path] = params
+
+ return params
+
+
+class FeaturePipeline:
+ """
+ Feature pipeline orchestrator for loading and transforming data.
+
+ The pipeline manages:
+ 1. Data loading from parquet sources
+ 2. Feature transformations via processors
+ 3. Output preparation for VAE input
+
+ Usage:
+ pipeline = FeaturePipeline(config)
+ feature_groups = pipeline.load_data(start_date, end_date)
+ transformed = pipeline.transform(feature_groups)
+ vae_input = pipeline.get_vae_input(transformed)
+ """
+
+ def __init__(
+ self,
+ config: Optional[Dict] = None,
+ robust_zscore_params_path: Optional[str] = None
+ ):
+ """
+ Initialize the FeaturePipeline.
+
+ Args:
+ config: Optional configuration dictionary with pipeline settings.
+ If None, uses default configuration.
+ robust_zscore_params_path: Path to robust zscore parameters.
+ If None, uses default path.
+ """
+ self.config = config or {}
+ self.robust_zscore_params_path = robust_zscore_params_path or DEFAULT_ROBUST_ZSCORE_PARAMS_PATH
+
+ # Cache for loaded robust zscore params
+ self._robust_zscore_params = None
+
+ # Initialize processors
+ self._init_processors()
+
+ def _init_processors(self):
+ """Initialize all processors with default configuration."""
+ # Diff processor for market_ext columns
+ self.diff_processor = DiffProcessor(MARKET_EXT_BASE_COLS)
+
+ # Market flag injector
+ self.flag_market_injector = FlagMarketInjector()
+
+ # ST flag injector
+ self.flag_st_injector = FlagSTInjector()
+
+ # Column remover
+ self.column_remover = ColumnRemover(COLUMNS_TO_REMOVE)
+
+ # Industry flag to index converter
+ self.flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
+
+ # Industry neutralization injectors (created on-demand with specific feature lists)
+ # self.indus_ntrl_injector = None # Created per-feature-group
+
+ # Robust zscore normalizer (created on-demand with loaded params)
+ # self.robust_norm = None # Created with loaded params
+
+ # Fillna processor
+ self.fillna = Fillna()
+
+ def load_data(
+ self,
+ start_date: str,
+ end_date: str,
+ filter_universe: bool = True,
+ universe_name: str = 'csiallx',
+ streaming: bool = False
+ ) -> FeatureGroups:
+ """
+ Load data from parquet sources into FeatureGroups container.
+
+ Args:
+ start_date: Start date in YYYY-MM-DD format
+ end_date: End date in YYYY-MM-DD format
+ filter_universe: Whether to filter to a specific stock universe
+ universe_name: Name of the stock universe to filter to
+ streaming: If True, use Polars streaming mode for large datasets (>1 year)
+
+ Returns:
+ FeatureGroups container with loaded data
+ """
+ print("=" * 60)
+ print(f"Loading data from {start_date} to {end_date}")
+ print("=" * 60)
+
+ # Load all data sources
+ df_alpha, df_market_ext, df_market_flag, df_industry = load_all_data(
+ start_date, end_date, streaming=streaming
+ )
+
+ # Apply stock universe filter if requested
+ if filter_universe:
+ print(f"Filtering to {universe_name} universe...")
+ df_alpha = filter_stock_universe(df_alpha, instruments=universe_name)
+ df_market_ext = filter_stock_universe(df_market_ext, instruments=universe_name)
+ df_market_flag = filter_stock_universe(df_market_flag, instruments=universe_name)
+ df_industry = filter_stock_universe(df_industry, instruments=universe_name)
+ print(f" After filter - Alpha158 shape: {df_alpha.shape}")
+
+ # Create FeatureGroups container
+ feature_groups = FeatureGroups(
+ alpha158=df_alpha,
+ market_ext=df_market_ext,
+ market_flag=df_market_flag,
+ industry=df_industry
+ )
+
+ # Extract metadata
+ feature_groups.extract_metadata()
+
+ print(f"Loaded {len(feature_groups.instruments)} samples")
+ print("=" * 60)
+
+ return feature_groups
+
+ def transform(
+ self,
+ feature_groups: FeatureGroups,
+ pack_struct: bool = False
+ ) -> pl.DataFrame:
+ """
+ Apply feature transformation pipeline to FeatureGroups.
+
+ The pipeline applies processors in the following order:
+ 1. DiffProcessor - adds diff features to market_ext
+ 2. FlagMarketInjector - adds market_0, market_1 to market_flag
+ 3. FlagSTInjector - adds IsST to market_flag
+ 4. ColumnRemover - removes log_size_diff, IsN, IsZt, IsDt
+ 5. FlagToOnehot - converts industry flags to indus_idx
+ 6. IndusNtrlInjector - industry neutralization for alpha158 and market_ext
+ 7. RobustZScoreNorm - robust z-score normalization
+ 8. Fillna - fill NaN values with 0
+
+ Args:
+ feature_groups: FeatureGroups container with loaded data
+ pack_struct: If True, pack each feature group into a struct column
+ (features_alpha158, features_market_ext, features_market_flag).
+ If False (default), return flat DataFrame with all columns merged.
+
+ Returns:
+ Merged DataFrame with all transformed features (pl.DataFrame)
+ """
+ print("=" * 60)
+ print("Starting feature transformation pipeline")
+ print("=" * 60)
+
+ # Merge all groups for processing
+ df = feature_groups.merge_for_processors()
+
+ # Step 1: Diff Processor - adds diff features for market_ext
+ df = self.diff_processor.process(df)
+
+ # Step 2: FlagMarketInjector - adds market_0, market_1
+ df = self.flag_market_injector.process(df)
+
+ # Step 3: FlagSTInjector - adds IsST
+ df = self.flag_st_injector.process(df)
+
+ # Step 4: ColumnRemover - removes specific columns
+ df = self.column_remover.process(df)
+
+ # Step 5: FlagToOnehot - converts industry flags to indus_idx
+ df = self.flag_to_onehot.process(df)
+
+ # Step 6: IndusNtrlInjector - industry neutralization
+ # For alpha158 features
+ indus_ntrl_alpha = IndusNtrlInjector(ALPHA158_COLS, suffix='_ntrl')
+ df = indus_ntrl_alpha.process(df, df) # Pass df as both feature and industry source
+
+ # For market_ext features (with diff columns)
+ market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
+ # Remove columns that were dropped
+ market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
+ indus_ntrl_ext = IndusNtrlInjector(market_ext_with_diff, suffix='_ntrl')
+ df = indus_ntrl_ext.process(df, df)
+
+ # Step 7: RobustZScoreNorm - robust z-score normalization
+ # Load parameters and create normalizer
+ if self._robust_zscore_params is None:
+ self._robust_zscore_params = load_robust_zscore_params(self.robust_zscore_params_path)
+
+ # Build the list of features to normalize
+ alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
+ market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
+
+ # Feature order for VAE: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
+ norm_feature_cols = (
+ alpha158_ntrl_cols + ALPHA158_COLS +
+ market_ext_ntrl_cols + market_ext_with_diff
+ )
+
+ print(f"Applying RobustZScoreNorm to {len(norm_feature_cols)} features...")
+
+ robust_norm = RobustZScoreNorm(
+ norm_feature_cols,
+ clip_range=(-3, 3),
+ use_qlib_params=True,
+ qlib_mean=self._robust_zscore_params['mean_train'],
+ qlib_std=self._robust_zscore_params['std_train']
+ )
+ df = robust_norm.process(df)
+
+ # Step 8: Fillna - fill NaN with 0
+ # Get all feature columns
+ market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
+ market_flag_cols += ['market_0', 'market_1', 'IsST']
+ market_flag_cols = list(dict.fromkeys(market_flag_cols)) # Remove duplicates
+
+ final_feature_cols = norm_feature_cols + market_flag_cols + ['indus_idx']
+ df = self.fillna.process(df, final_feature_cols)
+
+ print("=" * 60)
+ print("Pipeline complete")
+ print(f" Total columns: {len(df.columns)}")
+ print(f" Rows: {len(df)}")
+
+ # Optionally pack features into struct columns
+ if pack_struct:
+ df = self._pack_into_structs(df)
+ print(f" Packed into struct columns")
+
+ print("=" * 60)
+ return df
+
+ def _pack_into_structs(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Pack feature groups into struct columns.
+
+ Creates:
+ - features_alpha158: struct with 316 fields (158 + 158 _ntrl)
+ - features_market_ext: struct with 14 fields (7 + 7 _ntrl)
+ - features_market_flag: struct with 11 fields
+
+ Returns:
+ DataFrame with columns: instrument, datetime, indus_idx, features_*
+ """
+ # Define column groups
+ alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
+ alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS
+
+ market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
+ market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
+ market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
+ market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff
+
+ market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
+ market_flag_cols += ['market_0', 'market_1', 'IsST']
+ market_flag_cols = list(dict.fromkeys(market_flag_cols))
+
+ # Build result with struct columns
+ result_cols = ['instrument', 'datetime']
+
+ # Check if indus_idx exists
+ if 'indus_idx' in df.columns:
+ result_cols.append('indus_idx')
+
+ # Pack alpha158
+ alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns]
+ if alpha158_cols_in_df:
+ result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158'))
+
+ # Pack market_ext
+ ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns]
+ if ext_cols_in_df:
+ result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext'))
+
+ # Pack market_flag
+ flag_cols_in_df = [c for c in market_flag_cols if c in df.columns]
+ if flag_cols_in_df:
+ result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag'))
+
+ return df.select(result_cols)
+
+ def get_vae_input(
+ self,
+ df: pl.DataFrame | FeatureGroups,
+ exclude_isst: bool = True
+ ) -> np.ndarray:
+ """
+ Prepare VAE input features from transformed DataFrame or FeatureGroups.
+
+ VAE input structure (341 features):
+ - feature group (316): 158 alpha158 + 158 alpha158_ntrl
+ - feature_ext group (14): 7 market_ext + 7 market_ext_ntrl
+ - feature_flag group (11): market flags (excluding IsST)
+
+ NOTE: indus_idx is NOT included in VAE input.
+
+ Args:
+ df: Transformed DataFrame or FeatureGroups container
+ exclude_isst: Whether to exclude IsST from VAE input (default: True)
+
+ Returns:
+ Numpy array of shape (n_samples, VAE_INPUT_DIM)
+ """
+ print("Preparing features for VAE...")
+
+ # Accept either DataFrame or FeatureGroups
+ if isinstance(df, FeatureGroups):
+ df = df.merge_for_processors()
+
+ # Build alpha158 feature columns
+ alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
+ alpha158_cols = ALPHA158_COLS.copy()
+
+ # Build market_ext feature columns (with diff, minus removed columns)
+ market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
+ market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
+ market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
+ market_ext_cols = market_ext_with_diff.copy()
+
+ # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
+ norm_feature_cols = (
+ alpha158_ntrl_cols + alpha158_cols +
+ market_ext_ntrl_cols + market_ext_cols
+ )
+
+ # Market flag columns (excluding IsST if requested)
+ market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
+ market_flag_cols += ['market_0', 'market_1']
+ if not exclude_isst:
+ market_flag_cols.append('IsST')
+ market_flag_cols = list(dict.fromkeys(market_flag_cols))
+
+ # Combine all VAE input columns
+ vae_cols = norm_feature_cols + market_flag_cols
+
+ print(f" norm_feature_cols: {len(norm_feature_cols)}")
+ print(f" market_flag_cols: {len(market_flag_cols)}")
+ print(f" Total VAE input columns: {len(vae_cols)}")
+
+ # Verify all columns exist
+ missing_cols = [c for c in vae_cols if c not in df.columns]
+ if missing_cols:
+ print(f"WARNING: Missing columns: {missing_cols}")
+
+ # Select features and convert to numpy
+ features_df = df.select(vae_cols)
+ features = features_df.to_numpy().astype(np.float32)
+
+ # Handle any remaining NaN/Inf values
+ features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
+
+ print(f"Feature matrix shape: {features.shape}")
+
+ # Verify dimensions
+ if features.shape[1] != VAE_INPUT_DIM:
+ print(f"WARNING: Expected {VAE_INPUT_DIM} features, got {features.shape[1]}")
+
+ if features.shape[1] < VAE_INPUT_DIM:
+ # Pad with zeros
+ padding = np.zeros(
+ (features.shape[0], VAE_INPUT_DIM - features.shape[1]),
+ dtype=np.float32
+ )
+ features = np.concatenate([features, padding], axis=1)
+ print(f"Padded to shape: {features.shape}")
+ else:
+ # Truncate
+ features = features[:, :VAE_INPUT_DIM]
+ print(f"Truncated to shape: {features.shape}")
+
+ return features
diff --git a/stock_1d/d033/alpha158_beta/src/processors/processors.py b/stock_1d/d033/alpha158_beta/src/processors/processors.py
new file mode 100644
index 0000000..c0ee302
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/src/processors/processors.py
@@ -0,0 +1,447 @@
+"""Processor classes for feature transformation.
+
+Processors are stateless, functional components that operate on
+specific feature groups or merged DataFrames.
+"""
+
+import numpy as np
+import polars as pl
+from typing import List, Tuple, Optional, Dict
+
+
+class DiffProcessor:
+ """
+ Diff Processor: Calculate diff features for market_ext columns.
+
+ For each column, calculates diff with period=1 within each instrument group.
+ """
+
+ def __init__(self, columns: List[str]):
+ """
+ Initialize the DiffProcessor.
+
+ Args:
+ columns: List of column names to compute diffs for
+ """
+ self.columns = columns
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Add diff features for specified columns.
+
+ Args:
+ df: Input DataFrame with market_ext columns
+
+ Returns:
+ DataFrame with added {col}_diff columns
+ """
+ print("Applying Diff processor...")
+
+ # CRITICAL: Build all expressions FIRST, then apply in single with_columns()
+ # Use order_by='datetime' to ensure proper time-series ordering within each instrument
+ diff_exprs = [
+ pl.col(col).diff().over('instrument', order_by='datetime').alias(f"{col}_diff")
+ for col in self.columns
+ if col in df.columns
+ ]
+
+ if diff_exprs:
+ df = df.with_columns(diff_exprs)
+
+ return df
+
+
+class FlagMarketInjector:
+ """
+ Flag Market Injector: Create market_0, market_1 columns based on instrument code.
+
+ Maps to Qlib's map_market_sec logic with vocab_size=2:
+ - market_0 (主板): SH60xxx, SZ00xxx
+ - market_1 (科创板/创业板): SH688xxx, SH689xxx, SZ300xxx, SZ301xxx
+
+ NOTE: vocab_size=2 (not 3!) - 新三板/北交所 (NE4xxxx, NE8xxxx) are NOT included.
+ """
+
+ def process(self, df: pl.DataFrame, instrument_col: str = 'instrument') -> pl.DataFrame:
+ """
+ Add market_0, market_1 columns.
+
+ Args:
+ df: Input DataFrame with instrument column
+ instrument_col: Name of the instrument column
+
+ Returns:
+ DataFrame with added market_0 and market_1 columns
+ """
+ print("Applying FlagMarketInjector (vocab_size=2)...")
+
+ # Convert instrument to string and pad to 6 digits
+ inst_str = pl.col(instrument_col).cast(pl.String).str.zfill(6)
+
+ # Determine market type based on first digit
+ is_sh_main = inst_str.str.starts_with('6') # SH600xxx, SH601xxx, etc.
+ is_sz_main = inst_str.str.starts_with('0') # SZ000xxx
+ is_sh_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689') # SH688xxx, SH689xxx
+ is_sz_chi = inst_str.str.starts_with('300') | inst_str.str.starts_with('301') # SZ300xxx, SZ301xxx
+
+ df = df.with_columns([
+ # market_0 = 主板 (SH main + SZ main)
+ (is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'),
+ # market_1 = 科创板 + 创业板 (SH star + SZ ChiNext)
+ (is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1')
+ ])
+
+ return df
+
+
+class FlagSTInjector:
+ """
+ Flag ST Injector: Create IsST column from ST flags.
+
+ Creates IsST = ST_S | ST_Y if ST flags are available,
+ otherwise creates a placeholder column of all zeros.
+ """
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Add IsST column.
+
+ Args:
+ df: Input DataFrame
+
+ Returns:
+ DataFrame with added IsST column
+ """
+ print("Applying FlagSTInjector (creating IsST)...")
+
+ # Check if ST flags are available
+ if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns:
+ # Create IsST from actual ST flags
+ df = df.with_columns([
+ ((pl.col('ST_S').cast(pl.Boolean, strict=False) |
+ pl.col('ST_Y').cast(pl.Boolean, strict=False))
+ .cast(pl.Int8).alias('IsST'))
+ ])
+ else:
+ # Create placeholder (all zeros) if ST flags not available
+ df = df.with_columns([
+ pl.lit(0).cast(pl.Int8).alias('IsST')
+ ])
+
+ return df
+
+
+class ColumnRemover:
+ """
+ Column Remover: Drop specific columns.
+
+ Removes columns that are not needed for the VAE input.
+ """
+
+ def __init__(self, columns_to_remove: List[str]):
+ """
+ Initialize the ColumnRemover.
+
+ Args:
+ columns_to_remove: List of column names to remove
+ """
+ self.columns_to_remove = columns_to_remove
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Remove specified columns.
+
+ Args:
+ df: Input DataFrame
+
+ Returns:
+ DataFrame with specified columns removed
+ """
+ print(f"Applying ColumnRemover (removing {len(self.columns_to_remove)} columns)...")
+
+ # Only remove columns that exist
+ cols_to_drop = [c for c in self.columns_to_remove if c in df.columns]
+ if cols_to_drop:
+ df = df.drop(cols_to_drop)
+
+ return df
+
+
+class FlagToOnehot:
+ """
+ Flag To Onehot: Convert 29 one-hot industry columns to single indus_idx.
+
+ For each row, finds which industry column is True/1 and sets indus_idx to that index.
+ """
+
+ def __init__(self, industry_cols: List[str]):
+ """
+ Initialize the FlagToOnehot.
+
+ Args:
+ industry_cols: List of 29 industry flag column names
+ """
+ self.industry_cols = industry_cols
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Convert industry flags to single indus_idx column.
+
+ Args:
+ df: Input DataFrame with industry flag columns
+
+ Returns:
+ DataFrame with indus_idx column (original industry columns removed)
+ """
+ print("Applying FlagToOnehot (converting 29 industry flags to indus_idx)...")
+
+ # Build a when/then chain to find the industry index
+ # Start with -1 (no industry) as default
+ indus_expr = pl.lit(-1)
+
+ for idx, col in enumerate(self.industry_cols):
+ if col in df.columns:
+ indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr)
+
+ df = df.with_columns([indus_expr.alias('indus_idx')])
+
+ # Drop the original one-hot columns
+ cols_to_drop = [c for c in self.industry_cols if c in df.columns]
+ if cols_to_drop:
+ df = df.drop(cols_to_drop)
+
+ return df
+
+
+class IndusNtrlInjector:
+ """
+ Industry Neutralization Injector: Industry neutralization for features.
+
+ For each feature, subtracts the industry mean (grouped by indus_idx)
+ from the feature value. Creates new columns with "_ntrl" suffix while
+ keeping original columns.
+
+ IMPORTANT: Industry neutralization is done PER DATETIME (cross-sectional),
+ not across the entire dataset. This matches qlib's cal_indus_ntrl behavior.
+ """
+
+ def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'):
+ """
+ Initialize the IndusNtrlInjector.
+
+ Args:
+ feature_cols: List of feature column names to neutralize
+ suffix: Suffix to append to neutralized column names
+ """
+ self.feature_cols = feature_cols
+ self.suffix = suffix
+
+ def process(
+ self,
+ feature_df: pl.DataFrame,
+ industry_df: pl.DataFrame
+ ) -> pl.DataFrame:
+ """
+ Apply industry neutralization to specified features.
+
+ Args:
+ feature_df: DataFrame with feature columns to neutralize
+ industry_df: DataFrame with indus_idx column (must have instrument, datetime)
+
+ Returns:
+ DataFrame with added {col}_ntrl columns
+ """
+ print(f"Applying IndusNtrlInjector to {len(self.feature_cols)} features...")
+
+ # Check if indus_idx already exists in feature_df
+ if 'indus_idx' in feature_df.columns:
+ df = feature_df
+ else:
+ # Merge industry index into feature dataframe
+ df = feature_df.join(
+ industry_df.select(['instrument', 'datetime', 'indus_idx']),
+ on=['instrument', 'datetime'],
+ how='left',
+ suffix='_indus'
+ )
+
+ # Filter to only columns that exist
+ existing_cols = [c for c in self.feature_cols if c in df.columns]
+
+ # CRITICAL: Build all neutralization expressions FIRST, then apply in single with_columns()
+ # Use order_by='datetime' to ensure proper time-series ordering within each group
+ # The neutralization is done per datetime (cross-sectional), so order_by='datetime'
+ # ensures values are processed in chronological order
+ ntrl_exprs = [
+ (pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'], order_by='datetime')).alias(f"{col}{self.suffix}")
+ for col in existing_cols
+ ]
+
+ if ntrl_exprs:
+ df = df.with_columns(ntrl_exprs)
+
+ return df
+
+
+class RobustZScoreNorm:
+ """
+ Robust Z-Score Normalization: Per datetime normalization.
+
+ (x - median) / (1.4826 * MAD) where MAD = median(|x - median|)
+ Clip outliers at [-3, 3].
+
+ Supports pre-fitted parameters from qlib's pickled processor:
+ normalizer = RobustZScoreNorm(
+ feature_cols=feature_cols,
+ use_qlib_params=True,
+ qlib_mean=zscore_proc.mean_train,
+ qlib_std=zscore_proc.std_train
+ )
+ """
+
+ def __init__(
+ self,
+ feature_cols: List[str],
+ clip_range: Tuple[float, float] = (-3, 3),
+ use_qlib_params: bool = False,
+ qlib_mean: Optional[np.ndarray] = None,
+ qlib_std: Optional[np.ndarray] = None
+ ):
+ """
+ Initialize the RobustZScoreNorm.
+
+ Args:
+ feature_cols: List of feature column names to normalize
+ clip_range: Tuple of (min, max) for clipping normalized values
+ use_qlib_params: Whether to use pre-fitted parameters from qlib
+ qlib_mean: Pre-fitted mean array from qlib (required if use_qlib_params=True)
+ qlib_std: Pre-fitted std array from qlib (required if use_qlib_params=True)
+ """
+ self.feature_cols = feature_cols
+ self.clip_range = clip_range
+ self.use_qlib_params = use_qlib_params
+ self.mean_train = qlib_mean
+ self.std_train = qlib_std
+
+ if use_qlib_params:
+ if qlib_mean is None or qlib_std is None:
+ raise ValueError("Must provide qlib_mean and qlib_std when use_qlib_params=True")
+ print(f"Using pre-fitted qlib parameters (mean shape: {qlib_mean.shape}, std shape: {qlib_std.shape})")
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Apply robust z-score normalization.
+
+ Args:
+ df: Input DataFrame with feature columns
+
+ Returns:
+ DataFrame with normalized feature columns (in-place modification)
+ """
+ print(f"Applying RobustZScoreNorm to {len(self.feature_cols)} features...")
+
+ # Filter to only columns that exist
+ existing_cols = [c for c in self.feature_cols if c in df.columns]
+
+ if self.use_qlib_params:
+ # Use pre-fitted parameters from qlib (fit once, apply to all dates)
+ # CRITICAL: Build all normalization expressions FIRST, then apply in single with_columns()
+ # This avoids creating a copy per column (330 columns = 330 copies if done in loop)
+ norm_exprs = []
+ for i, col in enumerate(existing_cols):
+ if i < len(self.mean_train):
+ mean_val = float(self.mean_train[i])
+ std_val = float(self.std_train[i])
+ norm_exprs.append(
+ ((pl.col(col) - mean_val) / (std_val + 1e-8))
+ .clip(self.clip_range[0], self.clip_range[1])
+ .alias(col)
+ )
+
+ if norm_exprs:
+ df = df.with_columns(norm_exprs)
+ else:
+ # Compute per-datetime robust z-score (original behavior)
+ # CRITICAL: Build all expressions with temp columns first, then clean up in single drop()
+ all_exprs = []
+ temp_cols = []
+
+ for col in existing_cols:
+ # Compute median per datetime
+ median_col = f"__median_{col}"
+ temp_cols.append(median_col)
+ all_exprs.append(
+ pl.col(col).median().over('datetime').alias(median_col)
+ )
+
+ # Compute absolute deviation
+ abs_dev_col = f"__absdev_{col}"
+ temp_cols.append(abs_dev_col)
+ all_exprs.append(
+ (pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col)
+ )
+
+ # Compute MAD (median of absolute deviations)
+ mad_col = f"__mad_{col}"
+ temp_cols.append(mad_col)
+ all_exprs.append(
+ pl.col(abs_dev_col).median().over('datetime').alias(mad_col)
+ )
+
+ # Compute robust z-score and clip (modifies original column)
+ all_exprs.append(
+ ((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8))
+ .clip(self.clip_range[0], self.clip_range[1])
+ .alias(col)
+ )
+
+ # Apply all expressions in single with_columns()
+ if all_exprs:
+ df = df.with_columns(all_exprs)
+
+ # Clean up all temporary columns in single drop()
+ if temp_cols:
+ df = df.drop(temp_cols)
+
+ return df
+
+
+class Fillna:
+ """
+ Fill NaN: Fill all NaN values with 0 for numeric columns.
+ """
+
+ def process(
+ self,
+ df: pl.DataFrame,
+ feature_cols: List[str]
+ ) -> pl.DataFrame:
+ """
+ Fill NaN values with 0 for specified columns.
+
+ Args:
+ df: Input DataFrame
+ feature_cols: List of column names to fill NaN values for
+
+ Returns:
+ DataFrame with NaN values filled with 0
+ """
+ print("Applying Fillna processor...")
+
+ # Filter to only columns that exist and are numeric (not boolean)
+ existing_cols = [
+ c for c in feature_cols
+ if c in df.columns and df[c].dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64, pl.UInt32, pl.UInt64]
+ ]
+
+ # CRITICAL: Build all fill expressions FIRST, then apply in single with_columns()
+ # This avoids creating a copy per column (~345 columns = ~345 copies if done in loop)
+ fill_exprs = [
+ pl.col(col).fill_null(0.0).fill_nan(0.0)
+ for col in existing_cols
+ ]
+
+ if fill_exprs:
+ df = df.with_columns(fill_exprs)
+
+ return df