From 89bd1a528edfe5bc855dce2a25f65ffdf52af018 Mon Sep 17 00:00:00 2001 From: guofu Date: Sun, 1 Mar 2026 12:56:44 +0800 Subject: [PATCH] Extract RobustZScoreNorm parameters and add from_version() method - Add extract_qlib_params.py script to extract pre-fitted mean/std parameters from Qlib's proc_list.proc and save as reusable .npy files with metadata.json - Add RobustZScoreNorm.from_version() class method to load saved parameters by version name, supporting multiple parameter versions coexistence - Update dump_polars_dataset.py to use from_version() instead of loading parameters directly from proc_list.proc - Update generate_beta_embedding.py to use qshare's filter_instruments for stock universe filtering - Save parameters to data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/ with 330 features (158 alpha158_ntrl + 158 alpha158_raw + 7 market_ext_ntrl + 7 market_ext_raw) - Update README.md with documentation for parameter extraction and usage Co-Authored-By: Claude Opus 4.6 --- cta_1d/src/processors/__init__.py | 753 ++++++++++++++++++ stock_1d/d033/alpha158_beta/README.md | 59 +- .../metadata.json | 366 +++++++++ .../scripts/dump_polars_dataset.py | 92 ++- .../scripts/extract_qlib_params.py | 305 +++++++ .../scripts/generate_beta_embedding.py | 45 +- 6 files changed, 1553 insertions(+), 67 deletions(-) create mode 100644 cta_1d/src/processors/__init__.py create mode 100644 stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json create mode 100644 stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py diff --git a/cta_1d/src/processors/__init__.py b/cta_1d/src/processors/__init__.py new file mode 100644 index 0000000..da8476e --- /dev/null +++ b/cta_1d/src/processors/__init__.py @@ -0,0 +1,753 @@ +""" +Polars-based data processors for financial feature transformation. + +This module provides Polars implementations of Qlib-style data processors +used in the data_ops pipeline. Each processor follows a consistent interface: +- Takes a Polars DataFrame as input +- Returns a transformed Polars DataFrame + +Processors are organized by category: +- Feature Engineering: DiffProcessor +- Flag Injection: FlagMarketInjector, FlagSTInjector +- Column Operations: ColumnRemover, FlagToOnehot +- Normalization: IndusNtrlInjector, RobustZScoreNorm +- Data Cleaning: Fillna +""" + +from dataclasses import dataclass, field +from typing import List, Tuple, Optional, Dict +import numpy as np +import polars as pl + + +__all__ = [ + # Feature Engineering + 'DiffProcessor', + # Flag Injection + 'FlagMarketInjector', + 'FlagSTInjector', + # Column Operations + 'ColumnRemover', + 'FlagToOnehot', + # Normalization + 'IndusNtrlInjector', + 'RobustZScoreNorm', + # Data Cleaning + 'Fillna', +] + + +# ============================================================================= +# Feature Engineering Processors +# ============================================================================= + +class DiffProcessor: + """ + Calculate period-over-period differences within each instrument. + + For each specified column, calculates the diff(periods) within each + instrument group, using forward-fill for NaN handling. + + Attributes: + columns: List of columns to calculate diff for + suffix: Suffix to append to diff column names (default: 'diff') + periods: Number of periods to diff (default: 1) + + Example: + >>> processor = DiffProcessor(['turnover', 'log_size']) + >>> df = processor.process(df) + >>> # Creates: turnover_diff, log_size_diff + """ + + def __init__( + self, + columns: List[str], + suffix: str = 'diff', + periods: int = 1 + ): + self.columns = columns + self.suffix = suffix + self.periods = periods + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Add diff features for specified columns. + + Args: + df: Input DataFrame with datetime and instrument columns + + Returns: + DataFrame with original columns + diff columns with suffix + """ + # Sort by instrument and datetime for correct diff calculation + df = df.sort(['instrument', 'datetime']) + + # Add diff for each column + for col in self.columns: + if col in df.columns: + diff_col = f"{col}_{self.suffix}" + df = df.with_columns([ + pl.col(col) + .diff(self.periods) + .over('instrument') + .alias(diff_col) + ]) + + return df + + def __repr__(self) -> str: + return f"DiffProcessor(columns={self.columns}, suffix='{self.suffix}', periods={self.periods})" + + +# ============================================================================= +# Flag Injection Processors +# ============================================================================= + +class FlagMarketInjector: + """ + Inject market sector flags based on instrument codes. + + Classifies stocks into market segments based on their instrument codes. + Supports both formats: + - With exchange prefix: SH600000, SZ000001, SH688000, SZ300001 + - Numeric only: 600000, 000001, 688000, 300001 + + Market classification: + - market_0: Main board (SH60xxx, SZ00xxx, or 6xxxxx, 0xxxxx) - 主板 + - market_1: STAR/ChiNext (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx, + or 688xxx, 689xxx, 300xxx, 301xxx) - 科创板/创业板 + + Note: Does NOT include 新三板/北交所 (NE40xxx, NE42xxx, NE43xxx, NE8xxx) + in the classification - these stocks will have 0 for both flags. + + Example: + >>> processor = FlagMarketInjector() + >>> df = processor.process(df) + >>> # Adds: market_0, market_1 (both int8) + """ + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Add market classification columns. + + Args: + df: Input DataFrame with instrument column + + Returns: + DataFrame with market_0, market_1 columns added + """ + # Convert instrument to string + inst_col = pl.col('instrument').cast(pl.String) + + # Remove exchange prefix if present (SH/SZ/NE -> numeric part) + # E.g., SH600000 -> 600000, SZ000001 -> 000001 + inst_numeric = inst_col.str.replace_all("^(SH|SZ|NE)", "") + + # Get first digit(s) for market classification + first_digit = inst_numeric.str.slice(0, 1) + first_three = inst_numeric.str.slice(0, 3) + + # market_0: Main board (60xxx, 601xxx, 603xxx, 000xxx, 001xxx, 002xxx) + # Excludes STAR market (688xxx, 689xxx) which start with '6' but are not main board + is_sh_main = (first_digit == '6') & ~(first_three == '688') & ~(first_three == '689') + is_sz_main = first_digit == '0' + + # market_1: STAR/ChiNext (688xxx, 689xxx, 300xxx, 301xxx) + is_sh_star = (first_three == '688') | (first_three == '689') + is_sz_chi = (first_three == '300') | (first_three == '301') + + df = df.with_columns([ + # market_0 = 主板 + (is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'), + # market_1 = 科创板 + 创业板 + (is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1') + ]) + + return df + + def __repr__(self) -> str: + return "FlagMarketInjector()" + + +class FlagSTInjector: + """ + Inject ST (Special Treatment) flag. + + Creates IsST column from ST_S and ST_Y flags: + - IsST = 1 if ST_S or ST_Y is True (stock is ST or *ST) + - IsST = 0 otherwise + + If ST flags are not available, creates a placeholder column (all zeros). + + Attributes: + mark_st_as: Value to mark ST stocks as (default: 1) + + Example: + >>> processor = FlagSTInjector() + >>> df = processor.process(df) + >>> # Adds: IsST (int8) + """ + + def __init__(self, mark_st_as: int = 1): + self.mark_st_as = mark_st_as + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Add IsST column from ST flags. + + Args: + df: Input DataFrame with ST_S, ST_Y columns (or without for placeholder) + + Returns: + DataFrame with IsST column added + """ + # Check if ST flags are available + if 'ST_S' in df.columns and 'ST_Y' in df.columns: + # Create IsST from actual ST flags + df = df.with_columns([ + ((pl.col('ST_S').cast(pl.Boolean, strict=False) | + pl.col('ST_Y').cast(pl.Boolean, strict=False)) + .cast(pl.Int8).alias('IsST')) + ]) + else: + # Create placeholder (all zeros) if ST flags not available + df = df.with_columns([ + pl.lit(0).cast(pl.Int8).alias('IsST') + ]) + + return df + + def __repr__(self) -> str: + return f"FlagSTInjector(mark_st_as={self.mark_st_as})" + + +# ============================================================================= +# Column Operation Processors +# ============================================================================= + +class ColumnRemover: + """ + Remove specified columns from the DataFrame. + + Attributes: + columns_to_remove: List of column names to drop + + Example: + >>> processor = ColumnRemover(['log_size_diff', 'IsN', 'IsZt', 'IsDt']) + >>> df = processor.process(df) + >>> # Removes specified columns + """ + + def __init__(self, columns_to_remove: List[str]): + self.columns_to_remove = columns_to_remove + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Remove specified columns. + + Args: + df: Input DataFrame + + Returns: + DataFrame without specified columns + """ + # Only remove columns that exist + cols_to_drop = [c for c in self.columns_to_remove if c in df.columns] + if cols_to_drop: + df = df.drop(cols_to_drop) + + return df + + def __repr__(self) -> str: + return f"ColumnRemover(columns_to_remove={self.columns_to_remove})" + + +class FlagToOnehot: + """ + Convert flag columns to one-hot encoded index. + + For multiple one-hot encoded columns, finds which flag is set and + returns the corresponding index. Uses -1 as default for rows with + no flags set. + + Attributes: + flag_columns: List of boolean flag column names + + Example: + >>> processor = FlagToOnehot(['gds_CC10', 'gds_CC11', ...]) + >>> df = processor.process(df) + >>> # Adds: indus_idx (index of first True flag, or -1) + """ + + def __init__(self, flag_columns: List[str]): + self.flag_columns = flag_columns + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Convert flag columns to single index column. + + Args: + df: Input DataFrame with flag columns + + Returns: + DataFrame with indus_idx column added, original flags dropped + """ + # Build a when/then chain to find the industry index + # Start with -1 (no industry) as default + indus_expr = pl.lit(-1) + + for idx, col in enumerate(self.flag_columns): + if col in df.columns: + indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr) + + df = df.with_columns([indus_expr.alias('indus_idx')]) + + # Drop the original one-hot columns + cols_to_drop = [c for c in self.flag_columns if c in df.columns] + if cols_to_drop: + df = df.drop(cols_to_drop) + + return df + + def __repr__(self) -> str: + return f"FlagToOnehot(flag_columns={len(self.flag_columns)} columns)" + + +# ============================================================================= +# Normalization Processors +# ============================================================================= + +class IndusNtrlInjector: + """ + Industry neutralization for features. + + For each feature, subtracts the industry mean (grouped by indus_idx + within each datetime) from the feature value. Creates new columns + with the specified suffix while keeping original columns. + + This performs cross-sectional neutralization per datetime, matching + Qlib's cal_indus_ntrl behavior. + + Attributes: + feature_cols: List of feature columns to neutralize + suffix: Suffix for neutralized column names (default: '_ntrl') + + Example: + >>> processor = IndusNtrlInjector(['KMID', 'KLEN'], suffix='_ntrl') + >>> df = processor.process(df) + >>> # Adds: KMID_ntrl, KLEN_ntrl + """ + + def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'): + self.feature_cols = feature_cols + self.suffix = suffix + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Apply industry neutralization to specified features. + + Args: + df: Input DataFrame with feature columns and indus_idx + + Returns: + DataFrame with neutralized columns added (_ntrl suffix) + """ + # Filter to only columns that exist + existing_cols = [c for c in self.feature_cols if c in df.columns] + + for col in existing_cols: + ntrl_col = f"{col}{self.suffix}" + # Calculate industry mean PER DATETIME and subtract from feature + # Use group_by().transform() for proper group-wise operation + df = df.with_columns([ + (pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'])).alias(ntrl_col) + ]) + + return df + + def __repr__(self) -> str: + return f"IndusNtrlInjector(feature_cols={len(self.feature_cols)} columns, suffix='{self.suffix}')" + + +@dataclass +class RobustZScoreNorm: + """ + Robust z-score normalization using median/MAD. + + Applies the transformation: (x - median) / (1.4826 * MAD) + where MAD = median(|x - median|) + + Supports two modes: + 1. Per-datetime computation (default): Calculates median/MAD for each datetime + 2. Pre-fitted parameters: Uses provided mean/std arrays (from Qlib processor) + + Attributes: + feature_cols: List of feature columns to normalize + clip_range: Clip normalized values to this range (default: (-3, 3)) + use_qlib_params: Use pre-fitted parameters (default: False) + qlib_mean: Pre-fitted mean array (required if use_qlib_params=True) + qlib_std: Pre-fitted std array (required if use_qlib_params=True) + + Example: + # Using pre-fitted Qlib parameters + >>> processor = RobustZScoreNorm( + ... feature_cols=['KMID', 'KLEN'], + ... use_qlib_params=True, + ... qlib_mean=mean_array, + ... qlib_std=std_array + ... ) + >>> df = processor.process(df) + + # Loading parameters from saved version + >>> processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm") + >>> df = processor.process(df) + """ + + feature_cols: List[str] + clip_range: Tuple[float, float] = (-3.0, 3.0) + use_qlib_params: bool = False + qlib_mean: Optional[np.ndarray] = None + qlib_std: Optional[np.ndarray] = None + + @classmethod + def from_version( + cls, + version: str, + feature_cols: Optional[List[str]] = None, + clip_range: Tuple[float, float] = (-3.0, 3.0), + params_dir: str = None + ) -> "RobustZScoreNorm": + """ + Create a RobustZScoreNorm instance from saved parameters by version name. + + This loads pre-extracted mean_train and std_train from the parameter + directory structure: + {params_dir}/{version}/ + ├── mean_train.npy + ├── std_train.npy + └── metadata.json + + Args: + version: Version name (e.g., "csiallx_feature2_ntrla_flag_pnlnorm") + feature_cols: Optional list of feature columns. If None, uses the + order from metadata.json (alpha158_ntrl + alpha158_raw + + market_ext_ntrl + market_ext_raw) + clip_range: Clip range for normalized values (default: (-3, 3)) + params_dir: Base directory for parameter versions. If None, uses: + stock_1d/d033/alpha158_beta/data/robust_zscore_params/ + + Returns: + RobustZScoreNorm instance with pre-fitted parameters loaded + + Raises: + FileNotFoundError: If version directory or parameter files not found + ValueError: If feature column count doesn't match parameter shape + + Example: + >>> processor = RobustZScoreNorm.from_version( + ... "csiallx_feature2_ntrla_flag_pnlnorm" + ... ) + >>> df = processor.process(df) + """ + import json + from pathlib import Path + + # Set default params_dir + if params_dir is None: + # Default to the standard location + # Go from cta_1d/src/processors/ to alpha_lab/stock_1d/d033/alpha158_beta/data/robust_zscore_params/ + params_dir = Path(__file__).parent.parent.parent.parent / \ + "stock_1d" / "d033" / "alpha158_beta" / "data" / "robust_zscore_params" + else: + params_dir = Path(params_dir) + + version_dir = params_dir / version + + # Check version directory exists + if not version_dir.exists(): + raise FileNotFoundError( + f"Version directory not found: {version_dir}\n" + f"Available versions should be in: {params_dir}" + ) + + # Load mean_train.npy + mean_path = version_dir / "mean_train.npy" + if not mean_path.exists(): + raise FileNotFoundError(f"mean_train.npy not found: {mean_path}") + mean_train = np.load(mean_path) + + # Load std_train.npy + std_path = version_dir / "std_train.npy" + if not std_path.exists(): + raise FileNotFoundError(f"std_train.npy not found: {std_path}") + std_train = np.load(std_path) + + # Load metadata.json for feature column names + metadata_path = version_dir / "metadata.json" + if metadata_path.exists(): + with open(metadata_path, 'r') as f: + metadata = json.load(f) + + # Build feature columns from metadata if not provided + if feature_cols is None: + feature_columns = metadata.get('feature_columns', {}) + alpha158_ntrl = [f"{c}_ntrl" for c in feature_columns.get('alpha158_ntrl', [])] + alpha158_raw = feature_columns.get('alpha158_raw', []) + market_ext_ntrl = [f"{c}_ntrl" for c in feature_columns.get('market_ext_ntrl', [])] + market_ext_raw = feature_columns.get('market_ext_raw', []) + + feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw + + # Validate feature column count matches parameter shape + expected_count = len(mean_train) + if feature_cols and len(feature_cols) != expected_count: + raise ValueError( + f"Feature column count ({len(feature_cols)}) does not match " + f"parameter shape ({expected_count})" + ) + + print(f"Loaded RobustZScoreNorm parameters from version '{version}':") + print(f" mean_train shape: {mean_train.shape}") + print(f" std_train shape: {std_train.shape}") + print(f" feature_cols: {len(feature_cols) if feature_cols else 'not specified'}") + + return cls( + feature_cols=feature_cols, + clip_range=clip_range, + use_qlib_params=True, + qlib_mean=mean_train, + qlib_std=std_train + ) + + def __post_init__(self): + """Validate parameters after initialization.""" + if self.use_qlib_params: + if self.qlib_mean is None or self.qlib_std is None: + raise ValueError( + "Must provide qlib_mean and qlib_std when use_qlib_params=True" + ) + # Convert to numpy arrays if not already + self.qlib_mean = np.asarray(self.qlib_mean) + self.qlib_std = np.asarray(self.qlib_std) + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Apply robust z-score normalization. + + Args: + df: Input DataFrame with feature columns + + Returns: + DataFrame with normalized columns (in-place modification) + """ + # Filter to only columns that exist + existing_cols = [c for c in self.feature_cols if c in df.columns] + + if self.use_qlib_params: + # Use pre-fitted parameters (fit once, apply to all dates) + for i, col in enumerate(existing_cols): + if i < len(self.qlib_mean): + mean_val = float(self.qlib_mean[i]) + std_val = float(self.qlib_std[i]) + + # Apply z-score normalization using pre-fitted params + df = df.with_columns([ + ((pl.col(col) - mean_val) / (std_val + 1e-8)) + .clip(self.clip_range[0], self.clip_range[1]) + .alias(col) + ]) + else: + # Compute per-datetime robust z-score + for col in existing_cols: + # Compute median per datetime + median_col = f"__median_{col}" + df = df.with_columns([ + pl.col(col).median().over('datetime').alias(median_col) + ]) + + # Compute absolute deviation + abs_dev_col = f"__absdev_{col}" + df = df.with_columns([ + (pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col) + ]) + + # Compute MAD (median of absolute deviations) + mad_col = f"__mad_{col}" + df = df.with_columns([ + pl.col(abs_dev_col).median().over('datetime').alias(mad_col) + ]) + + # Compute robust z-score and clip + df = df.with_columns([ + ((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8)) + .clip(self.clip_range[0], self.clip_range[1]) + .alias(col) + ]) + + # Clean up temporary columns + df = df.drop([median_col, abs_dev_col, mad_col]) + + return df + + +# ============================================================================= +# Data Cleaning Processors +# ============================================================================= + +class Fillna: + """ + Fill NaN values with specified value. + + Fills NaN/None values in specified columns with the fill_value. + Only processes numeric columns (Float32, Float64, Int32, Int64, UInt32, UInt64). + + Attributes: + fill_value: Value to fill NaN with (default: 0.0) + + Example: + >>> processor = Fillna(fill_value=0.0) + >>> df = processor.process(df, ['KMID', 'KLEN']) + """ + + def __init__(self, fill_value: float = 0.0): + self.fill_value = fill_value + + def process(self, df: pl.DataFrame, columns: List[str]) -> pl.DataFrame: + """ + Fill NaN values in specified columns. + + Args: + df: Input DataFrame + columns: List of columns to fill NaN in + + Returns: + DataFrame with NaN values filled + """ + # Filter to only columns that exist and are numeric + existing_cols = [c for c in columns if c in df.columns] + + for col in existing_cols: + # Check column dtype + dtype = df[col].dtype + if dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64, + pl.UInt32, pl.UInt64, pl.UInt16, pl.UInt8]: + df = df.with_columns([ + pl.col(col).fill_null(self.fill_value).fill_nan(self.fill_value) + ]) + + return df + + def __repr__(self) -> str: + return f"Fillna(fill_value={self.fill_value})" + + +# ============================================================================= +# Processor Pipeline Utilities +# ============================================================================= + +def create_processor_pipeline( + alpha158_cols: List[str], + market_ext_base: List[str], + market_flag_cols: List[str], + industry_flag_cols: List[str], + columns_to_remove: Optional[List[str]] = None, + ntrl_suffix: str = '_ntrl' +) -> List: + """ + Create a complete processor pipeline configuration. + + This factory function creates processors in the correct order: + 1. DiffProcessor - adds diff features + 2. FlagMarketInjector - adds market_0, market_1 + 3. FlagSTInjector - adds IsST + 4. ColumnRemover - removes specified columns + 5. FlagToOnehot - converts industry flags to index + 6. IndusNtrlInjector (x2) - neutralizes alpha158 and market_ext + 7. RobustZScoreNorm - normalizes features + 8. Fillna - fills NaN values + + Args: + alpha158_cols: List of alpha158 feature names + market_ext_base: List of market extension base columns + market_flag_cols: List of market flag columns + industry_flag_cols: List of industry flag columns + columns_to_remove: Columns to remove (default: ['log_size_diff', 'IsN', 'IsZt', 'IsDt']) + ntrl_suffix: Suffix for neutralized columns + + Returns: + List of processor instances in execution order + """ + if columns_to_remove is None: + columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt'] + + return [ + DiffProcessor(market_ext_base), + FlagMarketInjector(), + FlagSTInjector(), + ColumnRemover(columns_to_remove), + FlagToOnehot(industry_flag_cols), + IndusNtrlInjector(alpha158_cols, suffix=ntrl_suffix), + IndusNtrlInjector(market_ext_base, suffix=ntrl_suffix), + # RobustZScoreNorm and Fillna require fitted parameters + # and should be added separately + ] + + +def get_final_feature_columns( + alpha158_cols: List[str], + market_ext_base: List[str], + market_flag_cols: List[str], + columns_to_remove: Optional[List[str]] = None, + ntrl_suffix: str = '_ntrl' +) -> Dict[str, List[str]]: + """ + Get the final feature column structure after processing. + + This is useful for determining the expected VAE input dimensions + and verifying feature order. + + Args: + alpha158_cols: List of alpha158 feature names + market_ext_base: List of market extension base columns (before Diff) + market_flag_cols: List of market flag columns (before ColumnRemover) + columns_to_remove: Columns to remove + ntrl_suffix: Suffix for neutralized columns + + Returns: + Dictionary with feature groups: + - 'norm_feature_cols': Features to normalize (in order) + - 'market_flag_cols': Market flag columns after processing + - 'all_feature_cols': All feature columns including indus_idx + """ + if columns_to_remove is None: + columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt'] + + # After Diff: market_ext becomes base + diff + market_ext_after_diff = market_ext_base + [f"{c}_diff" for c in market_ext_base] + + # After ColumnRemover: remove specified columns + market_ext_final = [c for c in market_ext_after_diff if c not in columns_to_remove] + market_flag_final = [c for c in market_flag_cols if c not in columns_to_remove] + + # After FlagMarketInjector: add market_0, market_1 + market_flag_final.extend(['market_0', 'market_1']) + + # After FlagSTInjector: add IsST + market_flag_final.append('IsST') + + # Build normalized feature columns (in Qlib order: ntrl + raw for each group) + alpha158_ntrl = [f"{c}{ntrl_suffix}" for c in alpha158_cols] + market_ext_ntrl = [f"{c}{ntrl_suffix}" for c in market_ext_final] + + # Normalization order: alpha158_ntrl + alpha158 + market_ext_ntrl + market_ext + norm_feature_cols = alpha158_ntrl + alpha158_cols + market_ext_ntrl + market_ext_final + + return { + 'alpha158_cols': alpha158_cols, + 'alpha158_ntrl_cols': alpha158_ntrl, + 'market_ext_cols': market_ext_final, + 'market_ext_ntrl_cols': market_ext_ntrl, + 'market_flag_cols': market_flag_final, + 'norm_feature_cols': norm_feature_cols, + 'all_feature_cols': norm_feature_cols + market_flag_final + ['indus_idx'], + } diff --git a/stock_1d/d033/alpha158_beta/README.md b/stock_1d/d033/alpha158_beta/README.md index bf6fdf6..c1b678a 100644 --- a/stock_1d/d033/alpha158_beta/README.md +++ b/stock_1d/d033/alpha158_beta/README.md @@ -19,18 +19,27 @@ stock_1d/d033/alpha158_beta/ │ ├── fetch_predictions.py # Fetch original predictions from DolphinDB │ ├── predict_with_embedding.py # Generate predictions using beta embeddings │ ├── compare_predictions.py # Compare 0_7 vs 0_7_beta predictions -│ └── dump_polars_dataset.py # Dump raw and processed datasets using polars pipeline +│ ├── dump_polars_dataset.py # Dump raw and processed datasets using polars pipeline +│ └── extract_qlib_params.py # Extract RobustZScoreNorm parameters from Qlib proc_list ├── src/ # Source modules │ └── qlib_loader.py # Qlib data loader with configurable date range ├── config/ # Configuration files │ └── handler.yaml # Modified handler with configurable end date -└── data/ # Data files (gitignored) - ├── embedding_0_7_beta.parquet - ├── predictions_beta_embedding.parquet - ├── original_predictions_0_7.parquet - ├── actual_returns.parquet - ├── raw_data_*.pkl # Raw data before preprocessing - └── processed_data_*.pkl # Processed data after preprocessing +├── data/ # Data files (gitignored) +│ ├── robust_zscore_params/ # Pre-fitted normalization parameters +│ │ └── csiallx_feature2_ntrla_flag_pnlnorm/ +│ │ ├── mean_train.npy +│ │ ├── std_train.npy +│ │ └── metadata.json +│ ├── embedding_0_7_beta.parquet +│ ├── predictions_beta_embedding.parquet +│ ├── original_predictions_0_7.parquet +│ ├── actual_returns.parquet +│ ├── raw_data_*.pkl # Raw data before preprocessing +│ └── processed_data_*.pkl # Processed data after preprocessing +└── data_polars/ # Polars-generated datasets (gitignored) + ├── raw_data_*.pkl + └── processed_data_*.pkl ``` ## Data Loading with Configurable Date Range @@ -122,7 +131,7 @@ This script: - ColumnRemover (removes log_size_diff, IsN, IsZt, IsDt) - FlagToOnehot (converts 29 industry flags to indus_idx) - IndusNtrlInjector (industry neutralization) - - RobustZScoreNorm (using pre-fitted qlib parameters) + - RobustZScoreNorm (using pre-fitted qlib parameters via `from_version()`) - Fillna (fill NaN with 0) 4. Saves processed data to `data_polars/processed_data_*.pkl` @@ -133,6 +142,38 @@ Output structure: - Processed data: 342 columns (316 feature + 14 feature_ext + 11 feature_flag + 1 indus_idx) - VAE input dimension: 341 (excluding indus_idx) +### RobustZScoreNorm Parameter Extraction + +The pipeline uses pre-fitted normalization parameters extracted from Qlib's `proc_list.proc` file. These parameters are stored in `data/robust_zscore_params/` and can be loaded using the `RobustZScoreNorm.from_version()` method. + +**Extract parameters from Qlib proc_list:** + +```bash +python scripts/extract_qlib_params.py --version csiallx_feature2_ntrla_flag_pnlnorm +``` + +This creates: +- `data/robust_zscore_params/{version}/mean_train.npy` - Pre-fitted mean parameters (330,) +- `data/robust_zscore_params/{version}/std_train.npy` - Pre-fitted std parameters (330,) +- `data/robust_zscore_params/{version}/metadata.json` - Feature column names and metadata + +**Use in Polars processors:** + +```python +from cta_1d.src.processors import RobustZScoreNorm + +# Load pre-fitted parameters by version name +processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm") + +# Apply normalization to DataFrame +df = processor.process(df) +``` + +**Parameter details:** +- Fit period: 2013-01-01 to 2018-12-31 +- Feature count: 330 (158 alpha158_ntrl + 158 alpha158_raw + 7 market_ext_ntrl + 7 market_ext_raw) +- Fields: ['feature', 'feature_ext'] + ## Workflow ### 1. Generate Beta Embeddings diff --git a/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json b/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json new file mode 100644 index 0000000..390a6ce --- /dev/null +++ b/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json @@ -0,0 +1,366 @@ +{ + "version": "csiallx_feature2_ntrla_flag_pnlnorm", + "created_at": "2026-03-01T12:18:01.969109", + "source_file": "2013-01-01", + "fit_start_time": "2013-01-01", + "fit_end_time": "2018-12-31", + "fields_group": [ + "feature", + "feature_ext" + ], + "feature_columns": { + "alpha158_ntrl": [ + "KMID", + "KLEN", + "KMID2", + "KUP", + "KUP2", + "KLOW", + "KLOW2", + "KSFT", + "KSFT2", + "OPEN0", + "HIGH0", + "LOW0", + "VWAP0", + "ROC5", + "ROC10", + "ROC20", + "ROC30", + "ROC60", + "MA5", + "MA10", + "MA20", + "MA30", + "MA60", + "STD5", + "STD10", + "STD20", + "STD30", + "STD60", + "BETA5", + "BETA10", + "BETA20", + "BETA30", + "BETA60", + "RSQR5", + "RSQR10", + "RSQR20", + "RSQR30", + "RSQR60", + "RESI5", + "RESI10", + "RESI20", + "RESI30", + "RESI60", + "MAX5", + "MAX10", + "MAX20", + "MAX30", + "MAX60", + "MIN5", + "MIN10", + "MIN20", + "MIN30", + "MIN60", + "QTLU5", + "QTLU10", + "QTLU20", + "QTLU30", + "QTLU60", + "QTLD5", + "QTLD10", + "QTLD20", + "QTLD30", + "QTLD60", + "RANK5", + "RANK10", + "RANK20", + "RANK30", + "RANK60", + "RSV5", + "RSV10", + "RSV20", + "RSV30", + "RSV60", + "IMAX5", + "IMAX10", + "IMAX20", + "IMAX30", + "IMAX60", + "IMIN5", + "IMIN10", + "IMIN20", + "IMIN30", + "IMIN60", + "IMXD5", + "IMXD10", + "IMXD20", + "IMXD30", + "IMXD60", + "CORR5", + "CORR10", + "CORR20", + "CORR30", + "CORR60", + "CORD5", + "CORD10", + "CORD20", + "CORD30", + "CORD60", + "CNTP5", + "CNTP10", + "CNTP20", + "CNTP30", + "CNTP60", + "CNTN5", + "CNTN10", + "CNTN20", + "CNTN30", + "CNTN60", + "CNTD5", + "CNTD10", + "CNTD20", + "CNTD30", + "CNTD60", + "SUMP5", + "SUMP10", + "SUMP20", + "SUMP30", + "SUMP60", + "SUMN5", + "SUMN10", + "SUMN20", + "SUMN30", + "SUMN60", + "SUMD5", + "SUMD10", + "SUMD20", + "SUMD30", + "SUMD60", + "VMA5", + "VMA10", + "VMA20", + "VMA30", + "VMA60", + "VSTD5", + "VSTD10", + "VSTD20", + "VSTD30", + "VSTD60", + "WVMA5", + "WVMA10", + "WVMA20", + "WVMA30", + "WVMA60", + "VSUMP5", + "VSUMP10", + "VSUMP20", + "VSUMP30", + "VSUMP60", + "VSUMN5", + "VSUMN10", + "VSUMN20", + "VSUMN30", + "VSUMN60", + "VSUMD5", + "VSUMD10", + "VSUMD20", + "VSUMD30", + "VSUMD60" + ], + "alpha158_raw": [ + "KMID", + "KLEN", + "KMID2", + "KUP", + "KUP2", + "KLOW", + "KLOW2", + "KSFT", + "KSFT2", + "OPEN0", + "HIGH0", + "LOW0", + "VWAP0", + "ROC5", + "ROC10", + "ROC20", + "ROC30", + "ROC60", + "MA5", + "MA10", + "MA20", + "MA30", + "MA60", + "STD5", + "STD10", + "STD20", + "STD30", + "STD60", + "BETA5", + "BETA10", + "BETA20", + "BETA30", + "BETA60", + "RSQR5", + "RSQR10", + "RSQR20", + "RSQR30", + "RSQR60", + "RESI5", + "RESI10", + "RESI20", + "RESI30", + "RESI60", + "MAX5", + "MAX10", + "MAX20", + "MAX30", + "MAX60", + "MIN5", + "MIN10", + "MIN20", + "MIN30", + "MIN60", + "QTLU5", + "QTLU10", + "QTLU20", + "QTLU30", + "QTLU60", + "QTLD5", + "QTLD10", + "QTLD20", + "QTLD30", + "QTLD60", + "RANK5", + "RANK10", + "RANK20", + "RANK30", + "RANK60", + "RSV5", + "RSV10", + "RSV20", + "RSV30", + "RSV60", + "IMAX5", + "IMAX10", + "IMAX20", + "IMAX30", + "IMAX60", + "IMIN5", + "IMIN10", + "IMIN20", + "IMIN30", + "IMIN60", + "IMXD5", + "IMXD10", + "IMXD20", + "IMXD30", + "IMXD60", + "CORR5", + "CORR10", + "CORR20", + "CORR30", + "CORR60", + "CORD5", + "CORD10", + "CORD20", + "CORD30", + "CORD60", + "CNTP5", + "CNTP10", + "CNTP20", + "CNTP30", + "CNTP60", + "CNTN5", + "CNTN10", + "CNTN20", + "CNTN30", + "CNTN60", + "CNTD5", + "CNTD10", + "CNTD20", + "CNTD30", + "CNTD60", + "SUMP5", + "SUMP10", + "SUMP20", + "SUMP30", + "SUMP60", + "SUMN5", + "SUMN10", + "SUMN20", + "SUMN30", + "SUMN60", + "SUMD5", + "SUMD10", + "SUMD20", + "SUMD30", + "SUMD60", + "VMA5", + "VMA10", + "VMA20", + "VMA30", + "VMA60", + "VSTD5", + "VSTD10", + "VSTD20", + "VSTD30", + "VSTD60", + "WVMA5", + "WVMA10", + "WVMA20", + "WVMA30", + "WVMA60", + "VSUMP5", + "VSUMP10", + "VSUMP20", + "VSUMP30", + "VSUMP60", + "VSUMN5", + "VSUMN10", + "VSUMN20", + "VSUMN30", + "VSUMN60", + "VSUMD5", + "VSUMD10", + "VSUMD20", + "VSUMD30", + "VSUMD60" + ], + "market_ext_ntrl": [ + "turnover", + "free_turnover", + "log_size", + "con_rating_strength", + "turnover_diff", + "free_turnover_diff", + "con_rating_strength_diff" + ], + "market_ext_raw": [ + "turnover", + "free_turnover", + "log_size", + "con_rating_strength", + "turnover_diff", + "free_turnover_diff", + "con_rating_strength_diff" + ] + }, + "feature_count": { + "alpha158_ntrl": 158, + "alpha158_raw": 158, + "market_ext_ntrl": 7, + "market_ext_raw": 7, + "total": 330 + }, + "parameter_shapes": { + "mean_train": [ + 330 + ], + "std_train": [ + 330 + ] + } +} \ No newline at end of file diff --git a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py index cc928df..7b961d0 100644 --- a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py +++ b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py @@ -20,18 +20,20 @@ from datetime import datetime # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) -from generate_beta_embedding import ( - load_all_data, - merge_data_sources, - filter_stock_universe, +# Import processors from the new shared module +from cta_1d.src.processors import ( DiffProcessor, FlagMarketInjector, + FlagSTInjector, ColumnRemover, FlagToOnehot, IndusNtrlInjector, RobustZScoreNorm, Fillna, - load_qlib_processor_params, +) + +# Import constants from local module +from generate_beta_embedding import ( ALPHA158_COLS, INDUSTRY_FLAG_COLS, ) @@ -52,7 +54,7 @@ def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame: This mimics the qlib proc_list: 0. Diff: Adds diff features for market_ext columns 1. FlagMarketInjector: Adds market_0, market_1 - 2. FlagSTInjector: SKIPPED (fails even in gold-standard) + 2. FlagSTInjector: Adds IsST (placeholder if ST flags not available) 3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt 4. FlagToOnehot: Converts 29 industry flags to indus_idx 5. IndusNtrlInjector: Industry neutralization for feature @@ -88,9 +90,13 @@ def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame: # Add market_0, market_1 to flag list market_flag_with_market = market_flag_cols + ['market_0', 'market_1'] - # Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard) - print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...") - market_flag_with_st = market_flag_with_market # No IsST added + # Step 3: FlagSTInjector - adds IsST (placeholder if ST flags not available) + print("[3] Applying FlagSTInjector...") + flag_st_injector = FlagSTInjector() + df = flag_st_injector.process(df) + + # Add IsST to flag list + market_flag_with_st = market_flag_with_market + ['IsST'] # Step 4: ColumnRemover print("[4] Applying ColumnRemover...") @@ -129,21 +135,18 @@ def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame: print("[8] Applying RobustZScoreNorm...") norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols - qlib_params = load_qlib_processor_params() + # Load RobustZScoreNorm with pre-fitted parameters from version + robust_norm = RobustZScoreNorm.from_version( + "csiallx_feature2_ntrla_flag_pnlnorm", + feature_cols=norm_feature_cols + ) - # Verify parameter shape + # Verify parameter shape matches expected features expected_features = len(norm_feature_cols) - if qlib_params['mean_train'].shape[0] != expected_features: + if robust_norm.qlib_mean.shape[0] != expected_features: print(f" WARNING: Feature count mismatch! Expected {expected_features}, " - f"got {qlib_params['mean_train'].shape[0]}") - - robust_norm = RobustZScoreNorm( - norm_feature_cols, - clip_range=(-3, 3), - use_qlib_params=True, - qlib_mean=qlib_params['mean_train'], - qlib_std=qlib_params['std_train'] - ) + f"got {robust_norm.qlib_mean.shape[0]}") + df = robust_norm.process(df) # Step 9: Fillna @@ -166,6 +169,9 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame": """ Convert polars DataFrame to pandas DataFrame with MultiIndex columns. This matches the format of qlib's output. + + IMPORTANT: Qlib's IndusNtrlInjector outputs columns in order [_ntrl] + [raw], + so we need to reorder columns to match this expected order. """ import pandas as pd @@ -185,6 +191,7 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame": df = df.drop(columns=existing_raw_cols) # Build MultiIndex columns based on column name patterns + # IMPORTANT: Qlib order is [_ntrl columns] + [raw columns] for each group columns_with_group = [] # Define column sets @@ -195,23 +202,29 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame": feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'} + # First pass: collect _ntrl columns (these come first in qlib order) + ntrl_alpha158_cols = [] + ntrl_market_ext_cols = [] + raw_alpha158_cols = [] + raw_market_ext_cols = [] + flag_cols = [] + indus_idx_col = None + for col in df.columns: if col == 'indus_idx': - columns_with_group.append(('indus_idx', col)) + indus_idx_col = col elif col in feature_flag_cols: - columns_with_group.append(('feature_flag', col)) + flag_cols.append(col) elif col.endswith('_ntrl'): base_name = col[:-5] # Remove _ntrl suffix (5 characters) if base_name in alpha158_base: - columns_with_group.append(('feature', col)) + ntrl_alpha158_cols.append(col) elif base_name in market_ext_all: - columns_with_group.append(('feature_ext', col)) - else: - columns_with_group.append(('feature', col)) # Default to feature + ntrl_market_ext_cols.append(col) elif col in alpha158_base: - columns_with_group.append(('feature', col)) + raw_alpha158_cols.append(col) elif col in market_ext_all: - columns_with_group.append(('feature_ext', col)) + raw_market_ext_cols.append(col) elif col in INDUSTRY_FLAG_COLS: columns_with_group.append(('indus_flag', col)) elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}: @@ -221,6 +234,27 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame": print(f" Warning: Unknown column '{col}', assigning to 'other' group") columns_with_group.append(('other', col)) + # Build columns in qlib order: [_ntrl] + [raw] for each feature group + # Feature group: alpha158_ntrl + alpha158 + for col in sorted(ntrl_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x.replace('_ntrl', '')) if x.replace('_ntrl', '') in ALPHA158_COLS else 999): + columns_with_group.append(('feature', col)) + for col in sorted(raw_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x) if x in ALPHA158_COLS else 999): + columns_with_group.append(('feature', col)) + + # Feature_ext group: market_ext_ntrl + market_ext + for col in ntrl_market_ext_cols: + columns_with_group.append(('feature_ext', col)) + for col in raw_market_ext_cols: + columns_with_group.append(('feature_ext', col)) + + # Feature_flag group + for col in flag_cols: + columns_with_group.append(('feature_flag', col)) + + # Indus_idx + if indus_idx_col: + columns_with_group.append(('indus_idx', indus_idx_col)) + # Create MultiIndex columns multi_cols = pd.MultiIndex.from_tuples(columns_with_group) df.columns = multi_cols diff --git a/stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py b/stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py new file mode 100644 index 0000000..1d73846 --- /dev/null +++ b/stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python +""" +Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file. + +This script extracts the pre-fitted mean_train and std_train parameters from +Qlib's RobustZScoreNorm processor and saves them as reusable .npy files. + +The extracted parameters can be used by the Polars RobustZScoreNorm processor +via the from_version() class method. + +Output structure: + stock_1d/d033/alpha158_beta/data/robust_zscore_params/ + └── {version}/ + ├── mean_train.npy + ├── std_train.npy + └── metadata.json +""" + +import os +import sys +import json +import pickle as pkl +from pathlib import Path +from datetime import datetime +import numpy as np + +# Default paths +DEFAULT_PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc" +DEFAULT_VERSION = "csiallx_feature2_ntrla_flag_pnlnorm" + +# Alpha158 columns in order (158 features) +ALPHA158_COLS = [ + 'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2', + 'OPEN0', 'HIGH0', 'LOW0', 'VWAP0', + 'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60', + 'MA5', 'MA10', 'MA20', 'MA30', 'MA60', + 'STD5', 'STD10', 'STD20', 'STD30', 'STD60', + 'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60', + 'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60', + 'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60', + 'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60', + 'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60', + 'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60', + 'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60', + 'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60', + 'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60', + 'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60', + 'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60', + 'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60', + 'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60', + 'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60', + 'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60', + 'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60', + 'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60', + 'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60', + 'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60', + 'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60', + 'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60', + 'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60', + 'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60', + 'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60', + 'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60', + 'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60' +] +assert len(ALPHA158_COLS) == 158, f"Expected 158 alpha158 cols, got {len(ALPHA158_COLS)}" + + +def extract_robust_zscore_params(proc_list_path: str) -> dict: + """ + Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file. + + Args: + proc_list_path: Path to the proc_list.proc pickle file + + Returns: + Dictionary containing: + - mean_train: numpy array of shape (330,) + - std_train: numpy array of shape (330,) + - fit_start_time: datetime string + - fit_end_time: datetime string + - fields_group: list of field groups + """ + print(f"Loading proc_list.proc from: {proc_list_path}") + + with open(proc_list_path, 'rb') as f: + proc_list = pkl.load(f) + + print(f"Loaded {len(proc_list)} processors from proc_list") + + # Find RobustZScoreNorm processor (typically at index 7) + zscore_proc = None + for i, proc in enumerate(proc_list): + proc_name = type(proc).__name__ + print(f" [{i}] {proc_name}") + if proc_name == "RobustZScoreNorm": + zscore_proc = proc + + if zscore_proc is None: + raise ValueError("RobustZScoreNorm processor not found in proc_list") + + # Extract parameters + params = { + 'mean_train': zscore_proc.mean_train, + 'std_train': zscore_proc.std_train, + 'fit_start_time': getattr(zscore_proc, 'fit_start_time', None), + 'fit_end_time': getattr(zscore_proc, 'fit_end_time', None), + 'fields_group': getattr(zscore_proc, 'fields_group', None), + } + + print(f"\nExtracted RobustZScoreNorm parameters:") + print(f" mean_train shape: {params['mean_train'].shape}, dtype: {params['mean_train'].dtype}") + print(f" std_train shape: {params['std_train'].shape}, dtype: {params['std_train'].dtype}") + print(f" fit_start_time: {params['fit_start_time']}") + print(f" fit_end_time: {params['fit_end_time']}") + print(f" fields_group: {params['fields_group']}") + + return params + + +def build_feature_column_names() -> list: + """ + Build the complete list of 330 feature column names in order. + + Feature order (330 total): + 1. alpha158_ntrl (158 features) + 2. alpha158_raw (158 features) + 3. market_ext_ntrl (7 features) + 4. market_ext_raw (7 features) + + market_ext columns (after processing): + - Base: turnover, free_turnover, log_size, con_rating_strength + - Diff: turnover_diff, free_turnover_diff, con_rating_strength_diff + - Note: log_size_diff is removed by ColumnRemover + """ + # Alpha158 neutralized columns (158) + alpha158_ntrl = [f"{c}_ntrl" for c in ALPHA158_COLS] + + # Alpha158 raw columns (158) + alpha158_raw = ALPHA158_COLS.copy() + + # market_ext columns (7 after ColumnRemover) + # After Diff: 4 base + 4 diff = 8 + # After ColumnRemover (removes log_size_diff): 7 remain + market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + market_ext_diff = ['turnover_diff', 'free_turnover_diff', 'log_size_diff', 'con_rating_strength_diff'] + market_ext_all = market_ext_base + market_ext_diff + market_ext_final = [c for c in market_ext_all if c != 'log_size_diff'] + + # market_ext neutralized columns (7) + market_ext_ntrl = [f"{c}_ntrl" for c in market_ext_final] + + # market_ext raw columns (7) + market_ext_raw = market_ext_final.copy() + + # Combine all feature columns in Qlib order + feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw + + print(f"\nBuilt feature column names:") + print(f" alpha158_ntrl: {len(alpha158_ntrl)} features") + print(f" alpha158_raw: {len(alpha158_raw)} features") + print(f" market_ext_ntrl: {len(market_ext_ntrl)} features") + print(f" market_ext_raw: {len(market_ext_raw)} features") + print(f" Total: {len(feature_cols)} features") + + return feature_cols + + +def save_parameters( + params: dict, + feature_cols: list, + output_dir: str, + version: str +): + """ + Save extracted parameters to output directory. + + Creates: + - mean_train.npy + - std_train.npy + - metadata.json + """ + output_path = Path(output_dir) / version + output_path.mkdir(parents=True, exist_ok=True) + + print(f"\nSaving parameters to: {output_path}") + + # Save mean_train.npy + mean_path = output_path / "mean_train.npy" + np.save(mean_path, params['mean_train']) + print(f" Saved mean_train to: {mean_path}") + + # Save std_train.npy + std_path = output_path / "std_train.npy" + np.save(std_path, params['std_train']) + print(f" Saved std_train to: {std_path}") + + # Build metadata + metadata = { + 'version': version, + 'created_at': datetime.now().isoformat(), + 'source_file': str(params.get('fit_start_time', 'unknown')), + 'fit_start_time': str(params['fit_start_time']), + 'fit_end_time': str(params['fit_end_time']), + 'fields_group': list(params['fields_group']) if params['fields_group'] else None, + 'feature_columns': { + 'alpha158_ntrl': ALPHA158_COLS, # Store base names, _ntrl is implied + 'alpha158_raw': ALPHA158_COLS, + 'market_ext_ntrl': [ + 'turnover', 'free_turnover', 'log_size', 'con_rating_strength', + 'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff' + ], + 'market_ext_raw': [ + 'turnover', 'free_turnover', 'log_size', 'con_rating_strength', + 'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff' + ], + }, + 'feature_count': { + 'alpha158_ntrl': 158, + 'alpha158_raw': 158, + 'market_ext_ntrl': 7, + 'market_ext_raw': 7, + 'total': 330 + }, + 'parameter_shapes': { + 'mean_train': list(params['mean_train'].shape), + 'std_train': list(params['std_train'].shape) + } + } + + # Save metadata.json + metadata_path = output_path / "metadata.json" + with open(metadata_path, 'w') as f: + json.dump(metadata, f, indent=2) + print(f" Saved metadata to: {metadata_path}") + + return output_path + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Extract RobustZScoreNorm parameters from Qlib's proc_list.proc" + ) + parser.add_argument( + '--proc-list', + type=str, + default=DEFAULT_PROC_LIST_PATH, + help=f"Path to proc_list.proc (default: {DEFAULT_PROC_LIST_PATH})" + ) + parser.add_argument( + '--version', + type=str, + default=DEFAULT_VERSION, + help=f"Version name for output directory (default: {DEFAULT_VERSION})" + ) + parser.add_argument( + '--output-dir', + type=str, + default=str(Path(__file__).parent.parent / "data" / "robust_zscore_params"), + help="Output directory for parameter files" + ) + + args = parser.parse_args() + + print("=" * 80) + print("Extract Qlib RobustZScoreNorm Parameters") + print("=" * 80) + print(f"Source: {args.proc_list}") + print(f"Version: {args.version}") + print(f"Output: {args.output_dir}") + print() + + # Step 1: Extract parameters from proc_list.proc + params = extract_robust_zscore_params(args.proc_list) + + # Step 2: Build feature column names + feature_cols = build_feature_column_names() + + # Step 3: Verify parameter shape matches feature count + if params['mean_train'].shape[0] != len(feature_cols): + print(f"\nWARNING: Parameter shape mismatch!") + print(f" Expected: {len(feature_cols)} features") + print(f" Got: {params['mean_train'].shape[0]} parameters") + else: + print(f"\n✓ Parameter shape matches feature count ({len(feature_cols)})") + + # Step 4: Save parameters + output_path = save_parameters(params, feature_cols, args.output_dir, args.version) + + print("\n" + "=" * 80) + print("Extraction complete!") + print("=" * 80) + print(f"Output files:") + print(f" {output_path}/mean_train.npy") + print(f" {output_path}/std_train.npy") + print(f" {output_path}/metadata.json") + print() + print("Usage in Polars RobustZScoreNorm:") + print(f' norm = RobustZScoreNorm.from_version("{args.version}")') + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py index 6902d31..d8316c3 100644 --- a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py +++ b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py @@ -74,38 +74,25 @@ INDUSTRY_FLAG_COLS = [ 'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70' ] -# Stock universe filter: csiallx = All A-shares excluding BSE/NEEQ and STAR market -# This matches the original qlib handler configuration -# - Include: SH600xxx, SH601xxx, SH603xxx, SH605xxx (Shanghai Main Board) -# - Include: SZ000xxx, SZ001xxx, SZ002xxx, SZ003xxx (Shenzhen Main Board) -# - Include: SZ300xxx, SZ301xxx (ChiNext) -# - Exclude: SH688xxx, SH689xxx (STAR Market/科创板) -# - Exclude: 4xxxxx, 8xxxxx (BSE/NEEQ/北交所/新三板) -def filter_stock_universe(df: pl.DataFrame) -> pl.DataFrame: + +def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame: """ - Filter dataframe to csiallx stock universe (A-shares only). + Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE) using qshare spine functions. + + This uses qshare's filter_instruments which loads the instrument list from: + /data/qlib/default/data_ops/target/instruments/csiallx.txt + + Args: + df: Input DataFrame with datetime and instrument columns + instruments: Market name for spine creation (default: 'csiallx') - This filter matches the original qlib handler configuration which excludes: - - BSE/NEEQ stocks (4xxxxx, 8xxxxx) - - STAR Market stocks (688xxx, 689xxx) + Returns: + Filtered DataFrame with only instruments in the specified universe """ - inst_str = pl.col('instrument').cast(pl.String).str.zfill(6) - - # Define inclusion patterns - is_sh_main = inst_str.str.starts_with('60') | inst_str.str.starts_with('61') - is_sz_main = inst_str.str.starts_with('0') - is_chi_next = inst_str.str.starts_with('300') | inst_str.str.starts_with('301') - - # Define exclusion patterns (explicitly exclude these) - is_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689') - is_bseeq = inst_str.str.starts_with('4') | inst_str.str.starts_with('8') - - # Filter: include main boards and ChiNext, exclude STAR and BSE/NEEQ - df = df.filter( - (is_sh_main | is_sz_main | is_chi_next) & - (~is_star) & - (~is_bseeq) - ) + from qshare.algo.polars.spine import filter_instruments + + # Use qshare's filter_instruments with csiallx market name + df = filter_instruments(df, instruments=instruments) return df