- Remove split_at_end parameter from pipeline.transform(), always return DataFrame - Add pack_struct parameter to pack feature groups into struct columns - Rename exporters: select_feature_groups_from_df -> get_groups, select_feature_groups -> get_groups_from_fg - Add pack_structs() and unpack_struct() helper functions - Remove split_from_merged() method from FeatureGroups (no longer needed) - Rename dump_polars_dataset.py to dump_features.py with --pack-struct flag - Update README with new CLI usage and struct column documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>master
parent
26a694298d
commit
5109ac4eb3
@ -0,0 +1,265 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Script to generate and dump transformed features from the alpha158_beta pipeline.
|
||||||
|
|
||||||
|
This script provides fine-grained control over the feature generation and dumping process:
|
||||||
|
- Select which feature groups to dump (alpha158, market_ext, market_flag, merged, vae_input)
|
||||||
|
- Choose output format (parquet, pickle, numpy)
|
||||||
|
- Control date range and universe filtering
|
||||||
|
- Save intermediate pipeline outputs
|
||||||
|
- Enable streaming mode for large datasets (>1 year)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Dump all features to parquet
|
||||||
|
python dump_features.py --start-date 2025-01-01 --end-date 2025-01-31
|
||||||
|
|
||||||
|
# Dump only alpha158 features to pickle
|
||||||
|
python dump_features.py --groups alpha158 --format pickle
|
||||||
|
|
||||||
|
# Dump with custom output path
|
||||||
|
python dump_features.py --output /path/to/output.parquet
|
||||||
|
|
||||||
|
# Dump merged features with all columns
|
||||||
|
python dump_features.py --groups merged --verbose
|
||||||
|
|
||||||
|
# Use streaming mode for large date ranges (>1 year)
|
||||||
|
python dump_features.py --start-date 2020-01-01 --end-date 2023-12-31 --streaming
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
# Add src to path for imports
|
||||||
|
SCRIPT_DIR = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(SCRIPT_DIR.parent / 'src'))
|
||||||
|
|
||||||
|
from processors import (
|
||||||
|
FeaturePipeline,
|
||||||
|
FeatureGroups,
|
||||||
|
VAE_INPUT_DIM,
|
||||||
|
ALPHA158_COLS,
|
||||||
|
MARKET_EXT_BASE_COLS,
|
||||||
|
COLUMNS_TO_REMOVE,
|
||||||
|
get_groups,
|
||||||
|
dump_to_parquet,
|
||||||
|
dump_to_pickle,
|
||||||
|
dump_to_numpy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default output directory
|
||||||
|
DEFAULT_OUTPUT_DIR = SCRIPT_DIR.parent / "data"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_and_dump(
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
output_path: str,
|
||||||
|
output_format: str = 'parquet',
|
||||||
|
groups: List[str] = None,
|
||||||
|
universe: str = 'csiallx',
|
||||||
|
filter_universe: bool = True,
|
||||||
|
robust_zscore_params_path: Optional[str] = None,
|
||||||
|
verbose: bool = True,
|
||||||
|
pack_struct: bool = False,
|
||||||
|
streaming: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generate features and dump to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
output_path: Output file path
|
||||||
|
output_format: Output format ('parquet', 'pickle', 'numpy')
|
||||||
|
groups: Feature groups to dump (default: ['merged'])
|
||||||
|
universe: Stock universe name
|
||||||
|
filter_universe: Whether to filter to stock universe
|
||||||
|
robust_zscore_params_path: Path to robust zscore parameters
|
||||||
|
verbose: Whether to print progress
|
||||||
|
pack_struct: If True, pack each feature group into struct columns
|
||||||
|
(features_alpha158, features_market_ext, features_market_flag)
|
||||||
|
streaming: If True, use Polars streaming mode for large datasets (>1 year)
|
||||||
|
"""
|
||||||
|
if groups is None:
|
||||||
|
groups = ['merged']
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Feature Dump Script")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Date range: {start_date} to {end_date}")
|
||||||
|
print(f"Output format: {output_format}")
|
||||||
|
print(f"Feature groups: {groups}")
|
||||||
|
print(f"Universe: {universe} (filter: {filter_universe})")
|
||||||
|
print(f"Pack struct: {pack_struct}")
|
||||||
|
print(f"Output path: {output_path}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Initialize pipeline
|
||||||
|
pipeline = FeaturePipeline(
|
||||||
|
robust_zscore_params_path=robust_zscore_params_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
feature_groups = pipeline.load_data(
|
||||||
|
start_date, end_date,
|
||||||
|
filter_universe=filter_universe,
|
||||||
|
universe_name=universe,
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply transformations - get merged DataFrame (pipeline always returns merged DataFrame now)
|
||||||
|
df_transformed = pipeline.transform(feature_groups, pack_struct=pack_struct)
|
||||||
|
|
||||||
|
# Select feature groups from merged DataFrame
|
||||||
|
outputs = get_groups(df_transformed, groups, verbose, use_struct=False)
|
||||||
|
|
||||||
|
# Ensure output directory exists
|
||||||
|
output_dir = os.path.dirname(output_path)
|
||||||
|
if output_dir:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Dump to file(s)
|
||||||
|
if output_format == 'numpy':
|
||||||
|
# For numpy, we save the merged features
|
||||||
|
dump_to_numpy(feature_groups, output_path, include_metadata=True, verbose=verbose)
|
||||||
|
elif output_format == 'pickle':
|
||||||
|
if 'merged' in outputs:
|
||||||
|
dump_to_pickle(outputs['merged'], output_path, verbose=verbose)
|
||||||
|
elif len(outputs) == 1:
|
||||||
|
# Single group output
|
||||||
|
key = list(outputs.keys())[0]
|
||||||
|
base_path = Path(output_path)
|
||||||
|
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||||
|
dump_to_pickle(outputs[key], dump_path, verbose=verbose)
|
||||||
|
else:
|
||||||
|
# Multiple groups - save each separately
|
||||||
|
base_path = Path(output_path)
|
||||||
|
for key, df_out in outputs.items():
|
||||||
|
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||||
|
dump_to_pickle(df_out, dump_path, verbose=verbose)
|
||||||
|
else: # parquet
|
||||||
|
if 'merged' in outputs:
|
||||||
|
dump_to_parquet(outputs['merged'], output_path, verbose=verbose)
|
||||||
|
elif len(outputs) == 1:
|
||||||
|
# Single group output
|
||||||
|
key = list(outputs.keys())[0]
|
||||||
|
base_path = Path(output_path)
|
||||||
|
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||||
|
dump_to_parquet(outputs[key], dump_path, verbose=verbose)
|
||||||
|
else:
|
||||||
|
# Multiple groups - save each separately
|
||||||
|
base_path = Path(output_path)
|
||||||
|
for key, df_out in outputs.items():
|
||||||
|
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||||
|
dump_to_parquet(df_out, dump_path, verbose=verbose)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Feature dump complete!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate and dump transformed features from alpha158_beta pipeline"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Date range
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-date", type=str, required=True,
|
||||||
|
help="Start date in YYYY-MM-DD format"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--end-date", type=str, required=True,
|
||||||
|
help="End date in YYYY-MM-DD format"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output settings
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", type=str, default=None,
|
||||||
|
help=f"Output file path (default: {DEFAULT_OUTPUT_DIR}/features.parquet)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--format", "-f", type=str, default='parquet',
|
||||||
|
choices=['parquet', 'pickle', 'numpy'],
|
||||||
|
help="Output format (default: parquet)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Feature groups
|
||||||
|
parser.add_argument(
|
||||||
|
"--groups", "-g", type=str, nargs='+', default=['merged'],
|
||||||
|
choices=['merged', 'alpha158', 'market_ext', 'market_flag', 'vae_input'],
|
||||||
|
help="Feature groups to dump (default: merged)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Universe settings
|
||||||
|
parser.add_argument(
|
||||||
|
"--universe", type=str, default='csiallx',
|
||||||
|
help="Stock universe name (default: csiallx)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-filter-universe", action="store_true",
|
||||||
|
help="Disable stock universe filtering"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Robust zscore parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--robust-zscore-params", type=str, default=None,
|
||||||
|
help="Path to robust zscore parameters directory"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verbose mode
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", "-v", action="store_true", default=True,
|
||||||
|
help="Enable verbose output (default: True)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quiet", "-q", action="store_true",
|
||||||
|
help="Disable verbose output"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Struct option
|
||||||
|
parser.add_argument(
|
||||||
|
"--pack-struct", "-s", action="store_true",
|
||||||
|
help="Pack each feature group into separate struct columns (features_alpha158, features_market_ext, features_market_flag)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Streaming option
|
||||||
|
parser.add_argument(
|
||||||
|
"--streaming", action="store_true",
|
||||||
|
help="Use Polars streaming mode for large datasets (recommended for date ranges > 1 year)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Handle verbose/quiet flags
|
||||||
|
verbose = args.verbose and not args.quiet
|
||||||
|
|
||||||
|
# Set default output path
|
||||||
|
if args.output is None:
|
||||||
|
# Build default output path: {data_dir}/features_{group}.parquet
|
||||||
|
# Note: generate_and_dump will add group suffix, so use base name "features"
|
||||||
|
output_path = str(DEFAULT_OUTPUT_DIR / "features.parquet")
|
||||||
|
else:
|
||||||
|
output_path = args.output
|
||||||
|
|
||||||
|
# Generate and dump
|
||||||
|
generate_and_dump(
|
||||||
|
start_date=args.start_date,
|
||||||
|
end_date=args.end_date,
|
||||||
|
output_path=output_path,
|
||||||
|
output_format=args.format,
|
||||||
|
groups=args.groups,
|
||||||
|
universe=args.universe,
|
||||||
|
filter_universe=not args.no_filter_universe,
|
||||||
|
robust_zscore_params_path=args.robust_zscore_params,
|
||||||
|
verbose=verbose,
|
||||||
|
pack_struct=args.pack_struct,
|
||||||
|
streaming=args.streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,364 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
"""
|
|
||||||
Script to dump raw and processed datasets using the polars-based pipeline.
|
|
||||||
|
|
||||||
This generates:
|
|
||||||
1. Raw data (before applying processors) - equivalent to qlib's handler output
|
|
||||||
2. Processed data (after applying all processors) - ready for VAE encoding
|
|
||||||
|
|
||||||
Date range: 2026-02-23 to today (2026-02-27)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import pickle as pkl
|
|
||||||
import numpy as np
|
|
||||||
import polars as pl
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
# Import processors from the new shared module
|
|
||||||
from cta_1d.src.processors import (
|
|
||||||
DiffProcessor,
|
|
||||||
FlagMarketInjector,
|
|
||||||
FlagSTInjector,
|
|
||||||
ColumnRemover,
|
|
||||||
FlagToOnehot,
|
|
||||||
IndusNtrlInjector,
|
|
||||||
RobustZScoreNorm,
|
|
||||||
Fillna,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Import constants from local module
|
|
||||||
from generate_beta_embedding import (
|
|
||||||
ALPHA158_COLS,
|
|
||||||
INDUSTRY_FLAG_COLS,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Date range
|
|
||||||
START_DATE = "2026-02-23"
|
|
||||||
END_DATE = "2026-02-27"
|
|
||||||
|
|
||||||
# Output directory
|
|
||||||
OUTPUT_DIR = Path(__file__).parent.parent / "data_polars"
|
|
||||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
|
|
||||||
"""
|
|
||||||
Apply the full processor pipeline (equivalent to qlib's proc_list).
|
|
||||||
|
|
||||||
This mimics the qlib proc_list:
|
|
||||||
0. Diff: Adds diff features for market_ext columns
|
|
||||||
1. FlagMarketInjector: Adds market_0, market_1
|
|
||||||
2. FlagSTInjector: Adds IsST (placeholder if ST flags not available)
|
|
||||||
3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
|
|
||||||
4. FlagToOnehot: Converts 29 industry flags to indus_idx
|
|
||||||
5. IndusNtrlInjector: Industry neutralization for feature
|
|
||||||
6. IndusNtrlInjector: Industry neutralization for feature_ext
|
|
||||||
7. RobustZScoreNorm: Normalization using pre-fitted qlib params
|
|
||||||
8. Fillna: Fill NaN with 0
|
|
||||||
"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("Applying processor pipeline")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# market_ext columns (4 base)
|
|
||||||
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
|
||||||
|
|
||||||
# market_flag columns (12 total before ColumnRemover)
|
|
||||||
market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
|
|
||||||
'open_limit', 'close_limit', 'low_limit',
|
|
||||||
'open_stop', 'close_stop', 'high_stop']
|
|
||||||
|
|
||||||
# Step 1: Diff Processor
|
|
||||||
print("\n[1] Applying Diff processor...")
|
|
||||||
diff_processor = DiffProcessor(market_ext_base)
|
|
||||||
df = diff_processor.process(df)
|
|
||||||
|
|
||||||
# After Diff: market_ext has 8 columns (4 base + 4 diff)
|
|
||||||
market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base]
|
|
||||||
|
|
||||||
# Step 2: FlagMarketInjector (adds market_0, market_1)
|
|
||||||
print("[2] Applying FlagMarketInjector...")
|
|
||||||
flag_injector = FlagMarketInjector()
|
|
||||||
df = flag_injector.process(df)
|
|
||||||
|
|
||||||
# Add market_0, market_1 to flag list
|
|
||||||
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
|
|
||||||
|
|
||||||
# Step 3: FlagSTInjector - adds IsST (placeholder if ST flags not available)
|
|
||||||
print("[3] Applying FlagSTInjector...")
|
|
||||||
flag_st_injector = FlagSTInjector()
|
|
||||||
df = flag_st_injector.process(df)
|
|
||||||
|
|
||||||
# Add IsST to flag list
|
|
||||||
market_flag_with_st = market_flag_with_market + ['IsST']
|
|
||||||
|
|
||||||
# Step 4: ColumnRemover
|
|
||||||
print("[4] Applying ColumnRemover...")
|
|
||||||
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
|
||||||
remover = ColumnRemover(columns_to_remove)
|
|
||||||
df = remover.process(df)
|
|
||||||
|
|
||||||
# Update column lists after removal
|
|
||||||
market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove]
|
|
||||||
market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove]
|
|
||||||
|
|
||||||
print(f" Removed columns: {columns_to_remove}")
|
|
||||||
print(f" Remaining market_ext: {len(market_ext_cols)} columns")
|
|
||||||
print(f" Remaining market_flag: {len(market_flag_with_st)} columns")
|
|
||||||
|
|
||||||
# Step 5: FlagToOnehot
|
|
||||||
print("[5] Applying FlagToOnehot...")
|
|
||||||
flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
|
|
||||||
df = flag_to_onehot.process(df)
|
|
||||||
|
|
||||||
# Step 6 & 7: IndusNtrlInjector
|
|
||||||
print("[6] Applying IndusNtrlInjector for alpha158...")
|
|
||||||
alpha158_cols = ALPHA158_COLS.copy()
|
|
||||||
indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl')
|
|
||||||
df = indus_ntrl_alpha.process(df)
|
|
||||||
|
|
||||||
print("[7] Applying IndusNtrlInjector for market_ext...")
|
|
||||||
indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl')
|
|
||||||
df = indus_ntrl_ext.process(df)
|
|
||||||
|
|
||||||
# Build column lists for normalization
|
|
||||||
alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols]
|
|
||||||
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols]
|
|
||||||
|
|
||||||
# Step 8: RobustZScoreNorm
|
|
||||||
print("[8] Applying RobustZScoreNorm...")
|
|
||||||
norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
|
|
||||||
|
|
||||||
# Load RobustZScoreNorm with pre-fitted parameters from version
|
|
||||||
robust_norm = RobustZScoreNorm.from_version(
|
|
||||||
"csiallx_feature2_ntrla_flag_pnlnorm",
|
|
||||||
feature_cols=norm_feature_cols
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify parameter shape matches expected features
|
|
||||||
expected_features = len(norm_feature_cols)
|
|
||||||
if robust_norm.qlib_mean.shape[0] != expected_features:
|
|
||||||
print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
|
|
||||||
f"got {robust_norm.qlib_mean.shape[0]}")
|
|
||||||
|
|
||||||
df = robust_norm.process(df)
|
|
||||||
|
|
||||||
# Step 9: Fillna
|
|
||||||
print("[9] Applying Fillna...")
|
|
||||||
final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx']
|
|
||||||
fillna = Fillna()
|
|
||||||
df = fillna.process(df, final_feature_cols)
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Processor pipeline complete!")
|
|
||||||
print(f" Normalized features: {len(norm_feature_cols)}")
|
|
||||||
print(f" Market flags: {len(market_flag_with_st)}")
|
|
||||||
print(f" Total features (with indus_idx): {len(final_feature_cols)}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
|
|
||||||
"""
|
|
||||||
Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
|
|
||||||
This matches the format of qlib's output.
|
|
||||||
|
|
||||||
IMPORTANT: Qlib's IndusNtrlInjector outputs columns in order [_ntrl] + [raw],
|
|
||||||
so we need to reorder columns to match this expected order.
|
|
||||||
"""
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# Convert to pandas
|
|
||||||
df = df_polars.to_pandas()
|
|
||||||
|
|
||||||
# Check if datetime and instrument are columns
|
|
||||||
if 'datetime' in df.columns and 'instrument' in df.columns:
|
|
||||||
# Set MultiIndex
|
|
||||||
df = df.set_index(['datetime', 'instrument'])
|
|
||||||
# If they're already not in columns, assume they're already the index
|
|
||||||
|
|
||||||
# Drop raw columns that shouldn't be in processed data
|
|
||||||
raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue']
|
|
||||||
existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns]
|
|
||||||
if existing_raw_cols:
|
|
||||||
df = df.drop(columns=existing_raw_cols)
|
|
||||||
|
|
||||||
# Build MultiIndex columns based on column name patterns
|
|
||||||
# IMPORTANT: Qlib order is [_ntrl columns] + [raw columns] for each group
|
|
||||||
columns_with_group = []
|
|
||||||
|
|
||||||
# Define column sets
|
|
||||||
alpha158_base = set(ALPHA158_COLS)
|
|
||||||
market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'}
|
|
||||||
market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'}
|
|
||||||
market_ext_all = market_ext_base | market_ext_diff
|
|
||||||
feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
|
|
||||||
'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
|
|
||||||
|
|
||||||
# First pass: collect _ntrl columns (these come first in qlib order)
|
|
||||||
ntrl_alpha158_cols = []
|
|
||||||
ntrl_market_ext_cols = []
|
|
||||||
raw_alpha158_cols = []
|
|
||||||
raw_market_ext_cols = []
|
|
||||||
flag_cols = []
|
|
||||||
indus_idx_col = None
|
|
||||||
|
|
||||||
for col in df.columns:
|
|
||||||
if col == 'indus_idx':
|
|
||||||
indus_idx_col = col
|
|
||||||
elif col in feature_flag_cols:
|
|
||||||
flag_cols.append(col)
|
|
||||||
elif col.endswith('_ntrl'):
|
|
||||||
base_name = col[:-5] # Remove _ntrl suffix (5 characters)
|
|
||||||
if base_name in alpha158_base:
|
|
||||||
ntrl_alpha158_cols.append(col)
|
|
||||||
elif base_name in market_ext_all:
|
|
||||||
ntrl_market_ext_cols.append(col)
|
|
||||||
elif col in alpha158_base:
|
|
||||||
raw_alpha158_cols.append(col)
|
|
||||||
elif col in market_ext_all:
|
|
||||||
raw_market_ext_cols.append(col)
|
|
||||||
elif col in INDUSTRY_FLAG_COLS:
|
|
||||||
columns_with_group.append(('indus_flag', col))
|
|
||||||
elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
|
|
||||||
columns_with_group.append(('st_flag', col))
|
|
||||||
else:
|
|
||||||
# Unknown column - print warning
|
|
||||||
print(f" Warning: Unknown column '{col}', assigning to 'other' group")
|
|
||||||
columns_with_group.append(('other', col))
|
|
||||||
|
|
||||||
# Build columns in qlib order: [_ntrl] + [raw] for each feature group
|
|
||||||
# Feature group: alpha158_ntrl + alpha158
|
|
||||||
for col in sorted(ntrl_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x.replace('_ntrl', '')) if x.replace('_ntrl', '') in ALPHA158_COLS else 999):
|
|
||||||
columns_with_group.append(('feature', col))
|
|
||||||
for col in sorted(raw_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x) if x in ALPHA158_COLS else 999):
|
|
||||||
columns_with_group.append(('feature', col))
|
|
||||||
|
|
||||||
# Feature_ext group: market_ext_ntrl + market_ext
|
|
||||||
for col in ntrl_market_ext_cols:
|
|
||||||
columns_with_group.append(('feature_ext', col))
|
|
||||||
for col in raw_market_ext_cols:
|
|
||||||
columns_with_group.append(('feature_ext', col))
|
|
||||||
|
|
||||||
# Feature_flag group
|
|
||||||
for col in flag_cols:
|
|
||||||
columns_with_group.append(('feature_flag', col))
|
|
||||||
|
|
||||||
# Indus_idx
|
|
||||||
if indus_idx_col:
|
|
||||||
columns_with_group.append(('indus_idx', indus_idx_col))
|
|
||||||
|
|
||||||
# Create MultiIndex columns
|
|
||||||
multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
|
|
||||||
df.columns = multi_cols
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
print("=" * 80)
|
|
||||||
print("Dumping Polars Dataset")
|
|
||||||
print("=" * 80)
|
|
||||||
print(f"Date range: {START_DATE} to {END_DATE}")
|
|
||||||
print(f"Output directory: {OUTPUT_DIR}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Step 1: Load all data
|
|
||||||
print("Step 1: Loading data from parquet...")
|
|
||||||
df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE)
|
|
||||||
print(f" Alpha158 shape: {df_alpha.shape}")
|
|
||||||
print(f" Kline (market_ext) shape: {df_kline.shape}")
|
|
||||||
print(f" Flags shape: {df_flag.shape}")
|
|
||||||
print(f" Industry shape: {df_industry.shape}")
|
|
||||||
|
|
||||||
# Step 2: Merge data sources
|
|
||||||
print("\nStep 2: Merging data sources...")
|
|
||||||
df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
|
|
||||||
print(f" Merged shape (after csiallx filter): {df_merged.shape}")
|
|
||||||
|
|
||||||
# Step 3: Save raw data (before processors)
|
|
||||||
print("\nStep 3: Saving raw data (before processors)...")
|
|
||||||
|
|
||||||
# Keep columns that match qlib's raw output format
|
|
||||||
# Include datetime and instrument for MultiIndex conversion
|
|
||||||
raw_columns = (
|
|
||||||
['datetime', 'instrument'] + # Index columns
|
|
||||||
ALPHA158_COLS + # feature group
|
|
||||||
['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base
|
|
||||||
['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline
|
|
||||||
'open_limit', 'close_limit', 'low_limit',
|
|
||||||
'open_stop', 'close_stop', 'high_stop'] +
|
|
||||||
INDUSTRY_FLAG_COLS + # indus_flag
|
|
||||||
(['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter to available columns
|
|
||||||
available_raw_cols = [c for c in raw_columns if c in df_merged.columns]
|
|
||||||
print(f" Selecting {len(available_raw_cols)} columns for raw data...")
|
|
||||||
df_raw_polars = df_merged.select(available_raw_cols)
|
|
||||||
|
|
||||||
# Convert to pandas with MultiIndex
|
|
||||||
df_raw_pd = convert_to_multiindex_df(df_raw_polars)
|
|
||||||
|
|
||||||
raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
|
|
||||||
with open(raw_output_path, "wb") as f:
|
|
||||||
pkl.dump(df_raw_pd, f)
|
|
||||||
print(f" Saved raw data to: {raw_output_path}")
|
|
||||||
print(f" Raw data shape: {df_raw_pd.shape}")
|
|
||||||
print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}")
|
|
||||||
|
|
||||||
# Step 4: Apply processor pipeline
|
|
||||||
print("\nStep 4: Applying processor pipeline...")
|
|
||||||
df_processed = apply_processor_pipeline(df_merged)
|
|
||||||
|
|
||||||
# Step 5: Save processed data
|
|
||||||
print("\nStep 5: Saving processed data (after processors)...")
|
|
||||||
|
|
||||||
# Convert to pandas with MultiIndex
|
|
||||||
df_processed_pd = convert_to_multiindex_df(df_processed)
|
|
||||||
|
|
||||||
processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
|
|
||||||
with open(processed_output_path, "wb") as f:
|
|
||||||
pkl.dump(df_processed_pd, f)
|
|
||||||
print(f" Saved processed data to: {processed_output_path}")
|
|
||||||
print(f" Processed data shape: {df_processed_pd.shape}")
|
|
||||||
print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}")
|
|
||||||
|
|
||||||
# Count columns per group
|
|
||||||
print("\n Column counts by group:")
|
|
||||||
for grp in df_processed_pd.columns.get_level_values(0).unique().tolist():
|
|
||||||
count = (df_processed_pd.columns.get_level_values(0) == grp).sum()
|
|
||||||
print(f" {grp}: {count} columns")
|
|
||||||
|
|
||||||
# Step 6: Verify column counts
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("Verification")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag']
|
|
||||||
has_market_0 = 'market_0' in feature_flag_cols
|
|
||||||
has_market_1 = 'market_1' in feature_flag_cols
|
|
||||||
|
|
||||||
print(f" feature_flag columns: {feature_flag_cols}")
|
|
||||||
print(f" Has market_0: {has_market_0}")
|
|
||||||
print(f" Has market_1: {has_market_1}")
|
|
||||||
|
|
||||||
if has_market_0 and has_market_1:
|
|
||||||
print("\n SUCCESS: market_0 and market_1 columns are present!")
|
|
||||||
else:
|
|
||||||
print("\n WARNING: market_0 or market_1 columns are missing!")
|
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("Dataset dump complete!")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
Source package for alpha158_beta experiments.
|
||||||
|
|
||||||
|
This package provides modules for data loading, feature transformation,
|
||||||
|
and model training for alpha158 beta factor experiments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .processors import (
|
||||||
|
FeaturePipeline,
|
||||||
|
FeatureGroups,
|
||||||
|
DiffProcessor,
|
||||||
|
FlagMarketInjector,
|
||||||
|
FlagSTInjector,
|
||||||
|
ColumnRemover,
|
||||||
|
FlagToOnehot,
|
||||||
|
IndusNtrlInjector,
|
||||||
|
RobustZScoreNorm,
|
||||||
|
Fillna,
|
||||||
|
load_alpha158,
|
||||||
|
load_market_ext,
|
||||||
|
load_market_flags,
|
||||||
|
load_industry_flags,
|
||||||
|
load_all_data,
|
||||||
|
load_robust_zscore_params,
|
||||||
|
filter_stock_universe,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Main pipeline
|
||||||
|
'FeaturePipeline',
|
||||||
|
'FeatureGroups',
|
||||||
|
|
||||||
|
# Processors
|
||||||
|
'DiffProcessor',
|
||||||
|
'FlagMarketInjector',
|
||||||
|
'FlagSTInjector',
|
||||||
|
'ColumnRemover',
|
||||||
|
'FlagToOnehot',
|
||||||
|
'IndusNtrlInjector',
|
||||||
|
'RobustZScoreNorm',
|
||||||
|
'Fillna',
|
||||||
|
|
||||||
|
# Loaders
|
||||||
|
'load_alpha158',
|
||||||
|
'load_market_ext',
|
||||||
|
'load_market_flags',
|
||||||
|
'load_industry_flags',
|
||||||
|
'load_all_data',
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
'load_robust_zscore_params',
|
||||||
|
'filter_stock_universe',
|
||||||
|
]
|
||||||
@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Processors package for the alpha158_beta feature pipeline.
|
||||||
|
|
||||||
|
This package provides a modular, polars-native data loading and transformation
|
||||||
|
pipeline for generating VAE input features from alpha158 beta factors.
|
||||||
|
|
||||||
|
Main components:
|
||||||
|
- FeatureGroups: Dataclass container for separate feature groups
|
||||||
|
- FeaturePipeline: Main orchestrator class for the full pipeline
|
||||||
|
- Processors: Individual transformation classes (DiffProcessor, IndusNtrlInjector, etc.)
|
||||||
|
- Loaders: Data loading functions for parquet sources
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from processors import FeaturePipeline, FeatureGroups
|
||||||
|
|
||||||
|
pipeline = FeaturePipeline()
|
||||||
|
feature_groups = pipeline.load_data(start_date, end_date)
|
||||||
|
transformed = pipeline.transform(feature_groups)
|
||||||
|
vae_input = pipeline.get_vae_input(transformed)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .dataclass import FeatureGroups
|
||||||
|
from .loaders import (
|
||||||
|
load_alpha158,
|
||||||
|
load_market_ext,
|
||||||
|
load_market_flags,
|
||||||
|
load_industry_flags,
|
||||||
|
load_all_data,
|
||||||
|
load_parquet_by_date_range,
|
||||||
|
get_date_partitions,
|
||||||
|
# Constants
|
||||||
|
PARQUET_ALPHA158_BETA_PATH,
|
||||||
|
PARQUET_KLINE_PATH,
|
||||||
|
PARQUET_MARKET_FLAG_PATH,
|
||||||
|
PARQUET_INDUSTRY_FLAG_PATH,
|
||||||
|
PARQUET_CON_RATING_PATH,
|
||||||
|
INDUSTRY_FLAG_COLS,
|
||||||
|
MARKET_FLAG_COLS_KLINE,
|
||||||
|
MARKET_FLAG_COLS_MARKET,
|
||||||
|
MARKET_EXT_RAW_COLS,
|
||||||
|
)
|
||||||
|
from .processors import (
|
||||||
|
DiffProcessor,
|
||||||
|
FlagMarketInjector,
|
||||||
|
FlagSTInjector,
|
||||||
|
ColumnRemover,
|
||||||
|
FlagToOnehot,
|
||||||
|
IndusNtrlInjector,
|
||||||
|
RobustZScoreNorm,
|
||||||
|
Fillna,
|
||||||
|
)
|
||||||
|
from .pipeline import (
|
||||||
|
FeaturePipeline,
|
||||||
|
load_robust_zscore_params,
|
||||||
|
filter_stock_universe,
|
||||||
|
# Constants
|
||||||
|
ALPHA158_COLS,
|
||||||
|
MARKET_EXT_BASE_COLS,
|
||||||
|
MARKET_FLAG_COLS,
|
||||||
|
COLUMNS_TO_REMOVE,
|
||||||
|
VAE_INPUT_DIM,
|
||||||
|
DEFAULT_ROBUST_ZSCORE_PARAMS_PATH,
|
||||||
|
)
|
||||||
|
from .exporters import (
|
||||||
|
get_groups,
|
||||||
|
get_groups_from_fg,
|
||||||
|
pack_structs,
|
||||||
|
unpack_struct,
|
||||||
|
dump_to_parquet,
|
||||||
|
dump_to_pickle,
|
||||||
|
dump_to_numpy,
|
||||||
|
dump_features,
|
||||||
|
# Also import old names for backward compatibility
|
||||||
|
get_groups as select_feature_groups_from_df,
|
||||||
|
get_groups_from_fg as select_feature_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Main classes
|
||||||
|
'FeaturePipeline',
|
||||||
|
'FeatureGroups',
|
||||||
|
|
||||||
|
# Processors
|
||||||
|
'DiffProcessor',
|
||||||
|
'FlagMarketInjector',
|
||||||
|
'FlagSTInjector',
|
||||||
|
'ColumnRemover',
|
||||||
|
'FlagToOnehot',
|
||||||
|
'IndusNtrlInjector',
|
||||||
|
'RobustZScoreNorm',
|
||||||
|
'Fillna',
|
||||||
|
|
||||||
|
# Loaders
|
||||||
|
'load_alpha158',
|
||||||
|
'load_market_ext',
|
||||||
|
'load_market_flags',
|
||||||
|
'load_industry_flags',
|
||||||
|
'load_all_data',
|
||||||
|
'load_parquet_by_date_range',
|
||||||
|
'get_date_partitions',
|
||||||
|
|
||||||
|
# Utility functions
|
||||||
|
'load_robust_zscore_params',
|
||||||
|
'filter_stock_universe',
|
||||||
|
|
||||||
|
# Exporter functions
|
||||||
|
'get_groups',
|
||||||
|
'get_groups_from_fg',
|
||||||
|
'pack_structs',
|
||||||
|
'unpack_struct',
|
||||||
|
'dump_to_parquet',
|
||||||
|
'dump_to_pickle',
|
||||||
|
'dump_to_numpy',
|
||||||
|
'dump_features',
|
||||||
|
|
||||||
|
# Backward compatibility aliases
|
||||||
|
'select_feature_groups_from_df',
|
||||||
|
'select_feature_groups',
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
'PARQUET_ALPHA158_BETA_PATH',
|
||||||
|
'PARQUET_KLINE_PATH',
|
||||||
|
'PARQUET_MARKET_FLAG_PATH',
|
||||||
|
'PARQUET_INDUSTRY_FLAG_PATH',
|
||||||
|
'PARQUET_CON_RATING_PATH',
|
||||||
|
'INDUSTRY_FLAG_COLS',
|
||||||
|
'MARKET_FLAG_COLS_KLINE',
|
||||||
|
'MARKET_FLAG_COLS_MARKET',
|
||||||
|
'MARKET_EXT_RAW_COLS',
|
||||||
|
'ALPHA158_COLS',
|
||||||
|
'MARKET_EXT_BASE_COLS',
|
||||||
|
'MARKET_FLAG_COLS',
|
||||||
|
'COLUMNS_TO_REMOVE',
|
||||||
|
'VAE_INPUT_DIM',
|
||||||
|
'DEFAULT_ROBUST_ZSCORE_PARAMS_PATH',
|
||||||
|
]
|
||||||
@ -0,0 +1,115 @@
|
|||||||
|
"""Dataclass definitions for the feature pipeline."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, List
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
# Import constants for column categorization
|
||||||
|
# Alpha158 base columns (158 features)
|
||||||
|
ALPHA158_BASE_COLS = [
|
||||||
|
'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
|
||||||
|
'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
|
||||||
|
'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
|
||||||
|
'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
|
||||||
|
'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
|
||||||
|
'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
|
||||||
|
'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
|
||||||
|
'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
|
||||||
|
'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
|
||||||
|
'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
|
||||||
|
'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
|
||||||
|
'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
|
||||||
|
'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
|
||||||
|
'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
|
||||||
|
'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
|
||||||
|
'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
|
||||||
|
'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
|
||||||
|
'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
|
||||||
|
'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
|
||||||
|
'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
|
||||||
|
'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
|
||||||
|
'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
|
||||||
|
'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
|
||||||
|
'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
|
||||||
|
'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
|
||||||
|
'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
|
||||||
|
'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
|
||||||
|
'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
|
||||||
|
'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
|
||||||
|
'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
|
||||||
|
'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Market extension base columns
|
||||||
|
MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||||
|
|
||||||
|
# Market flag base columns (before processors)
|
||||||
|
MARKET_FLAG_BASE_COLS = [
|
||||||
|
'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
|
||||||
|
'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureGroups:
|
||||||
|
"""
|
||||||
|
Container for separate feature groups in the pipeline.
|
||||||
|
|
||||||
|
Keeps feature groups separate throughout the pipeline to avoid
|
||||||
|
unnecessary merging and complex column management.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
alpha158: 158 alpha158 features (+ _ntrl after industry neutralization)
|
||||||
|
market_ext: Market extension features (Turnover, FreeTurnover, MarketValue, etc.)
|
||||||
|
market_flag: Market flag features (IsZt, IsDt, IsXD, etc.)
|
||||||
|
industry: 29 industry flags (converted to indus_idx by FlagToOnehot)
|
||||||
|
indus_idx: Single column industry index (after FlagToOnehot processing)
|
||||||
|
instruments: List of instrument IDs for metadata
|
||||||
|
dates: List of datetime values for metadata
|
||||||
|
"""
|
||||||
|
# Core feature groups
|
||||||
|
alpha158: pl.DataFrame # 158 alpha158 features (+ _ntrl after processing)
|
||||||
|
market_ext: pl.DataFrame # Market extension features (+ _ntrl after processing)
|
||||||
|
market_flag: pl.DataFrame # Market flags (12 cols initially, 11 after ColumnRemover)
|
||||||
|
industry: Optional[pl.DataFrame] = None # 29 industry flags -> indus_idx
|
||||||
|
|
||||||
|
# Processed industry index (separate after FlagToOnehot)
|
||||||
|
indus_idx: Optional[pl.DataFrame] = None # Single column after FlagToOnehot
|
||||||
|
|
||||||
|
# Metadata (extracted from dataframes for easy access)
|
||||||
|
instruments: List[str] = field(default_factory=list)
|
||||||
|
dates: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
def extract_metadata(self) -> None:
|
||||||
|
"""Extract instrument and date lists from the alpha158 dataframe."""
|
||||||
|
if self.alpha158 is not None and len(self.alpha158) > 0:
|
||||||
|
self.instruments = self.alpha158['instrument'].to_list()
|
||||||
|
self.dates = self.alpha158['datetime'].to_list()
|
||||||
|
|
||||||
|
def merge_for_processors(self) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Merge all feature groups into a single DataFrame for processors.
|
||||||
|
|
||||||
|
This is used by processors that need access to multiple groups
|
||||||
|
(e.g., IndusNtrlInjector needs industry index).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged DataFrame with all features
|
||||||
|
"""
|
||||||
|
df = self.alpha158
|
||||||
|
|
||||||
|
# Merge market_ext if not already merged
|
||||||
|
if self.market_ext is not None and self.market_ext is not self.alpha158:
|
||||||
|
df = df.join(self.market_ext, on=['instrument', 'datetime'], how='left')
|
||||||
|
|
||||||
|
# Merge market_flag if not already merged
|
||||||
|
if self.market_flag is not None and self.market_flag is not self.alpha158:
|
||||||
|
df = df.join(self.market_flag, on=['instrument', 'datetime'], how='left')
|
||||||
|
|
||||||
|
# Merge industry/indus_idx if available
|
||||||
|
if self.indus_idx is not None:
|
||||||
|
df = df.join(self.indus_idx, on=['instrument', 'datetime'], how='left')
|
||||||
|
elif self.industry is not None:
|
||||||
|
df = df.join(self.industry, on=['instrument', 'datetime'], how='left')
|
||||||
|
|
||||||
|
return df
|
||||||
@ -0,0 +1,523 @@
|
|||||||
|
"""
|
||||||
|
Feature exporters for the alpha158_beta pipeline.
|
||||||
|
|
||||||
|
This module provides functions to select and export feature groups from the
|
||||||
|
transformed pipeline output. It can be used by both dump_features.py and
|
||||||
|
generate_beta_embedding.py.
|
||||||
|
|
||||||
|
Feature groups:
|
||||||
|
- merged: All columns after transformation
|
||||||
|
- alpha158: Alpha158 features + _ntrl versions
|
||||||
|
- market_ext: Market extended features + _ntrl + _diff
|
||||||
|
- market_flag: Market flag columns
|
||||||
|
- vae_input: 341 features specifically curated for VAE training
|
||||||
|
|
||||||
|
Struct mode:
|
||||||
|
- When pack_struct=True, each feature group is packed into a struct column:
|
||||||
|
- features_alpha158 (316 fields)
|
||||||
|
- features_market_ext (14 fields)
|
||||||
|
- features_market_flag (11 fields)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from .pipeline import (
|
||||||
|
ALPHA158_COLS,
|
||||||
|
MARKET_EXT_BASE_COLS,
|
||||||
|
MARKET_FLAG_COLS,
|
||||||
|
COLUMNS_TO_REMOVE,
|
||||||
|
)
|
||||||
|
from .dataclass import FeatureGroups
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Helper functions for struct packing/unpacking
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def pack_structs(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Pack feature columns into struct columns based on feature groups.
|
||||||
|
|
||||||
|
Creates:
|
||||||
|
- features_alpha158: struct with 316 fields (158 + 158 _ntrl)
|
||||||
|
- features_market_ext: struct with 14 fields (7 + 7 _ntrl)
|
||||||
|
- features_market_flag: struct with 11 fields
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with flat columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with struct columns: instrument, datetime, indus_idx, features_*
|
||||||
|
"""
|
||||||
|
# Define column groups
|
||||||
|
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||||
|
alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS
|
||||||
|
|
||||||
|
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||||
|
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||||
|
market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff
|
||||||
|
|
||||||
|
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_flag_cols += ['market_0', 'market_1', 'IsST']
|
||||||
|
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||||
|
|
||||||
|
# Build result with struct columns
|
||||||
|
result_cols = ['instrument', 'datetime']
|
||||||
|
|
||||||
|
# Check if indus_idx exists
|
||||||
|
if 'indus_idx' in df.columns:
|
||||||
|
result_cols.append('indus_idx')
|
||||||
|
|
||||||
|
# Pack alpha158
|
||||||
|
alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns]
|
||||||
|
if alpha158_cols_in_df:
|
||||||
|
result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158'))
|
||||||
|
|
||||||
|
# Pack market_ext
|
||||||
|
ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns]
|
||||||
|
if ext_cols_in_df:
|
||||||
|
result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext'))
|
||||||
|
|
||||||
|
# Pack market_flag
|
||||||
|
flag_cols_in_df = [c for c in market_flag_cols if c in df.columns]
|
||||||
|
if flag_cols_in_df:
|
||||||
|
result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag'))
|
||||||
|
|
||||||
|
return df.select(result_cols)
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_struct(df: pl.DataFrame, struct_name: str) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Unpack a struct column back into individual columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame containing struct column
|
||||||
|
struct_name: Name of the struct column to unpack
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with struct fields as individual columns
|
||||||
|
"""
|
||||||
|
if struct_name not in df.columns:
|
||||||
|
raise ValueError(f"Struct column '{struct_name}' not found in DataFrame")
|
||||||
|
|
||||||
|
# Get the struct field names
|
||||||
|
struct_dtype = df.schema[struct_name]
|
||||||
|
if not isinstance(struct_dtype, pl.Struct):
|
||||||
|
raise ValueError(f"Column '{struct_name}' is not a struct type")
|
||||||
|
|
||||||
|
field_names = struct_dtype.fields
|
||||||
|
|
||||||
|
# Unpack using struct field access
|
||||||
|
unpacked_cols = []
|
||||||
|
for field in field_names:
|
||||||
|
col_name = field.name
|
||||||
|
unpacked_cols.append(
|
||||||
|
pl.col(struct_name).struct.field(col_name).alias(col_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select original columns + unpacked columns
|
||||||
|
other_cols = [c for c in df.columns if c != struct_name]
|
||||||
|
return df.select(other_cols + unpacked_cols)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_to_parquet(df: pl.DataFrame, path: str, verbose: bool = True) -> None:
|
||||||
|
"""Save DataFrame to parquet file."""
|
||||||
|
if verbose:
|
||||||
|
print(f"Saving to parquet: {path}")
|
||||||
|
df.write_parquet(path)
|
||||||
|
if verbose:
|
||||||
|
print(f" Shape: {df.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
def dump_to_pickle(df: pl.DataFrame, path: str, verbose: bool = True) -> None:
|
||||||
|
"""Save DataFrame to pickle file."""
|
||||||
|
if verbose:
|
||||||
|
print(f"Saving to pickle: {path}")
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
pickle.dump(df, f)
|
||||||
|
if verbose:
|
||||||
|
print(f" Shape: {df.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
def dump_to_numpy(
|
||||||
|
feature_groups: FeatureGroups,
|
||||||
|
path: str,
|
||||||
|
include_metadata: bool = True,
|
||||||
|
verbose: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save features to numpy format.
|
||||||
|
|
||||||
|
Saves:
|
||||||
|
- features.npy: The feature matrix
|
||||||
|
- metadata.pkl: Column names and metadata (if include_metadata=True)
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
print(f"Saving to numpy: {path}")
|
||||||
|
|
||||||
|
# Merge all groups for numpy array
|
||||||
|
df = feature_groups.merge_for_processors()
|
||||||
|
|
||||||
|
# Get all feature columns (exclude instrument, datetime, indus_idx)
|
||||||
|
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
|
||||||
|
|
||||||
|
# Extract features
|
||||||
|
features = df.select(feature_cols).to_numpy().astype(np.float32)
|
||||||
|
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
# Save features
|
||||||
|
base_path = Path(path)
|
||||||
|
if base_path.suffix == '':
|
||||||
|
base_path = Path(str(path) + '.npy')
|
||||||
|
|
||||||
|
np.save(str(base_path), features)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" Features shape: {features.shape}")
|
||||||
|
|
||||||
|
# Save metadata if requested
|
||||||
|
if include_metadata:
|
||||||
|
metadata_path = str(base_path).replace('.npy', '_metadata.pkl')
|
||||||
|
metadata = {
|
||||||
|
'feature_cols': feature_cols,
|
||||||
|
'instruments': df['instrument'].to_list(),
|
||||||
|
'dates': df['datetime'].to_list(),
|
||||||
|
'n_features': len(feature_cols),
|
||||||
|
'n_samples': len(df),
|
||||||
|
}
|
||||||
|
with open(metadata_path, 'wb') as f:
|
||||||
|
pickle.dump(metadata, f)
|
||||||
|
if verbose:
|
||||||
|
print(f" Metadata saved to: {metadata_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_groups(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
groups_to_dump: List[str],
|
||||||
|
verbose: bool = True,
|
||||||
|
use_struct: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Select which feature groups to include in the output from a merged DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Transformed merged DataFrame (flat or with struct columns)
|
||||||
|
groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input')
|
||||||
|
verbose: Whether to print progress
|
||||||
|
use_struct: If True, pack feature columns into a single 'features' struct column.
|
||||||
|
If df already has struct columns (pack_struct=True was used in pipeline),
|
||||||
|
this function will automatically handle them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with selected DataFrames and metadata
|
||||||
|
"""
|
||||||
|
# Check if df already has struct columns from pipeline.pack_struct
|
||||||
|
has_struct_cols = any(
|
||||||
|
isinstance(df.schema.get(c), pl.Struct)
|
||||||
|
for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if 'merged' in groups_to_dump:
|
||||||
|
if has_struct_cols:
|
||||||
|
# Already has struct columns from pipeline.pack_struct
|
||||||
|
result['merged'] = df
|
||||||
|
elif use_struct:
|
||||||
|
# Keep instrument, datetime, and pack rest into struct
|
||||||
|
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
|
||||||
|
result['merged'] = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||||
|
else:
|
||||||
|
result['merged'] = df
|
||||||
|
if verbose:
|
||||||
|
print(f"Merged features: {result['merged'].shape}")
|
||||||
|
|
||||||
|
if 'alpha158' in groups_to_dump:
|
||||||
|
# Check if struct column already exists from pipeline.pack_struct
|
||||||
|
if has_struct_cols and 'features_alpha158' in df.columns:
|
||||||
|
result['alpha158'] = df.select(['instrument', 'datetime', 'features_alpha158'])
|
||||||
|
else:
|
||||||
|
# Select alpha158 columns from merged DataFrame
|
||||||
|
alpha_cols = ['instrument', 'datetime'] + [c for c in ALPHA158_COLS if c in df.columns]
|
||||||
|
# Also include _ntrl versions
|
||||||
|
alpha_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS if f"{c}_ntrl" in df.columns]
|
||||||
|
alpha_cols += alpha_ntrl_cols
|
||||||
|
df_alpha = df.select(alpha_cols)
|
||||||
|
if use_struct:
|
||||||
|
feature_cols = [c for c in df_alpha.columns if c not in ['instrument', 'datetime']]
|
||||||
|
df_alpha = df_alpha.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||||
|
result['alpha158'] = df_alpha
|
||||||
|
if verbose:
|
||||||
|
print(f"Alpha158 features: {result['alpha158'].shape}")
|
||||||
|
|
||||||
|
if 'market_ext' in groups_to_dump:
|
||||||
|
# Check if struct column already exists from pipeline.pack_struct
|
||||||
|
if has_struct_cols and 'features_market_ext' in df.columns:
|
||||||
|
result['market_ext'] = df.select(['instrument', 'datetime', 'features_market_ext'])
|
||||||
|
else:
|
||||||
|
# Select market_ext columns from merged DataFrame
|
||||||
|
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||||
|
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||||
|
ext_cols = ['instrument', 'datetime'] + market_ext_with_diff
|
||||||
|
# Also include _ntrl versions
|
||||||
|
ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff if f"{c}_ntrl" in df.columns]
|
||||||
|
ext_cols += ext_ntrl_cols
|
||||||
|
df_ext = df.select(ext_cols)
|
||||||
|
if use_struct:
|
||||||
|
feature_cols = [c for c in df_ext.columns if c not in ['instrument', 'datetime']]
|
||||||
|
df_ext = df_ext.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||||
|
result['market_ext'] = df_ext
|
||||||
|
if verbose:
|
||||||
|
print(f"Market ext features: {result['market_ext'].shape}")
|
||||||
|
|
||||||
|
if 'market_flag' in groups_to_dump:
|
||||||
|
# Check if struct column already exists from pipeline.pack_struct
|
||||||
|
if has_struct_cols and 'features_market_flag' in df.columns:
|
||||||
|
result['market_flag'] = df.select(['instrument', 'datetime', 'features_market_flag'])
|
||||||
|
else:
|
||||||
|
# Select market_flag columns from merged DataFrame
|
||||||
|
flag_cols = ['instrument', 'datetime']
|
||||||
|
flag_cols += [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE and c in df.columns]
|
||||||
|
flag_cols += ['market_0', 'market_1', 'IsST'] if all(c in df.columns for c in ['market_0', 'market_1', 'IsST']) else []
|
||||||
|
flag_cols = list(dict.fromkeys(flag_cols)) # Remove duplicates
|
||||||
|
df_flag = df.select(flag_cols)
|
||||||
|
if use_struct:
|
||||||
|
feature_cols = [c for c in df_flag.columns if c not in ['instrument', 'datetime']]
|
||||||
|
df_flag = df_flag.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||||
|
result['market_flag'] = df_flag
|
||||||
|
if verbose:
|
||||||
|
print(f"Market flag features: {result['market_flag'].shape}")
|
||||||
|
|
||||||
|
if 'vae_input' in groups_to_dump:
|
||||||
|
# Get VAE input columns from merged DataFrame
|
||||||
|
# VAE input = 330 normalized features + 11 market flags = 341 features
|
||||||
|
# Note: indus_idx is NOT included in VAE input
|
||||||
|
|
||||||
|
# Build alpha158 feature columns (158 original + 158 _ntrl = 316)
|
||||||
|
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||||
|
alpha158_cols = ALPHA158_COLS.copy()
|
||||||
|
|
||||||
|
# Build market_ext feature columns (7 original + 7 _ntrl = 14)
|
||||||
|
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||||
|
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||||
|
market_ext_cols = market_ext_with_diff.copy()
|
||||||
|
|
||||||
|
# VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
|
||||||
|
norm_feature_cols = (
|
||||||
|
alpha158_ntrl_cols + alpha158_cols +
|
||||||
|
market_ext_ntrl_cols + market_ext_cols
|
||||||
|
)
|
||||||
|
|
||||||
|
# Market flag columns (excluding IsST)
|
||||||
|
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_flag_cols += ['market_0', 'market_1']
|
||||||
|
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||||
|
|
||||||
|
# Combine all VAE input columns (341 total)
|
||||||
|
vae_cols = norm_feature_cols + market_flag_cols
|
||||||
|
|
||||||
|
if use_struct:
|
||||||
|
# Pack all features into a single struct column
|
||||||
|
result['vae_input'] = df.select([
|
||||||
|
'instrument',
|
||||||
|
'datetime',
|
||||||
|
pl.struct(vae_cols).alias('features')
|
||||||
|
])
|
||||||
|
if verbose:
|
||||||
|
print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)")
|
||||||
|
else:
|
||||||
|
# Always keep datetime and instrument as index columns
|
||||||
|
# Select features with index columns first
|
||||||
|
result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols)
|
||||||
|
if verbose:
|
||||||
|
print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_groups_from_fg(
|
||||||
|
feature_groups: FeatureGroups,
|
||||||
|
groups_to_dump: List[str],
|
||||||
|
verbose: bool = True,
|
||||||
|
use_struct: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Select which feature groups to include in the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_groups: Transformed FeatureGroups container
|
||||||
|
groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input')
|
||||||
|
verbose: Whether to print progress
|
||||||
|
use_struct: If True, pack feature columns into a single 'features' struct column
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with selected DataFrames and metadata
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if 'merged' in groups_to_dump:
|
||||||
|
df = feature_groups.merge_for_processors()
|
||||||
|
if use_struct:
|
||||||
|
# Keep instrument, datetime, and pack rest into struct
|
||||||
|
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
|
||||||
|
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||||
|
result['merged'] = df
|
||||||
|
if verbose:
|
||||||
|
print(f"Merged features: {result['merged'].shape}")
|
||||||
|
|
||||||
|
if 'alpha158' in groups_to_dump:
|
||||||
|
df = feature_groups.alpha158
|
||||||
|
if use_struct:
|
||||||
|
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
|
||||||
|
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||||
|
result['alpha158'] = df
|
||||||
|
if verbose:
|
||||||
|
print(f"Alpha158 features: {df.shape}")
|
||||||
|
|
||||||
|
if 'market_ext' in groups_to_dump:
|
||||||
|
df = feature_groups.market_ext
|
||||||
|
if use_struct:
|
||||||
|
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
|
||||||
|
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||||
|
result['market_ext'] = df
|
||||||
|
if verbose:
|
||||||
|
print(f"Market ext features: {df.shape}")
|
||||||
|
|
||||||
|
if 'market_flag' in groups_to_dump:
|
||||||
|
df = feature_groups.market_flag
|
||||||
|
if use_struct:
|
||||||
|
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
|
||||||
|
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||||
|
result['market_flag'] = df
|
||||||
|
if verbose:
|
||||||
|
print(f"Market flag features: {df.shape}")
|
||||||
|
|
||||||
|
if 'vae_input' in groups_to_dump:
|
||||||
|
# Get VAE input columns from already-transformed feature_groups
|
||||||
|
# VAE input = 330 normalized features + 11 market flags = 341 features
|
||||||
|
# Note: indus_idx is NOT included in VAE input
|
||||||
|
df = feature_groups.merge_for_processors()
|
||||||
|
|
||||||
|
# Build alpha158 feature columns (158 original + 158 _ntrl = 316)
|
||||||
|
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||||
|
alpha158_cols = ALPHA158_COLS.copy()
|
||||||
|
|
||||||
|
# Build market_ext feature columns (7 original + 7 _ntrl = 14)
|
||||||
|
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||||
|
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||||
|
market_ext_cols = market_ext_with_diff.copy()
|
||||||
|
|
||||||
|
# VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
|
||||||
|
norm_feature_cols = (
|
||||||
|
alpha158_ntrl_cols + alpha158_cols +
|
||||||
|
market_ext_ntrl_cols + market_ext_cols
|
||||||
|
)
|
||||||
|
|
||||||
|
# Market flag columns (excluding IsST)
|
||||||
|
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_flag_cols += ['market_0', 'market_1']
|
||||||
|
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||||
|
|
||||||
|
# Combine all VAE input columns (341 total)
|
||||||
|
vae_cols = norm_feature_cols + market_flag_cols
|
||||||
|
|
||||||
|
if use_struct:
|
||||||
|
# Pack all features into a single struct column
|
||||||
|
result['vae_input'] = df.select([
|
||||||
|
'instrument',
|
||||||
|
'datetime',
|
||||||
|
pl.struct(vae_cols).alias('features')
|
||||||
|
])
|
||||||
|
if verbose:
|
||||||
|
print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)")
|
||||||
|
else:
|
||||||
|
# Always keep datetime and instrument as index columns
|
||||||
|
# Select features with index columns first
|
||||||
|
result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols)
|
||||||
|
if verbose:
|
||||||
|
print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def dump_features(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
output_path: str,
|
||||||
|
output_format: str = 'parquet',
|
||||||
|
groups: List[str] = None,
|
||||||
|
verbose: bool = True,
|
||||||
|
use_struct: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Dump features to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Transformed merged DataFrame
|
||||||
|
output_path: Output file path
|
||||||
|
output_format: Output format ('parquet', 'pickle', 'numpy')
|
||||||
|
groups: Feature groups to dump (default: ['merged'])
|
||||||
|
verbose: Whether to print progress
|
||||||
|
use_struct: If True, pack feature columns into a single 'features' struct column
|
||||||
|
"""
|
||||||
|
if groups is None:
|
||||||
|
groups = ['merged']
|
||||||
|
|
||||||
|
# Select feature groups from merged DataFrame
|
||||||
|
outputs = get_groups(df, groups, verbose, use_struct)
|
||||||
|
|
||||||
|
# Ensure output directory exists
|
||||||
|
output_dir = os.path.dirname(output_path)
|
||||||
|
if output_dir:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# For numpy, we use FeatureGroups (need to convert df back)
|
||||||
|
# This is a simplified version - for numpy, use dump_to_numpy directly
|
||||||
|
if output_format == 'numpy':
|
||||||
|
raise NotImplementedError(
|
||||||
|
"For numpy output, use the FeaturePipeline directly. "
|
||||||
|
"This function handles parquet/pickle output only."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dump to file(s)
|
||||||
|
if output_format == 'pickle':
|
||||||
|
if 'merged' in outputs:
|
||||||
|
dump_to_pickle(outputs['merged'], output_path, verbose=verbose)
|
||||||
|
elif len(outputs) == 1:
|
||||||
|
# Single group output
|
||||||
|
key = list(outputs.keys())[0]
|
||||||
|
base_path = Path(output_path)
|
||||||
|
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||||
|
dump_to_pickle(outputs[key], dump_path, verbose=verbose)
|
||||||
|
else:
|
||||||
|
# Multiple groups - save each separately
|
||||||
|
base_path = Path(output_path)
|
||||||
|
for key, df_out in outputs.items():
|
||||||
|
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||||
|
dump_to_pickle(df_out, dump_path, verbose=verbose)
|
||||||
|
else: # parquet
|
||||||
|
if 'merged' in outputs:
|
||||||
|
dump_to_parquet(outputs['merged'], output_path, verbose=verbose)
|
||||||
|
elif len(outputs) == 1:
|
||||||
|
# Single group output
|
||||||
|
key = list(outputs.keys())[0]
|
||||||
|
base_path = Path(output_path)
|
||||||
|
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||||
|
dump_to_parquet(outputs[key], dump_path, verbose=verbose)
|
||||||
|
else:
|
||||||
|
# Multiple groups - save each separately
|
||||||
|
base_path = Path(output_path)
|
||||||
|
for key, df_out in outputs.items():
|
||||||
|
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||||
|
dump_to_parquet(df_out, dump_path, verbose=verbose)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("Feature dump complete!")
|
||||||
@ -0,0 +1,279 @@
|
|||||||
|
"""Data loading functions for the feature pipeline.
|
||||||
|
|
||||||
|
Memory Best Practices:
|
||||||
|
- Always filter on partition keys (datetime) BEFORE collecting
|
||||||
|
- Use streaming for large date ranges (>1 year)
|
||||||
|
- Never collect full parquet tables - filter on datetime first
|
||||||
|
"""
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
# Data paths
|
||||||
|
PARQUET_ALPHA158_BETA_PATH = "/data/parquet/dataset/stg_1day_wind_alpha158_0_7_beta_1D/"
|
||||||
|
PARQUET_KLINE_PATH = "/data/parquet/dataset/stg_1day_wind_kline_adjusted_1D/"
|
||||||
|
PARQUET_MARKET_FLAG_PATH = "/data/parquet/dataset/stg_1day_wind_market_flag_1D/"
|
||||||
|
PARQUET_INDUSTRY_FLAG_PATH = "/data/parquet/dataset/stg_1day_gds_indus_flag_cc1_1D/"
|
||||||
|
PARQUET_CON_RATING_PATH = "/data/parquet/dataset/stg_1day_gds_con_rating_1D/"
|
||||||
|
|
||||||
|
# Industry flag columns (30 one-hot columns - note: gds_CC29 is not present in the data)
|
||||||
|
INDUSTRY_FLAG_COLS = [
|
||||||
|
'gds_CC10', 'gds_CC11', 'gds_CC12', 'gds_CC20', 'gds_CC21', 'gds_CC22',
|
||||||
|
'gds_CC23', 'gds_CC24', 'gds_CC25', 'gds_CC26', 'gds_CC27', 'gds_CC28',
|
||||||
|
'gds_CC30', 'gds_CC31', 'gds_CC32', 'gds_CC33', 'gds_CC34', 'gds_CC35',
|
||||||
|
'gds_CC36', 'gds_CC37', 'gds_CC40', 'gds_CC41', 'gds_CC42', 'gds_CC43',
|
||||||
|
'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Market flag columns
|
||||||
|
MARKET_FLAG_COLS_KLINE = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR']
|
||||||
|
MARKET_FLAG_COLS_MARKET = ['open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop']
|
||||||
|
|
||||||
|
# Market extension raw columns
|
||||||
|
MARKET_EXT_RAW_COLS = ['Turnover', 'FreeTurnover', 'MarketValue']
|
||||||
|
|
||||||
|
|
||||||
|
def get_date_partitions(start_date: str, end_date: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Generate a list of date partitions to load from Parquet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of datetime=YYYYMMDD partition strings for weekdays only
|
||||||
|
"""
|
||||||
|
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
|
||||||
|
partitions = []
|
||||||
|
current = start
|
||||||
|
while current <= end:
|
||||||
|
if current.weekday() < 5: # Monday = 0, Friday = 4
|
||||||
|
partitions.append(f"datetime={current.strftime('%Y%m%d')}")
|
||||||
|
current = datetime(current.year, current.month, current.day + 1)
|
||||||
|
|
||||||
|
return partitions
|
||||||
|
|
||||||
|
|
||||||
|
def load_parquet_by_date_range(
|
||||||
|
base_path: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
columns: Optional[List[str]] = None,
|
||||||
|
collect: bool = True,
|
||||||
|
streaming: bool = False
|
||||||
|
) -> pl.LazyFrame | pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Load parquet data filtered by date range.
|
||||||
|
|
||||||
|
CRITICAL: This function filters on the datetime partition key BEFORE collecting.
|
||||||
|
This is essential for memory efficiency with large parquet datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_path: Base path to the parquet dataset
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
columns: Optional list of columns to select (excluding instrument/datetime)
|
||||||
|
collect: If True, return DataFrame (default); if False, return LazyFrame
|
||||||
|
streaming: If True, use streaming mode for large datasets (recommended for >1 year)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Polars DataFrame with the loaded data (LazyFrame if collect=False)
|
||||||
|
"""
|
||||||
|
start_int = int(start_date.replace("-", ""))
|
||||||
|
end_int = int(end_date.replace("-", ""))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start with lazy scan - DO NOT COLLECT YET
|
||||||
|
lf = pl.scan_parquet(base_path)
|
||||||
|
|
||||||
|
# CRITICAL: Filter on partition key FIRST, before any other operations
|
||||||
|
# This ensures partition pruning happens at the scan level
|
||||||
|
lf = lf.filter(pl.col('datetime') >= start_int)
|
||||||
|
lf = lf.filter(pl.col('datetime') <= end_int)
|
||||||
|
|
||||||
|
# Select specific columns if provided (column pruning)
|
||||||
|
if columns:
|
||||||
|
# Get schema to check which columns exist
|
||||||
|
schema = lf.collect_schema()
|
||||||
|
available_cols = ['instrument', 'datetime'] + [c for c in columns if c in schema.names()]
|
||||||
|
lf = lf.select(available_cols)
|
||||||
|
|
||||||
|
# Collect with optional streaming mode
|
||||||
|
if collect:
|
||||||
|
if streaming:
|
||||||
|
return lf.collect(streaming=True)
|
||||||
|
return lf.collect()
|
||||||
|
return lf
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading from {base_path}: {e}")
|
||||||
|
# Return empty DataFrame with expected schema
|
||||||
|
if columns:
|
||||||
|
return pl.DataFrame({
|
||||||
|
'instrument': pl.Series([], dtype=pl.String),
|
||||||
|
'datetime': pl.Series([], dtype=pl.Int32),
|
||||||
|
**{col: pl.Series([], dtype=pl.Float64) for col in columns if c not in ['instrument', 'datetime']}
|
||||||
|
})
|
||||||
|
return pl.DataFrame({
|
||||||
|
'instrument': pl.Series([], dtype=pl.String),
|
||||||
|
'datetime': pl.Series([], dtype=pl.Int32)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def load_alpha158(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Load alpha158 beta factors from parquet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
streaming: If True, use streaming mode for large datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with instrument, datetime, and 158 alpha158 features
|
||||||
|
"""
|
||||||
|
print("Loading alpha158_0_7_beta factors...")
|
||||||
|
df = load_parquet_by_date_range(PARQUET_ALPHA158_BETA_PATH, start_date, end_date, streaming=streaming)
|
||||||
|
print(f" Alpha158 shape: {df.shape}")
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def load_market_ext(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Load market extension features from parquet.
|
||||||
|
|
||||||
|
Loads Turnover, FreeTurnover, MarketValue from kline data and transforms:
|
||||||
|
- Turnover -> turnover (rename)
|
||||||
|
- FreeTurnover -> free_turnover (rename)
|
||||||
|
- MarketValue -> log_size = log(MarketValue)
|
||||||
|
- con_rating_strength: loaded from parquet (or zeros if not available)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
streaming: If True, use streaming mode for large datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with instrument, datetime, turnover, free_turnover, log_size, con_rating_strength
|
||||||
|
"""
|
||||||
|
print("Loading kline data (market ext columns)...")
|
||||||
|
|
||||||
|
# Load raw kline columns
|
||||||
|
df_kline = load_parquet_by_date_range(
|
||||||
|
PARQUET_KLINE_PATH, start_date, end_date, MARKET_EXT_RAW_COLS, streaming=streaming
|
||||||
|
)
|
||||||
|
print(f" Kline (market ext raw) shape: {df_kline.shape}")
|
||||||
|
|
||||||
|
# Load con_rating_strength from parquet
|
||||||
|
print("Loading con_rating_strength from parquet...")
|
||||||
|
df_con_rating = load_parquet_by_date_range(
|
||||||
|
PARQUET_CON_RATING_PATH, start_date, end_date, ['con_rating_strength'], streaming=streaming
|
||||||
|
)
|
||||||
|
print(f" Con rating shape: {df_con_rating.shape}")
|
||||||
|
|
||||||
|
# Transform columns
|
||||||
|
df_kline = df_kline.with_columns([
|
||||||
|
pl.col('Turnover').alias('turnover'),
|
||||||
|
pl.col('FreeTurnover').alias('free_turnover'),
|
||||||
|
pl.col('MarketValue').log().alias('log_size'),
|
||||||
|
])
|
||||||
|
print(f" Kline (market ext transformed) shape: {df_kline.shape}")
|
||||||
|
|
||||||
|
# Merge con_rating_strength
|
||||||
|
df_kline = df_kline.join(df_con_rating, on=['instrument', 'datetime'], how='left')
|
||||||
|
df_kline = df_kline.with_columns([
|
||||||
|
pl.col('con_rating_strength').fill_null(0.0)
|
||||||
|
])
|
||||||
|
print(f" Kline (with con_rating) shape: {df_kline.shape}")
|
||||||
|
|
||||||
|
return df_kline
|
||||||
|
|
||||||
|
|
||||||
|
def load_market_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Load market flag features from parquet.
|
||||||
|
|
||||||
|
Combines two sources:
|
||||||
|
- From kline_adjusted: IsZt, IsDt, IsN, IsXD, IsXR, IsDR
|
||||||
|
- From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
streaming: If True, use streaming mode for large datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with instrument, datetime, and 12 market flag columns
|
||||||
|
"""
|
||||||
|
# Load kline flags
|
||||||
|
print("Loading market flags from kline_adjusted...")
|
||||||
|
df_kline_flag = load_parquet_by_date_range(
|
||||||
|
PARQUET_KLINE_PATH, start_date, end_date, MARKET_FLAG_COLS_KLINE, streaming=streaming
|
||||||
|
)
|
||||||
|
print(f" Kline flags shape: {df_kline_flag.shape}")
|
||||||
|
|
||||||
|
# Load market flags
|
||||||
|
print("Loading market flags from market_flag table...")
|
||||||
|
df_market_flag = load_parquet_by_date_range(
|
||||||
|
PARQUET_MARKET_FLAG_PATH, start_date, end_date, MARKET_FLAG_COLS_MARKET, streaming=streaming
|
||||||
|
)
|
||||||
|
print(f" Market flag shape: {df_market_flag.shape}")
|
||||||
|
|
||||||
|
# Merge both flag sources
|
||||||
|
df_flag = df_kline_flag.join(df_market_flag, on=['instrument', 'datetime'], how='inner')
|
||||||
|
print(f" Combined flags shape: {df_flag.shape}")
|
||||||
|
|
||||||
|
return df_flag
|
||||||
|
|
||||||
|
|
||||||
|
def load_industry_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Load industry flag features from parquet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
streaming: If True, use streaming mode for large datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with instrument, datetime, and 29 industry flag columns
|
||||||
|
"""
|
||||||
|
print("Loading industry flags...")
|
||||||
|
df = load_parquet_by_date_range(
|
||||||
|
PARQUET_INDUSTRY_FLAG_PATH, start_date, end_date, INDUSTRY_FLAG_COLS, streaming=streaming
|
||||||
|
)
|
||||||
|
print(f" Industry shape: {df.shape}")
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def load_all_data(
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
streaming: bool = False
|
||||||
|
) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
||||||
|
"""
|
||||||
|
Load all data sources from Parquet.
|
||||||
|
|
||||||
|
This is a convenience function that loads all four data sources
|
||||||
|
and returns them as separate DataFrames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
streaming: If True, use streaming mode for large datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (alpha158_df, market_ext_df, market_flag_df, industry_df)
|
||||||
|
"""
|
||||||
|
print(f"Loading data from {start_date} to {end_date}...")
|
||||||
|
|
||||||
|
df_alpha = load_alpha158(start_date, end_date, streaming=streaming)
|
||||||
|
df_market_ext = load_market_ext(start_date, end_date, streaming=streaming)
|
||||||
|
df_market_flag = load_market_flags(start_date, end_date, streaming=streaming)
|
||||||
|
df_industry = load_industry_flags(start_date, end_date, streaming=streaming)
|
||||||
|
|
||||||
|
return df_alpha, df_market_ext, df_market_flag, df_industry
|
||||||
@ -0,0 +1,539 @@
|
|||||||
|
"""FeaturePipeline orchestrator for the data loading and transformation pipeline."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
from typing import List, Dict, Optional, Tuple
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .dataclass import FeatureGroups
|
||||||
|
from .loaders import (
|
||||||
|
load_alpha158,
|
||||||
|
load_market_ext,
|
||||||
|
load_market_flags,
|
||||||
|
load_industry_flags,
|
||||||
|
load_all_data,
|
||||||
|
INDUSTRY_FLAG_COLS,
|
||||||
|
)
|
||||||
|
from .processors import (
|
||||||
|
DiffProcessor,
|
||||||
|
FlagMarketInjector,
|
||||||
|
FlagSTInjector,
|
||||||
|
ColumnRemover,
|
||||||
|
FlagToOnehot,
|
||||||
|
IndusNtrlInjector,
|
||||||
|
RobustZScoreNorm,
|
||||||
|
Fillna
|
||||||
|
)
|
||||||
|
|
||||||
|
# Constants - Import from loaders module (single source of truth)
|
||||||
|
# INDUSTRY_FLAG_COLS is now imported from .loaders above
|
||||||
|
|
||||||
|
# Alpha158 feature columns in explicit order
|
||||||
|
ALPHA158_COLS = [
|
||||||
|
'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
|
||||||
|
'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
|
||||||
|
'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
|
||||||
|
'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
|
||||||
|
'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
|
||||||
|
'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
|
||||||
|
'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
|
||||||
|
'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
|
||||||
|
'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
|
||||||
|
'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
|
||||||
|
'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
|
||||||
|
'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
|
||||||
|
'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
|
||||||
|
'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
|
||||||
|
'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
|
||||||
|
'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
|
||||||
|
'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
|
||||||
|
'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
|
||||||
|
'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
|
||||||
|
'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
|
||||||
|
'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
|
||||||
|
'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
|
||||||
|
'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
|
||||||
|
'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
|
||||||
|
'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
|
||||||
|
'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
|
||||||
|
'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
|
||||||
|
'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
|
||||||
|
'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
|
||||||
|
'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
|
||||||
|
'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Market extension base columns
|
||||||
|
MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||||
|
|
||||||
|
# Market flag columns (before processors)
|
||||||
|
MARKET_FLAG_COLS = [
|
||||||
|
'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
|
||||||
|
'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Columns to remove after FlagMarketInjector and FlagSTInjector
|
||||||
|
COLUMNS_TO_REMOVE = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
||||||
|
|
||||||
|
# Expected VAE input dimension
|
||||||
|
VAE_INPUT_DIM = 341
|
||||||
|
|
||||||
|
# Default robust zscore parameters path
|
||||||
|
DEFAULT_ROBUST_ZSCORE_PARAMS_PATH = (
|
||||||
|
"/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/"
|
||||||
|
"data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE).
|
||||||
|
|
||||||
|
This uses qshare's filter_instruments which loads the instrument list from:
|
||||||
|
/data/qlib/default/data_ops/target/instruments/csiallx.txt
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with datetime and instrument columns
|
||||||
|
instruments: Market name for spine creation (default: 'csiallx')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered DataFrame with only instruments in the specified universe
|
||||||
|
"""
|
||||||
|
from qshare.algo.polars.spine import filter_instruments
|
||||||
|
return filter_instruments(df, instruments=instruments)
|
||||||
|
|
||||||
|
|
||||||
|
def load_robust_zscore_params(
|
||||||
|
params_path: str = None
|
||||||
|
) -> Dict[str, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Load pre-fitted RobustZScoreNorm parameters from numpy files.
|
||||||
|
|
||||||
|
Loads mean_train.npy and std_train.npy from the specified directory.
|
||||||
|
Parameters are cached to avoid repeated file I/O.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params_path: Path to the directory containing mean_train.npy and std_train.npy.
|
||||||
|
If None, uses the default path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with 'mean_train' and 'std_train' numpy arrays
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If parameter files are not found
|
||||||
|
"""
|
||||||
|
if params_path is None:
|
||||||
|
params_path = DEFAULT_ROBUST_ZSCORE_PARAMS_PATH
|
||||||
|
|
||||||
|
# Check for cached params in the class (module-level cache)
|
||||||
|
if not hasattr(load_robust_zscore_params, '_cached_params'):
|
||||||
|
load_robust_zscore_params._cached_params = {}
|
||||||
|
|
||||||
|
if params_path in load_robust_zscore_params._cached_params:
|
||||||
|
return load_robust_zscore_params._cached_params[params_path]
|
||||||
|
|
||||||
|
print(f"Loading robust zscore parameters from: {params_path}")
|
||||||
|
|
||||||
|
mean_path = os.path.join(params_path, 'mean_train.npy')
|
||||||
|
std_path = os.path.join(params_path, 'std_train.npy')
|
||||||
|
|
||||||
|
if not os.path.exists(mean_path):
|
||||||
|
raise FileNotFoundError(f"mean_train.npy not found at {mean_path}")
|
||||||
|
if not os.path.exists(std_path):
|
||||||
|
raise FileNotFoundError(f"std_train.npy not found at {std_path}")
|
||||||
|
|
||||||
|
mean_train = np.load(mean_path)
|
||||||
|
std_train = np.load(std_path)
|
||||||
|
|
||||||
|
print(f"Loaded parameters:")
|
||||||
|
print(f" mean_train shape: {mean_train.shape}")
|
||||||
|
print(f" std_train shape: {std_train.shape}")
|
||||||
|
|
||||||
|
# Try to load metadata if available
|
||||||
|
metadata_path = os.path.join(params_path, 'metadata.json')
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
with open(metadata_path, 'r') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
print(f" fitted on: {metadata.get('fit_start_time', 'N/A')} to {metadata.get('fit_end_time', 'N/A')}")
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'mean_train': mean_train,
|
||||||
|
'std_train': std_train
|
||||||
|
}
|
||||||
|
|
||||||
|
# Cache the loaded parameters
|
||||||
|
load_robust_zscore_params._cached_params[params_path] = params
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
class FeaturePipeline:
|
||||||
|
"""
|
||||||
|
Feature pipeline orchestrator for loading and transforming data.
|
||||||
|
|
||||||
|
The pipeline manages:
|
||||||
|
1. Data loading from parquet sources
|
||||||
|
2. Feature transformations via processors
|
||||||
|
3. Output preparation for VAE input
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
pipeline = FeaturePipeline(config)
|
||||||
|
feature_groups = pipeline.load_data(start_date, end_date)
|
||||||
|
transformed = pipeline.transform(feature_groups)
|
||||||
|
vae_input = pipeline.get_vae_input(transformed)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Optional[Dict] = None,
|
||||||
|
robust_zscore_params_path: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the FeaturePipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional configuration dictionary with pipeline settings.
|
||||||
|
If None, uses default configuration.
|
||||||
|
robust_zscore_params_path: Path to robust zscore parameters.
|
||||||
|
If None, uses default path.
|
||||||
|
"""
|
||||||
|
self.config = config or {}
|
||||||
|
self.robust_zscore_params_path = robust_zscore_params_path or DEFAULT_ROBUST_ZSCORE_PARAMS_PATH
|
||||||
|
|
||||||
|
# Cache for loaded robust zscore params
|
||||||
|
self._robust_zscore_params = None
|
||||||
|
|
||||||
|
# Initialize processors
|
||||||
|
self._init_processors()
|
||||||
|
|
||||||
|
def _init_processors(self):
|
||||||
|
"""Initialize all processors with default configuration."""
|
||||||
|
# Diff processor for market_ext columns
|
||||||
|
self.diff_processor = DiffProcessor(MARKET_EXT_BASE_COLS)
|
||||||
|
|
||||||
|
# Market flag injector
|
||||||
|
self.flag_market_injector = FlagMarketInjector()
|
||||||
|
|
||||||
|
# ST flag injector
|
||||||
|
self.flag_st_injector = FlagSTInjector()
|
||||||
|
|
||||||
|
# Column remover
|
||||||
|
self.column_remover = ColumnRemover(COLUMNS_TO_REMOVE)
|
||||||
|
|
||||||
|
# Industry flag to index converter
|
||||||
|
self.flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
|
||||||
|
|
||||||
|
# Industry neutralization injectors (created on-demand with specific feature lists)
|
||||||
|
# self.indus_ntrl_injector = None # Created per-feature-group
|
||||||
|
|
||||||
|
# Robust zscore normalizer (created on-demand with loaded params)
|
||||||
|
# self.robust_norm = None # Created with loaded params
|
||||||
|
|
||||||
|
# Fillna processor
|
||||||
|
self.fillna = Fillna()
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
filter_universe: bool = True,
|
||||||
|
universe_name: str = 'csiallx',
|
||||||
|
streaming: bool = False
|
||||||
|
) -> FeatureGroups:
|
||||||
|
"""
|
||||||
|
Load data from parquet sources into FeatureGroups container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
filter_universe: Whether to filter to a specific stock universe
|
||||||
|
universe_name: Name of the stock universe to filter to
|
||||||
|
streaming: If True, use Polars streaming mode for large datasets (>1 year)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FeatureGroups container with loaded data
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Loading data from {start_date} to {end_date}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Load all data sources
|
||||||
|
df_alpha, df_market_ext, df_market_flag, df_industry = load_all_data(
|
||||||
|
start_date, end_date, streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply stock universe filter if requested
|
||||||
|
if filter_universe:
|
||||||
|
print(f"Filtering to {universe_name} universe...")
|
||||||
|
df_alpha = filter_stock_universe(df_alpha, instruments=universe_name)
|
||||||
|
df_market_ext = filter_stock_universe(df_market_ext, instruments=universe_name)
|
||||||
|
df_market_flag = filter_stock_universe(df_market_flag, instruments=universe_name)
|
||||||
|
df_industry = filter_stock_universe(df_industry, instruments=universe_name)
|
||||||
|
print(f" After filter - Alpha158 shape: {df_alpha.shape}")
|
||||||
|
|
||||||
|
# Create FeatureGroups container
|
||||||
|
feature_groups = FeatureGroups(
|
||||||
|
alpha158=df_alpha,
|
||||||
|
market_ext=df_market_ext,
|
||||||
|
market_flag=df_market_flag,
|
||||||
|
industry=df_industry
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract metadata
|
||||||
|
feature_groups.extract_metadata()
|
||||||
|
|
||||||
|
print(f"Loaded {len(feature_groups.instruments)} samples")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return feature_groups
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self,
|
||||||
|
feature_groups: FeatureGroups,
|
||||||
|
pack_struct: bool = False
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Apply feature transformation pipeline to FeatureGroups.
|
||||||
|
|
||||||
|
The pipeline applies processors in the following order:
|
||||||
|
1. DiffProcessor - adds diff features to market_ext
|
||||||
|
2. FlagMarketInjector - adds market_0, market_1 to market_flag
|
||||||
|
3. FlagSTInjector - adds IsST to market_flag
|
||||||
|
4. ColumnRemover - removes log_size_diff, IsN, IsZt, IsDt
|
||||||
|
5. FlagToOnehot - converts industry flags to indus_idx
|
||||||
|
6. IndusNtrlInjector - industry neutralization for alpha158 and market_ext
|
||||||
|
7. RobustZScoreNorm - robust z-score normalization
|
||||||
|
8. Fillna - fill NaN values with 0
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_groups: FeatureGroups container with loaded data
|
||||||
|
pack_struct: If True, pack each feature group into a struct column
|
||||||
|
(features_alpha158, features_market_ext, features_market_flag).
|
||||||
|
If False (default), return flat DataFrame with all columns merged.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged DataFrame with all transformed features (pl.DataFrame)
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Starting feature transformation pipeline")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Merge all groups for processing
|
||||||
|
df = feature_groups.merge_for_processors()
|
||||||
|
|
||||||
|
# Step 1: Diff Processor - adds diff features for market_ext
|
||||||
|
df = self.diff_processor.process(df)
|
||||||
|
|
||||||
|
# Step 2: FlagMarketInjector - adds market_0, market_1
|
||||||
|
df = self.flag_market_injector.process(df)
|
||||||
|
|
||||||
|
# Step 3: FlagSTInjector - adds IsST
|
||||||
|
df = self.flag_st_injector.process(df)
|
||||||
|
|
||||||
|
# Step 4: ColumnRemover - removes specific columns
|
||||||
|
df = self.column_remover.process(df)
|
||||||
|
|
||||||
|
# Step 5: FlagToOnehot - converts industry flags to indus_idx
|
||||||
|
df = self.flag_to_onehot.process(df)
|
||||||
|
|
||||||
|
# Step 6: IndusNtrlInjector - industry neutralization
|
||||||
|
# For alpha158 features
|
||||||
|
indus_ntrl_alpha = IndusNtrlInjector(ALPHA158_COLS, suffix='_ntrl')
|
||||||
|
df = indus_ntrl_alpha.process(df, df) # Pass df as both feature and industry source
|
||||||
|
|
||||||
|
# For market_ext features (with diff columns)
|
||||||
|
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||||
|
# Remove columns that were dropped
|
||||||
|
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||||
|
indus_ntrl_ext = IndusNtrlInjector(market_ext_with_diff, suffix='_ntrl')
|
||||||
|
df = indus_ntrl_ext.process(df, df)
|
||||||
|
|
||||||
|
# Step 7: RobustZScoreNorm - robust z-score normalization
|
||||||
|
# Load parameters and create normalizer
|
||||||
|
if self._robust_zscore_params is None:
|
||||||
|
self._robust_zscore_params = load_robust_zscore_params(self.robust_zscore_params_path)
|
||||||
|
|
||||||
|
# Build the list of features to normalize
|
||||||
|
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||||
|
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||||
|
|
||||||
|
# Feature order for VAE: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
|
||||||
|
norm_feature_cols = (
|
||||||
|
alpha158_ntrl_cols + ALPHA158_COLS +
|
||||||
|
market_ext_ntrl_cols + market_ext_with_diff
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Applying RobustZScoreNorm to {len(norm_feature_cols)} features...")
|
||||||
|
|
||||||
|
robust_norm = RobustZScoreNorm(
|
||||||
|
norm_feature_cols,
|
||||||
|
clip_range=(-3, 3),
|
||||||
|
use_qlib_params=True,
|
||||||
|
qlib_mean=self._robust_zscore_params['mean_train'],
|
||||||
|
qlib_std=self._robust_zscore_params['std_train']
|
||||||
|
)
|
||||||
|
df = robust_norm.process(df)
|
||||||
|
|
||||||
|
# Step 8: Fillna - fill NaN with 0
|
||||||
|
# Get all feature columns
|
||||||
|
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_flag_cols += ['market_0', 'market_1', 'IsST']
|
||||||
|
market_flag_cols = list(dict.fromkeys(market_flag_cols)) # Remove duplicates
|
||||||
|
|
||||||
|
final_feature_cols = norm_feature_cols + market_flag_cols + ['indus_idx']
|
||||||
|
df = self.fillna.process(df, final_feature_cols)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Pipeline complete")
|
||||||
|
print(f" Total columns: {len(df.columns)}")
|
||||||
|
print(f" Rows: {len(df)}")
|
||||||
|
|
||||||
|
# Optionally pack features into struct columns
|
||||||
|
if pack_struct:
|
||||||
|
df = self._pack_into_structs(df)
|
||||||
|
print(f" Packed into struct columns")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def _pack_into_structs(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Pack feature groups into struct columns.
|
||||||
|
|
||||||
|
Creates:
|
||||||
|
- features_alpha158: struct with 316 fields (158 + 158 _ntrl)
|
||||||
|
- features_market_ext: struct with 14 fields (7 + 7 _ntrl)
|
||||||
|
- features_market_flag: struct with 11 fields
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with columns: instrument, datetime, indus_idx, features_*
|
||||||
|
"""
|
||||||
|
# Define column groups
|
||||||
|
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||||
|
alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS
|
||||||
|
|
||||||
|
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||||
|
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||||
|
market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff
|
||||||
|
|
||||||
|
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_flag_cols += ['market_0', 'market_1', 'IsST']
|
||||||
|
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||||
|
|
||||||
|
# Build result with struct columns
|
||||||
|
result_cols = ['instrument', 'datetime']
|
||||||
|
|
||||||
|
# Check if indus_idx exists
|
||||||
|
if 'indus_idx' in df.columns:
|
||||||
|
result_cols.append('indus_idx')
|
||||||
|
|
||||||
|
# Pack alpha158
|
||||||
|
alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns]
|
||||||
|
if alpha158_cols_in_df:
|
||||||
|
result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158'))
|
||||||
|
|
||||||
|
# Pack market_ext
|
||||||
|
ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns]
|
||||||
|
if ext_cols_in_df:
|
||||||
|
result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext'))
|
||||||
|
|
||||||
|
# Pack market_flag
|
||||||
|
flag_cols_in_df = [c for c in market_flag_cols if c in df.columns]
|
||||||
|
if flag_cols_in_df:
|
||||||
|
result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag'))
|
||||||
|
|
||||||
|
return df.select(result_cols)
|
||||||
|
|
||||||
|
def get_vae_input(
|
||||||
|
self,
|
||||||
|
df: pl.DataFrame | FeatureGroups,
|
||||||
|
exclude_isst: bool = True
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Prepare VAE input features from transformed DataFrame or FeatureGroups.
|
||||||
|
|
||||||
|
VAE input structure (341 features):
|
||||||
|
- feature group (316): 158 alpha158 + 158 alpha158_ntrl
|
||||||
|
- feature_ext group (14): 7 market_ext + 7 market_ext_ntrl
|
||||||
|
- feature_flag group (11): market flags (excluding IsST)
|
||||||
|
|
||||||
|
NOTE: indus_idx is NOT included in VAE input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Transformed DataFrame or FeatureGroups container
|
||||||
|
exclude_isst: Whether to exclude IsST from VAE input (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Numpy array of shape (n_samples, VAE_INPUT_DIM)
|
||||||
|
"""
|
||||||
|
print("Preparing features for VAE...")
|
||||||
|
|
||||||
|
# Accept either DataFrame or FeatureGroups
|
||||||
|
if isinstance(df, FeatureGroups):
|
||||||
|
df = df.merge_for_processors()
|
||||||
|
|
||||||
|
# Build alpha158 feature columns
|
||||||
|
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||||
|
alpha158_cols = ALPHA158_COLS.copy()
|
||||||
|
|
||||||
|
# Build market_ext feature columns (with diff, minus removed columns)
|
||||||
|
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||||
|
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||||
|
market_ext_cols = market_ext_with_diff.copy()
|
||||||
|
|
||||||
|
# VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
|
||||||
|
norm_feature_cols = (
|
||||||
|
alpha158_ntrl_cols + alpha158_cols +
|
||||||
|
market_ext_ntrl_cols + market_ext_cols
|
||||||
|
)
|
||||||
|
|
||||||
|
# Market flag columns (excluding IsST if requested)
|
||||||
|
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||||
|
market_flag_cols += ['market_0', 'market_1']
|
||||||
|
if not exclude_isst:
|
||||||
|
market_flag_cols.append('IsST')
|
||||||
|
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||||
|
|
||||||
|
# Combine all VAE input columns
|
||||||
|
vae_cols = norm_feature_cols + market_flag_cols
|
||||||
|
|
||||||
|
print(f" norm_feature_cols: {len(norm_feature_cols)}")
|
||||||
|
print(f" market_flag_cols: {len(market_flag_cols)}")
|
||||||
|
print(f" Total VAE input columns: {len(vae_cols)}")
|
||||||
|
|
||||||
|
# Verify all columns exist
|
||||||
|
missing_cols = [c for c in vae_cols if c not in df.columns]
|
||||||
|
if missing_cols:
|
||||||
|
print(f"WARNING: Missing columns: {missing_cols}")
|
||||||
|
|
||||||
|
# Select features and convert to numpy
|
||||||
|
features_df = df.select(vae_cols)
|
||||||
|
features = features_df.to_numpy().astype(np.float32)
|
||||||
|
|
||||||
|
# Handle any remaining NaN/Inf values
|
||||||
|
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
print(f"Feature matrix shape: {features.shape}")
|
||||||
|
|
||||||
|
# Verify dimensions
|
||||||
|
if features.shape[1] != VAE_INPUT_DIM:
|
||||||
|
print(f"WARNING: Expected {VAE_INPUT_DIM} features, got {features.shape[1]}")
|
||||||
|
|
||||||
|
if features.shape[1] < VAE_INPUT_DIM:
|
||||||
|
# Pad with zeros
|
||||||
|
padding = np.zeros(
|
||||||
|
(features.shape[0], VAE_INPUT_DIM - features.shape[1]),
|
||||||
|
dtype=np.float32
|
||||||
|
)
|
||||||
|
features = np.concatenate([features, padding], axis=1)
|
||||||
|
print(f"Padded to shape: {features.shape}")
|
||||||
|
else:
|
||||||
|
# Truncate
|
||||||
|
features = features[:, :VAE_INPUT_DIM]
|
||||||
|
print(f"Truncated to shape: {features.shape}")
|
||||||
|
|
||||||
|
return features
|
||||||
@ -0,0 +1,447 @@
|
|||||||
|
"""Processor classes for feature transformation.
|
||||||
|
|
||||||
|
Processors are stateless, functional components that operate on
|
||||||
|
specific feature groups or merged DataFrames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
from typing import List, Tuple, Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class DiffProcessor:
|
||||||
|
"""
|
||||||
|
Diff Processor: Calculate diff features for market_ext columns.
|
||||||
|
|
||||||
|
For each column, calculates diff with period=1 within each instrument group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, columns: List[str]):
|
||||||
|
"""
|
||||||
|
Initialize the DiffProcessor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
columns: List of column names to compute diffs for
|
||||||
|
"""
|
||||||
|
self.columns = columns
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Add diff features for specified columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with market_ext columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with added {col}_diff columns
|
||||||
|
"""
|
||||||
|
print("Applying Diff processor...")
|
||||||
|
|
||||||
|
# CRITICAL: Build all expressions FIRST, then apply in single with_columns()
|
||||||
|
# Use order_by='datetime' to ensure proper time-series ordering within each instrument
|
||||||
|
diff_exprs = [
|
||||||
|
pl.col(col).diff().over('instrument', order_by='datetime').alias(f"{col}_diff")
|
||||||
|
for col in self.columns
|
||||||
|
if col in df.columns
|
||||||
|
]
|
||||||
|
|
||||||
|
if diff_exprs:
|
||||||
|
df = df.with_columns(diff_exprs)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
class FlagMarketInjector:
|
||||||
|
"""
|
||||||
|
Flag Market Injector: Create market_0, market_1 columns based on instrument code.
|
||||||
|
|
||||||
|
Maps to Qlib's map_market_sec logic with vocab_size=2:
|
||||||
|
- market_0 (主板): SH60xxx, SZ00xxx
|
||||||
|
- market_1 (科创板/创业板): SH688xxx, SH689xxx, SZ300xxx, SZ301xxx
|
||||||
|
|
||||||
|
NOTE: vocab_size=2 (not 3!) - 新三板/北交所 (NE4xxxx, NE8xxxx) are NOT included.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame, instrument_col: str = 'instrument') -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Add market_0, market_1 columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with instrument column
|
||||||
|
instrument_col: Name of the instrument column
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with added market_0 and market_1 columns
|
||||||
|
"""
|
||||||
|
print("Applying FlagMarketInjector (vocab_size=2)...")
|
||||||
|
|
||||||
|
# Convert instrument to string and pad to 6 digits
|
||||||
|
inst_str = pl.col(instrument_col).cast(pl.String).str.zfill(6)
|
||||||
|
|
||||||
|
# Determine market type based on first digit
|
||||||
|
is_sh_main = inst_str.str.starts_with('6') # SH600xxx, SH601xxx, etc.
|
||||||
|
is_sz_main = inst_str.str.starts_with('0') # SZ000xxx
|
||||||
|
is_sh_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689') # SH688xxx, SH689xxx
|
||||||
|
is_sz_chi = inst_str.str.starts_with('300') | inst_str.str.starts_with('301') # SZ300xxx, SZ301xxx
|
||||||
|
|
||||||
|
df = df.with_columns([
|
||||||
|
# market_0 = 主板 (SH main + SZ main)
|
||||||
|
(is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'),
|
||||||
|
# market_1 = 科创板 + 创业板 (SH star + SZ ChiNext)
|
||||||
|
(is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1')
|
||||||
|
])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
class FlagSTInjector:
|
||||||
|
"""
|
||||||
|
Flag ST Injector: Create IsST column from ST flags.
|
||||||
|
|
||||||
|
Creates IsST = ST_S | ST_Y if ST flags are available,
|
||||||
|
otherwise creates a placeholder column of all zeros.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Add IsST column.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with added IsST column
|
||||||
|
"""
|
||||||
|
print("Applying FlagSTInjector (creating IsST)...")
|
||||||
|
|
||||||
|
# Check if ST flags are available
|
||||||
|
if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns:
|
||||||
|
# Create IsST from actual ST flags
|
||||||
|
df = df.with_columns([
|
||||||
|
((pl.col('ST_S').cast(pl.Boolean, strict=False) |
|
||||||
|
pl.col('ST_Y').cast(pl.Boolean, strict=False))
|
||||||
|
.cast(pl.Int8).alias('IsST'))
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Create placeholder (all zeros) if ST flags not available
|
||||||
|
df = df.with_columns([
|
||||||
|
pl.lit(0).cast(pl.Int8).alias('IsST')
|
||||||
|
])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnRemover:
|
||||||
|
"""
|
||||||
|
Column Remover: Drop specific columns.
|
||||||
|
|
||||||
|
Removes columns that are not needed for the VAE input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, columns_to_remove: List[str]):
|
||||||
|
"""
|
||||||
|
Initialize the ColumnRemover.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
columns_to_remove: List of column names to remove
|
||||||
|
"""
|
||||||
|
self.columns_to_remove = columns_to_remove
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Remove specified columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with specified columns removed
|
||||||
|
"""
|
||||||
|
print(f"Applying ColumnRemover (removing {len(self.columns_to_remove)} columns)...")
|
||||||
|
|
||||||
|
# Only remove columns that exist
|
||||||
|
cols_to_drop = [c for c in self.columns_to_remove if c in df.columns]
|
||||||
|
if cols_to_drop:
|
||||||
|
df = df.drop(cols_to_drop)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
class FlagToOnehot:
|
||||||
|
"""
|
||||||
|
Flag To Onehot: Convert 29 one-hot industry columns to single indus_idx.
|
||||||
|
|
||||||
|
For each row, finds which industry column is True/1 and sets indus_idx to that index.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, industry_cols: List[str]):
|
||||||
|
"""
|
||||||
|
Initialize the FlagToOnehot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
industry_cols: List of 29 industry flag column names
|
||||||
|
"""
|
||||||
|
self.industry_cols = industry_cols
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Convert industry flags to single indus_idx column.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with industry flag columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with indus_idx column (original industry columns removed)
|
||||||
|
"""
|
||||||
|
print("Applying FlagToOnehot (converting 29 industry flags to indus_idx)...")
|
||||||
|
|
||||||
|
# Build a when/then chain to find the industry index
|
||||||
|
# Start with -1 (no industry) as default
|
||||||
|
indus_expr = pl.lit(-1)
|
||||||
|
|
||||||
|
for idx, col in enumerate(self.industry_cols):
|
||||||
|
if col in df.columns:
|
||||||
|
indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr)
|
||||||
|
|
||||||
|
df = df.with_columns([indus_expr.alias('indus_idx')])
|
||||||
|
|
||||||
|
# Drop the original one-hot columns
|
||||||
|
cols_to_drop = [c for c in self.industry_cols if c in df.columns]
|
||||||
|
if cols_to_drop:
|
||||||
|
df = df.drop(cols_to_drop)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
class IndusNtrlInjector:
|
||||||
|
"""
|
||||||
|
Industry Neutralization Injector: Industry neutralization for features.
|
||||||
|
|
||||||
|
For each feature, subtracts the industry mean (grouped by indus_idx)
|
||||||
|
from the feature value. Creates new columns with "_ntrl" suffix while
|
||||||
|
keeping original columns.
|
||||||
|
|
||||||
|
IMPORTANT: Industry neutralization is done PER DATETIME (cross-sectional),
|
||||||
|
not across the entire dataset. This matches qlib's cal_indus_ntrl behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'):
|
||||||
|
"""
|
||||||
|
Initialize the IndusNtrlInjector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_cols: List of feature column names to neutralize
|
||||||
|
suffix: Suffix to append to neutralized column names
|
||||||
|
"""
|
||||||
|
self.feature_cols = feature_cols
|
||||||
|
self.suffix = suffix
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
feature_df: pl.DataFrame,
|
||||||
|
industry_df: pl.DataFrame
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Apply industry neutralization to specified features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_df: DataFrame with feature columns to neutralize
|
||||||
|
industry_df: DataFrame with indus_idx column (must have instrument, datetime)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with added {col}_ntrl columns
|
||||||
|
"""
|
||||||
|
print(f"Applying IndusNtrlInjector to {len(self.feature_cols)} features...")
|
||||||
|
|
||||||
|
# Check if indus_idx already exists in feature_df
|
||||||
|
if 'indus_idx' in feature_df.columns:
|
||||||
|
df = feature_df
|
||||||
|
else:
|
||||||
|
# Merge industry index into feature dataframe
|
||||||
|
df = feature_df.join(
|
||||||
|
industry_df.select(['instrument', 'datetime', 'indus_idx']),
|
||||||
|
on=['instrument', 'datetime'],
|
||||||
|
how='left',
|
||||||
|
suffix='_indus'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter to only columns that exist
|
||||||
|
existing_cols = [c for c in self.feature_cols if c in df.columns]
|
||||||
|
|
||||||
|
# CRITICAL: Build all neutralization expressions FIRST, then apply in single with_columns()
|
||||||
|
# Use order_by='datetime' to ensure proper time-series ordering within each group
|
||||||
|
# The neutralization is done per datetime (cross-sectional), so order_by='datetime'
|
||||||
|
# ensures values are processed in chronological order
|
||||||
|
ntrl_exprs = [
|
||||||
|
(pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'], order_by='datetime')).alias(f"{col}{self.suffix}")
|
||||||
|
for col in existing_cols
|
||||||
|
]
|
||||||
|
|
||||||
|
if ntrl_exprs:
|
||||||
|
df = df.with_columns(ntrl_exprs)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
class RobustZScoreNorm:
|
||||||
|
"""
|
||||||
|
Robust Z-Score Normalization: Per datetime normalization.
|
||||||
|
|
||||||
|
(x - median) / (1.4826 * MAD) where MAD = median(|x - median|)
|
||||||
|
Clip outliers at [-3, 3].
|
||||||
|
|
||||||
|
Supports pre-fitted parameters from qlib's pickled processor:
|
||||||
|
normalizer = RobustZScoreNorm(
|
||||||
|
feature_cols=feature_cols,
|
||||||
|
use_qlib_params=True,
|
||||||
|
qlib_mean=zscore_proc.mean_train,
|
||||||
|
qlib_std=zscore_proc.std_train
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_cols: List[str],
|
||||||
|
clip_range: Tuple[float, float] = (-3, 3),
|
||||||
|
use_qlib_params: bool = False,
|
||||||
|
qlib_mean: Optional[np.ndarray] = None,
|
||||||
|
qlib_std: Optional[np.ndarray] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the RobustZScoreNorm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_cols: List of feature column names to normalize
|
||||||
|
clip_range: Tuple of (min, max) for clipping normalized values
|
||||||
|
use_qlib_params: Whether to use pre-fitted parameters from qlib
|
||||||
|
qlib_mean: Pre-fitted mean array from qlib (required if use_qlib_params=True)
|
||||||
|
qlib_std: Pre-fitted std array from qlib (required if use_qlib_params=True)
|
||||||
|
"""
|
||||||
|
self.feature_cols = feature_cols
|
||||||
|
self.clip_range = clip_range
|
||||||
|
self.use_qlib_params = use_qlib_params
|
||||||
|
self.mean_train = qlib_mean
|
||||||
|
self.std_train = qlib_std
|
||||||
|
|
||||||
|
if use_qlib_params:
|
||||||
|
if qlib_mean is None or qlib_std is None:
|
||||||
|
raise ValueError("Must provide qlib_mean and qlib_std when use_qlib_params=True")
|
||||||
|
print(f"Using pre-fitted qlib parameters (mean shape: {qlib_mean.shape}, std shape: {qlib_std.shape})")
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Apply robust z-score normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with feature columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with normalized feature columns (in-place modification)
|
||||||
|
"""
|
||||||
|
print(f"Applying RobustZScoreNorm to {len(self.feature_cols)} features...")
|
||||||
|
|
||||||
|
# Filter to only columns that exist
|
||||||
|
existing_cols = [c for c in self.feature_cols if c in df.columns]
|
||||||
|
|
||||||
|
if self.use_qlib_params:
|
||||||
|
# Use pre-fitted parameters from qlib (fit once, apply to all dates)
|
||||||
|
# CRITICAL: Build all normalization expressions FIRST, then apply in single with_columns()
|
||||||
|
# This avoids creating a copy per column (330 columns = 330 copies if done in loop)
|
||||||
|
norm_exprs = []
|
||||||
|
for i, col in enumerate(existing_cols):
|
||||||
|
if i < len(self.mean_train):
|
||||||
|
mean_val = float(self.mean_train[i])
|
||||||
|
std_val = float(self.std_train[i])
|
||||||
|
norm_exprs.append(
|
||||||
|
((pl.col(col) - mean_val) / (std_val + 1e-8))
|
||||||
|
.clip(self.clip_range[0], self.clip_range[1])
|
||||||
|
.alias(col)
|
||||||
|
)
|
||||||
|
|
||||||
|
if norm_exprs:
|
||||||
|
df = df.with_columns(norm_exprs)
|
||||||
|
else:
|
||||||
|
# Compute per-datetime robust z-score (original behavior)
|
||||||
|
# CRITICAL: Build all expressions with temp columns first, then clean up in single drop()
|
||||||
|
all_exprs = []
|
||||||
|
temp_cols = []
|
||||||
|
|
||||||
|
for col in existing_cols:
|
||||||
|
# Compute median per datetime
|
||||||
|
median_col = f"__median_{col}"
|
||||||
|
temp_cols.append(median_col)
|
||||||
|
all_exprs.append(
|
||||||
|
pl.col(col).median().over('datetime').alias(median_col)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute absolute deviation
|
||||||
|
abs_dev_col = f"__absdev_{col}"
|
||||||
|
temp_cols.append(abs_dev_col)
|
||||||
|
all_exprs.append(
|
||||||
|
(pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute MAD (median of absolute deviations)
|
||||||
|
mad_col = f"__mad_{col}"
|
||||||
|
temp_cols.append(mad_col)
|
||||||
|
all_exprs.append(
|
||||||
|
pl.col(abs_dev_col).median().over('datetime').alias(mad_col)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute robust z-score and clip (modifies original column)
|
||||||
|
all_exprs.append(
|
||||||
|
((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8))
|
||||||
|
.clip(self.clip_range[0], self.clip_range[1])
|
||||||
|
.alias(col)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply all expressions in single with_columns()
|
||||||
|
if all_exprs:
|
||||||
|
df = df.with_columns(all_exprs)
|
||||||
|
|
||||||
|
# Clean up all temporary columns in single drop()
|
||||||
|
if temp_cols:
|
||||||
|
df = df.drop(temp_cols)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
class Fillna:
|
||||||
|
"""
|
||||||
|
Fill NaN: Fill all NaN values with 0 for numeric columns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
df: pl.DataFrame,
|
||||||
|
feature_cols: List[str]
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Fill NaN values with 0 for specified columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame
|
||||||
|
feature_cols: List of column names to fill NaN values for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with NaN values filled with 0
|
||||||
|
"""
|
||||||
|
print("Applying Fillna processor...")
|
||||||
|
|
||||||
|
# Filter to only columns that exist and are numeric (not boolean)
|
||||||
|
existing_cols = [
|
||||||
|
c for c in feature_cols
|
||||||
|
if c in df.columns and df[c].dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64, pl.UInt32, pl.UInt64]
|
||||||
|
]
|
||||||
|
|
||||||
|
# CRITICAL: Build all fill expressions FIRST, then apply in single with_columns()
|
||||||
|
# This avoids creating a copy per column (~345 columns = ~345 copies if done in loop)
|
||||||
|
fill_exprs = [
|
||||||
|
pl.col(col).fill_null(0.0).fill_nan(0.0)
|
||||||
|
for col in existing_cols
|
||||||
|
]
|
||||||
|
|
||||||
|
if fill_exprs:
|
||||||
|
df = df.with_columns(fill_exprs)
|
||||||
|
|
||||||
|
return df
|
||||||
Loading…
Reference in new issue