- Add extract_qlib_params.py script to extract pre-fitted mean/std parameters from Qlib's proc_list.proc and save as reusable .npy files with metadata.json - Add RobustZScoreNorm.from_version() class method to load saved parameters by version name, supporting multiple parameter versions coexistence - Update dump_polars_dataset.py to use from_version() instead of loading parameters directly from proc_list.proc - Update generate_beta_embedding.py to use qshare's filter_instruments for stock universe filtering - Save parameters to data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/ with 330 features (158 alpha158_ntrl + 158 alpha158_raw + 7 market_ext_ntrl + 7 market_ext_raw) - Update README.md with documentation for parameter extraction and usage Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>master
parent
8bd36c1939
commit
89bd1a528e
@ -0,0 +1,753 @@
|
|||||||
|
"""
|
||||||
|
Polars-based data processors for financial feature transformation.
|
||||||
|
|
||||||
|
This module provides Polars implementations of Qlib-style data processors
|
||||||
|
used in the data_ops pipeline. Each processor follows a consistent interface:
|
||||||
|
- Takes a Polars DataFrame as input
|
||||||
|
- Returns a transformed Polars DataFrame
|
||||||
|
|
||||||
|
Processors are organized by category:
|
||||||
|
- Feature Engineering: DiffProcessor
|
||||||
|
- Flag Injection: FlagMarketInjector, FlagSTInjector
|
||||||
|
- Column Operations: ColumnRemover, FlagToOnehot
|
||||||
|
- Normalization: IndusNtrlInjector, RobustZScoreNorm
|
||||||
|
- Data Cleaning: Fillna
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Tuple, Optional, Dict
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Feature Engineering
|
||||||
|
'DiffProcessor',
|
||||||
|
# Flag Injection
|
||||||
|
'FlagMarketInjector',
|
||||||
|
'FlagSTInjector',
|
||||||
|
# Column Operations
|
||||||
|
'ColumnRemover',
|
||||||
|
'FlagToOnehot',
|
||||||
|
# Normalization
|
||||||
|
'IndusNtrlInjector',
|
||||||
|
'RobustZScoreNorm',
|
||||||
|
# Data Cleaning
|
||||||
|
'Fillna',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Feature Engineering Processors
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class DiffProcessor:
|
||||||
|
"""
|
||||||
|
Calculate period-over-period differences within each instrument.
|
||||||
|
|
||||||
|
For each specified column, calculates the diff(periods) within each
|
||||||
|
instrument group, using forward-fill for NaN handling.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
columns: List of columns to calculate diff for
|
||||||
|
suffix: Suffix to append to diff column names (default: 'diff')
|
||||||
|
periods: Number of periods to diff (default: 1)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> processor = DiffProcessor(['turnover', 'log_size'])
|
||||||
|
>>> df = processor.process(df)
|
||||||
|
>>> # Creates: turnover_diff, log_size_diff
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
columns: List[str],
|
||||||
|
suffix: str = 'diff',
|
||||||
|
periods: int = 1
|
||||||
|
):
|
||||||
|
self.columns = columns
|
||||||
|
self.suffix = suffix
|
||||||
|
self.periods = periods
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Add diff features for specified columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with datetime and instrument columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with original columns + diff columns with suffix
|
||||||
|
"""
|
||||||
|
# Sort by instrument and datetime for correct diff calculation
|
||||||
|
df = df.sort(['instrument', 'datetime'])
|
||||||
|
|
||||||
|
# Add diff for each column
|
||||||
|
for col in self.columns:
|
||||||
|
if col in df.columns:
|
||||||
|
diff_col = f"{col}_{self.suffix}"
|
||||||
|
df = df.with_columns([
|
||||||
|
pl.col(col)
|
||||||
|
.diff(self.periods)
|
||||||
|
.over('instrument')
|
||||||
|
.alias(diff_col)
|
||||||
|
])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"DiffProcessor(columns={self.columns}, suffix='{self.suffix}', periods={self.periods})"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Flag Injection Processors
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class FlagMarketInjector:
|
||||||
|
"""
|
||||||
|
Inject market sector flags based on instrument codes.
|
||||||
|
|
||||||
|
Classifies stocks into market segments based on their instrument codes.
|
||||||
|
Supports both formats:
|
||||||
|
- With exchange prefix: SH600000, SZ000001, SH688000, SZ300001
|
||||||
|
- Numeric only: 600000, 000001, 688000, 300001
|
||||||
|
|
||||||
|
Market classification:
|
||||||
|
- market_0: Main board (SH60xxx, SZ00xxx, or 6xxxxx, 0xxxxx) - 主板
|
||||||
|
- market_1: STAR/ChiNext (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx,
|
||||||
|
or 688xxx, 689xxx, 300xxx, 301xxx) - 科创板/创业板
|
||||||
|
|
||||||
|
Note: Does NOT include 新三板/北交所 (NE40xxx, NE42xxx, NE43xxx, NE8xxx)
|
||||||
|
in the classification - these stocks will have 0 for both flags.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> processor = FlagMarketInjector()
|
||||||
|
>>> df = processor.process(df)
|
||||||
|
>>> # Adds: market_0, market_1 (both int8)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Add market classification columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with instrument column
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with market_0, market_1 columns added
|
||||||
|
"""
|
||||||
|
# Convert instrument to string
|
||||||
|
inst_col = pl.col('instrument').cast(pl.String)
|
||||||
|
|
||||||
|
# Remove exchange prefix if present (SH/SZ/NE -> numeric part)
|
||||||
|
# E.g., SH600000 -> 600000, SZ000001 -> 000001
|
||||||
|
inst_numeric = inst_col.str.replace_all("^(SH|SZ|NE)", "")
|
||||||
|
|
||||||
|
# Get first digit(s) for market classification
|
||||||
|
first_digit = inst_numeric.str.slice(0, 1)
|
||||||
|
first_three = inst_numeric.str.slice(0, 3)
|
||||||
|
|
||||||
|
# market_0: Main board (60xxx, 601xxx, 603xxx, 000xxx, 001xxx, 002xxx)
|
||||||
|
# Excludes STAR market (688xxx, 689xxx) which start with '6' but are not main board
|
||||||
|
is_sh_main = (first_digit == '6') & ~(first_three == '688') & ~(first_three == '689')
|
||||||
|
is_sz_main = first_digit == '0'
|
||||||
|
|
||||||
|
# market_1: STAR/ChiNext (688xxx, 689xxx, 300xxx, 301xxx)
|
||||||
|
is_sh_star = (first_three == '688') | (first_three == '689')
|
||||||
|
is_sz_chi = (first_three == '300') | (first_three == '301')
|
||||||
|
|
||||||
|
df = df.with_columns([
|
||||||
|
# market_0 = 主板
|
||||||
|
(is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'),
|
||||||
|
# market_1 = 科创板 + 创业板
|
||||||
|
(is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1')
|
||||||
|
])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "FlagMarketInjector()"
|
||||||
|
|
||||||
|
|
||||||
|
class FlagSTInjector:
|
||||||
|
"""
|
||||||
|
Inject ST (Special Treatment) flag.
|
||||||
|
|
||||||
|
Creates IsST column from ST_S and ST_Y flags:
|
||||||
|
- IsST = 1 if ST_S or ST_Y is True (stock is ST or *ST)
|
||||||
|
- IsST = 0 otherwise
|
||||||
|
|
||||||
|
If ST flags are not available, creates a placeholder column (all zeros).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
mark_st_as: Value to mark ST stocks as (default: 1)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> processor = FlagSTInjector()
|
||||||
|
>>> df = processor.process(df)
|
||||||
|
>>> # Adds: IsST (int8)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mark_st_as: int = 1):
|
||||||
|
self.mark_st_as = mark_st_as
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Add IsST column from ST flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with ST_S, ST_Y columns (or without for placeholder)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with IsST column added
|
||||||
|
"""
|
||||||
|
# Check if ST flags are available
|
||||||
|
if 'ST_S' in df.columns and 'ST_Y' in df.columns:
|
||||||
|
# Create IsST from actual ST flags
|
||||||
|
df = df.with_columns([
|
||||||
|
((pl.col('ST_S').cast(pl.Boolean, strict=False) |
|
||||||
|
pl.col('ST_Y').cast(pl.Boolean, strict=False))
|
||||||
|
.cast(pl.Int8).alias('IsST'))
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Create placeholder (all zeros) if ST flags not available
|
||||||
|
df = df.with_columns([
|
||||||
|
pl.lit(0).cast(pl.Int8).alias('IsST')
|
||||||
|
])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"FlagSTInjector(mark_st_as={self.mark_st_as})"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Column Operation Processors
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class ColumnRemover:
|
||||||
|
"""
|
||||||
|
Remove specified columns from the DataFrame.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
columns_to_remove: List of column names to drop
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> processor = ColumnRemover(['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
|
||||||
|
>>> df = processor.process(df)
|
||||||
|
>>> # Removes specified columns
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, columns_to_remove: List[str]):
|
||||||
|
self.columns_to_remove = columns_to_remove
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Remove specified columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame without specified columns
|
||||||
|
"""
|
||||||
|
# Only remove columns that exist
|
||||||
|
cols_to_drop = [c for c in self.columns_to_remove if c in df.columns]
|
||||||
|
if cols_to_drop:
|
||||||
|
df = df.drop(cols_to_drop)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"ColumnRemover(columns_to_remove={self.columns_to_remove})"
|
||||||
|
|
||||||
|
|
||||||
|
class FlagToOnehot:
|
||||||
|
"""
|
||||||
|
Convert flag columns to one-hot encoded index.
|
||||||
|
|
||||||
|
For multiple one-hot encoded columns, finds which flag is set and
|
||||||
|
returns the corresponding index. Uses -1 as default for rows with
|
||||||
|
no flags set.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
flag_columns: List of boolean flag column names
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> processor = FlagToOnehot(['gds_CC10', 'gds_CC11', ...])
|
||||||
|
>>> df = processor.process(df)
|
||||||
|
>>> # Adds: indus_idx (index of first True flag, or -1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, flag_columns: List[str]):
|
||||||
|
self.flag_columns = flag_columns
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Convert flag columns to single index column.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with flag columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with indus_idx column added, original flags dropped
|
||||||
|
"""
|
||||||
|
# Build a when/then chain to find the industry index
|
||||||
|
# Start with -1 (no industry) as default
|
||||||
|
indus_expr = pl.lit(-1)
|
||||||
|
|
||||||
|
for idx, col in enumerate(self.flag_columns):
|
||||||
|
if col in df.columns:
|
||||||
|
indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr)
|
||||||
|
|
||||||
|
df = df.with_columns([indus_expr.alias('indus_idx')])
|
||||||
|
|
||||||
|
# Drop the original one-hot columns
|
||||||
|
cols_to_drop = [c for c in self.flag_columns if c in df.columns]
|
||||||
|
if cols_to_drop:
|
||||||
|
df = df.drop(cols_to_drop)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"FlagToOnehot(flag_columns={len(self.flag_columns)} columns)"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Normalization Processors
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class IndusNtrlInjector:
|
||||||
|
"""
|
||||||
|
Industry neutralization for features.
|
||||||
|
|
||||||
|
For each feature, subtracts the industry mean (grouped by indus_idx
|
||||||
|
within each datetime) from the feature value. Creates new columns
|
||||||
|
with the specified suffix while keeping original columns.
|
||||||
|
|
||||||
|
This performs cross-sectional neutralization per datetime, matching
|
||||||
|
Qlib's cal_indus_ntrl behavior.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
feature_cols: List of feature columns to neutralize
|
||||||
|
suffix: Suffix for neutralized column names (default: '_ntrl')
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> processor = IndusNtrlInjector(['KMID', 'KLEN'], suffix='_ntrl')
|
||||||
|
>>> df = processor.process(df)
|
||||||
|
>>> # Adds: KMID_ntrl, KLEN_ntrl
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'):
|
||||||
|
self.feature_cols = feature_cols
|
||||||
|
self.suffix = suffix
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Apply industry neutralization to specified features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with feature columns and indus_idx
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with neutralized columns added (_ntrl suffix)
|
||||||
|
"""
|
||||||
|
# Filter to only columns that exist
|
||||||
|
existing_cols = [c for c in self.feature_cols if c in df.columns]
|
||||||
|
|
||||||
|
for col in existing_cols:
|
||||||
|
ntrl_col = f"{col}{self.suffix}"
|
||||||
|
# Calculate industry mean PER DATETIME and subtract from feature
|
||||||
|
# Use group_by().transform() for proper group-wise operation
|
||||||
|
df = df.with_columns([
|
||||||
|
(pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'])).alias(ntrl_col)
|
||||||
|
])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"IndusNtrlInjector(feature_cols={len(self.feature_cols)} columns, suffix='{self.suffix}')"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RobustZScoreNorm:
|
||||||
|
"""
|
||||||
|
Robust z-score normalization using median/MAD.
|
||||||
|
|
||||||
|
Applies the transformation: (x - median) / (1.4826 * MAD)
|
||||||
|
where MAD = median(|x - median|)
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
1. Per-datetime computation (default): Calculates median/MAD for each datetime
|
||||||
|
2. Pre-fitted parameters: Uses provided mean/std arrays (from Qlib processor)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
feature_cols: List of feature columns to normalize
|
||||||
|
clip_range: Clip normalized values to this range (default: (-3, 3))
|
||||||
|
use_qlib_params: Use pre-fitted parameters (default: False)
|
||||||
|
qlib_mean: Pre-fitted mean array (required if use_qlib_params=True)
|
||||||
|
qlib_std: Pre-fitted std array (required if use_qlib_params=True)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Using pre-fitted Qlib parameters
|
||||||
|
>>> processor = RobustZScoreNorm(
|
||||||
|
... feature_cols=['KMID', 'KLEN'],
|
||||||
|
... use_qlib_params=True,
|
||||||
|
... qlib_mean=mean_array,
|
||||||
|
... qlib_std=std_array
|
||||||
|
... )
|
||||||
|
>>> df = processor.process(df)
|
||||||
|
|
||||||
|
# Loading parameters from saved version
|
||||||
|
>>> processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm")
|
||||||
|
>>> df = processor.process(df)
|
||||||
|
"""
|
||||||
|
|
||||||
|
feature_cols: List[str]
|
||||||
|
clip_range: Tuple[float, float] = (-3.0, 3.0)
|
||||||
|
use_qlib_params: bool = False
|
||||||
|
qlib_mean: Optional[np.ndarray] = None
|
||||||
|
qlib_std: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_version(
|
||||||
|
cls,
|
||||||
|
version: str,
|
||||||
|
feature_cols: Optional[List[str]] = None,
|
||||||
|
clip_range: Tuple[float, float] = (-3.0, 3.0),
|
||||||
|
params_dir: str = None
|
||||||
|
) -> "RobustZScoreNorm":
|
||||||
|
"""
|
||||||
|
Create a RobustZScoreNorm instance from saved parameters by version name.
|
||||||
|
|
||||||
|
This loads pre-extracted mean_train and std_train from the parameter
|
||||||
|
directory structure:
|
||||||
|
{params_dir}/{version}/
|
||||||
|
├── mean_train.npy
|
||||||
|
├── std_train.npy
|
||||||
|
└── metadata.json
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: Version name (e.g., "csiallx_feature2_ntrla_flag_pnlnorm")
|
||||||
|
feature_cols: Optional list of feature columns. If None, uses the
|
||||||
|
order from metadata.json (alpha158_ntrl + alpha158_raw +
|
||||||
|
market_ext_ntrl + market_ext_raw)
|
||||||
|
clip_range: Clip range for normalized values (default: (-3, 3))
|
||||||
|
params_dir: Base directory for parameter versions. If None, uses:
|
||||||
|
stock_1d/d033/alpha158_beta/data/robust_zscore_params/
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RobustZScoreNorm instance with pre-fitted parameters loaded
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If version directory or parameter files not found
|
||||||
|
ValueError: If feature column count doesn't match parameter shape
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> processor = RobustZScoreNorm.from_version(
|
||||||
|
... "csiallx_feature2_ntrla_flag_pnlnorm"
|
||||||
|
... )
|
||||||
|
>>> df = processor.process(df)
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Set default params_dir
|
||||||
|
if params_dir is None:
|
||||||
|
# Default to the standard location
|
||||||
|
# Go from cta_1d/src/processors/ to alpha_lab/stock_1d/d033/alpha158_beta/data/robust_zscore_params/
|
||||||
|
params_dir = Path(__file__).parent.parent.parent.parent / \
|
||||||
|
"stock_1d" / "d033" / "alpha158_beta" / "data" / "robust_zscore_params"
|
||||||
|
else:
|
||||||
|
params_dir = Path(params_dir)
|
||||||
|
|
||||||
|
version_dir = params_dir / version
|
||||||
|
|
||||||
|
# Check version directory exists
|
||||||
|
if not version_dir.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Version directory not found: {version_dir}\n"
|
||||||
|
f"Available versions should be in: {params_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load mean_train.npy
|
||||||
|
mean_path = version_dir / "mean_train.npy"
|
||||||
|
if not mean_path.exists():
|
||||||
|
raise FileNotFoundError(f"mean_train.npy not found: {mean_path}")
|
||||||
|
mean_train = np.load(mean_path)
|
||||||
|
|
||||||
|
# Load std_train.npy
|
||||||
|
std_path = version_dir / "std_train.npy"
|
||||||
|
if not std_path.exists():
|
||||||
|
raise FileNotFoundError(f"std_train.npy not found: {std_path}")
|
||||||
|
std_train = np.load(std_path)
|
||||||
|
|
||||||
|
# Load metadata.json for feature column names
|
||||||
|
metadata_path = version_dir / "metadata.json"
|
||||||
|
if metadata_path.exists():
|
||||||
|
with open(metadata_path, 'r') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
|
||||||
|
# Build feature columns from metadata if not provided
|
||||||
|
if feature_cols is None:
|
||||||
|
feature_columns = metadata.get('feature_columns', {})
|
||||||
|
alpha158_ntrl = [f"{c}_ntrl" for c in feature_columns.get('alpha158_ntrl', [])]
|
||||||
|
alpha158_raw = feature_columns.get('alpha158_raw', [])
|
||||||
|
market_ext_ntrl = [f"{c}_ntrl" for c in feature_columns.get('market_ext_ntrl', [])]
|
||||||
|
market_ext_raw = feature_columns.get('market_ext_raw', [])
|
||||||
|
|
||||||
|
feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw
|
||||||
|
|
||||||
|
# Validate feature column count matches parameter shape
|
||||||
|
expected_count = len(mean_train)
|
||||||
|
if feature_cols and len(feature_cols) != expected_count:
|
||||||
|
raise ValueError(
|
||||||
|
f"Feature column count ({len(feature_cols)}) does not match "
|
||||||
|
f"parameter shape ({expected_count})"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Loaded RobustZScoreNorm parameters from version '{version}':")
|
||||||
|
print(f" mean_train shape: {mean_train.shape}")
|
||||||
|
print(f" std_train shape: {std_train.shape}")
|
||||||
|
print(f" feature_cols: {len(feature_cols) if feature_cols else 'not specified'}")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
feature_cols=feature_cols,
|
||||||
|
clip_range=clip_range,
|
||||||
|
use_qlib_params=True,
|
||||||
|
qlib_mean=mean_train,
|
||||||
|
qlib_std=std_train
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate parameters after initialization."""
|
||||||
|
if self.use_qlib_params:
|
||||||
|
if self.qlib_mean is None or self.qlib_std is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Must provide qlib_mean and qlib_std when use_qlib_params=True"
|
||||||
|
)
|
||||||
|
# Convert to numpy arrays if not already
|
||||||
|
self.qlib_mean = np.asarray(self.qlib_mean)
|
||||||
|
self.qlib_std = np.asarray(self.qlib_std)
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Apply robust z-score normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame with feature columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with normalized columns (in-place modification)
|
||||||
|
"""
|
||||||
|
# Filter to only columns that exist
|
||||||
|
existing_cols = [c for c in self.feature_cols if c in df.columns]
|
||||||
|
|
||||||
|
if self.use_qlib_params:
|
||||||
|
# Use pre-fitted parameters (fit once, apply to all dates)
|
||||||
|
for i, col in enumerate(existing_cols):
|
||||||
|
if i < len(self.qlib_mean):
|
||||||
|
mean_val = float(self.qlib_mean[i])
|
||||||
|
std_val = float(self.qlib_std[i])
|
||||||
|
|
||||||
|
# Apply z-score normalization using pre-fitted params
|
||||||
|
df = df.with_columns([
|
||||||
|
((pl.col(col) - mean_val) / (std_val + 1e-8))
|
||||||
|
.clip(self.clip_range[0], self.clip_range[1])
|
||||||
|
.alias(col)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Compute per-datetime robust z-score
|
||||||
|
for col in existing_cols:
|
||||||
|
# Compute median per datetime
|
||||||
|
median_col = f"__median_{col}"
|
||||||
|
df = df.with_columns([
|
||||||
|
pl.col(col).median().over('datetime').alias(median_col)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Compute absolute deviation
|
||||||
|
abs_dev_col = f"__absdev_{col}"
|
||||||
|
df = df.with_columns([
|
||||||
|
(pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Compute MAD (median of absolute deviations)
|
||||||
|
mad_col = f"__mad_{col}"
|
||||||
|
df = df.with_columns([
|
||||||
|
pl.col(abs_dev_col).median().over('datetime').alias(mad_col)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Compute robust z-score and clip
|
||||||
|
df = df.with_columns([
|
||||||
|
((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8))
|
||||||
|
.clip(self.clip_range[0], self.clip_range[1])
|
||||||
|
.alias(col)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Clean up temporary columns
|
||||||
|
df = df.drop([median_col, abs_dev_col, mad_col])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Data Cleaning Processors
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class Fillna:
|
||||||
|
"""
|
||||||
|
Fill NaN values with specified value.
|
||||||
|
|
||||||
|
Fills NaN/None values in specified columns with the fill_value.
|
||||||
|
Only processes numeric columns (Float32, Float64, Int32, Int64, UInt32, UInt64).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
fill_value: Value to fill NaN with (default: 0.0)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> processor = Fillna(fill_value=0.0)
|
||||||
|
>>> df = processor.process(df, ['KMID', 'KLEN'])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fill_value: float = 0.0):
|
||||||
|
self.fill_value = fill_value
|
||||||
|
|
||||||
|
def process(self, df: pl.DataFrame, columns: List[str]) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Fill NaN values in specified columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame
|
||||||
|
columns: List of columns to fill NaN in
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with NaN values filled
|
||||||
|
"""
|
||||||
|
# Filter to only columns that exist and are numeric
|
||||||
|
existing_cols = [c for c in columns if c in df.columns]
|
||||||
|
|
||||||
|
for col in existing_cols:
|
||||||
|
# Check column dtype
|
||||||
|
dtype = df[col].dtype
|
||||||
|
if dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64,
|
||||||
|
pl.UInt32, pl.UInt64, pl.UInt16, pl.UInt8]:
|
||||||
|
df = df.with_columns([
|
||||||
|
pl.col(col).fill_null(self.fill_value).fill_nan(self.fill_value)
|
||||||
|
])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Fillna(fill_value={self.fill_value})"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Processor Pipeline Utilities
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def create_processor_pipeline(
|
||||||
|
alpha158_cols: List[str],
|
||||||
|
market_ext_base: List[str],
|
||||||
|
market_flag_cols: List[str],
|
||||||
|
industry_flag_cols: List[str],
|
||||||
|
columns_to_remove: Optional[List[str]] = None,
|
||||||
|
ntrl_suffix: str = '_ntrl'
|
||||||
|
) -> List:
|
||||||
|
"""
|
||||||
|
Create a complete processor pipeline configuration.
|
||||||
|
|
||||||
|
This factory function creates processors in the correct order:
|
||||||
|
1. DiffProcessor - adds diff features
|
||||||
|
2. FlagMarketInjector - adds market_0, market_1
|
||||||
|
3. FlagSTInjector - adds IsST
|
||||||
|
4. ColumnRemover - removes specified columns
|
||||||
|
5. FlagToOnehot - converts industry flags to index
|
||||||
|
6. IndusNtrlInjector (x2) - neutralizes alpha158 and market_ext
|
||||||
|
7. RobustZScoreNorm - normalizes features
|
||||||
|
8. Fillna - fills NaN values
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha158_cols: List of alpha158 feature names
|
||||||
|
market_ext_base: List of market extension base columns
|
||||||
|
market_flag_cols: List of market flag columns
|
||||||
|
industry_flag_cols: List of industry flag columns
|
||||||
|
columns_to_remove: Columns to remove (default: ['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
|
||||||
|
ntrl_suffix: Suffix for neutralized columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of processor instances in execution order
|
||||||
|
"""
|
||||||
|
if columns_to_remove is None:
|
||||||
|
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
||||||
|
|
||||||
|
return [
|
||||||
|
DiffProcessor(market_ext_base),
|
||||||
|
FlagMarketInjector(),
|
||||||
|
FlagSTInjector(),
|
||||||
|
ColumnRemover(columns_to_remove),
|
||||||
|
FlagToOnehot(industry_flag_cols),
|
||||||
|
IndusNtrlInjector(alpha158_cols, suffix=ntrl_suffix),
|
||||||
|
IndusNtrlInjector(market_ext_base, suffix=ntrl_suffix),
|
||||||
|
# RobustZScoreNorm and Fillna require fitted parameters
|
||||||
|
# and should be added separately
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_final_feature_columns(
|
||||||
|
alpha158_cols: List[str],
|
||||||
|
market_ext_base: List[str],
|
||||||
|
market_flag_cols: List[str],
|
||||||
|
columns_to_remove: Optional[List[str]] = None,
|
||||||
|
ntrl_suffix: str = '_ntrl'
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Get the final feature column structure after processing.
|
||||||
|
|
||||||
|
This is useful for determining the expected VAE input dimensions
|
||||||
|
and verifying feature order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha158_cols: List of alpha158 feature names
|
||||||
|
market_ext_base: List of market extension base columns (before Diff)
|
||||||
|
market_flag_cols: List of market flag columns (before ColumnRemover)
|
||||||
|
columns_to_remove: Columns to remove
|
||||||
|
ntrl_suffix: Suffix for neutralized columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with feature groups:
|
||||||
|
- 'norm_feature_cols': Features to normalize (in order)
|
||||||
|
- 'market_flag_cols': Market flag columns after processing
|
||||||
|
- 'all_feature_cols': All feature columns including indus_idx
|
||||||
|
"""
|
||||||
|
if columns_to_remove is None:
|
||||||
|
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
||||||
|
|
||||||
|
# After Diff: market_ext becomes base + diff
|
||||||
|
market_ext_after_diff = market_ext_base + [f"{c}_diff" for c in market_ext_base]
|
||||||
|
|
||||||
|
# After ColumnRemover: remove specified columns
|
||||||
|
market_ext_final = [c for c in market_ext_after_diff if c not in columns_to_remove]
|
||||||
|
market_flag_final = [c for c in market_flag_cols if c not in columns_to_remove]
|
||||||
|
|
||||||
|
# After FlagMarketInjector: add market_0, market_1
|
||||||
|
market_flag_final.extend(['market_0', 'market_1'])
|
||||||
|
|
||||||
|
# After FlagSTInjector: add IsST
|
||||||
|
market_flag_final.append('IsST')
|
||||||
|
|
||||||
|
# Build normalized feature columns (in Qlib order: ntrl + raw for each group)
|
||||||
|
alpha158_ntrl = [f"{c}{ntrl_suffix}" for c in alpha158_cols]
|
||||||
|
market_ext_ntrl = [f"{c}{ntrl_suffix}" for c in market_ext_final]
|
||||||
|
|
||||||
|
# Normalization order: alpha158_ntrl + alpha158 + market_ext_ntrl + market_ext
|
||||||
|
norm_feature_cols = alpha158_ntrl + alpha158_cols + market_ext_ntrl + market_ext_final
|
||||||
|
|
||||||
|
return {
|
||||||
|
'alpha158_cols': alpha158_cols,
|
||||||
|
'alpha158_ntrl_cols': alpha158_ntrl,
|
||||||
|
'market_ext_cols': market_ext_final,
|
||||||
|
'market_ext_ntrl_cols': market_ext_ntrl,
|
||||||
|
'market_flag_cols': market_flag_final,
|
||||||
|
'norm_feature_cols': norm_feature_cols,
|
||||||
|
'all_feature_cols': norm_feature_cols + market_flag_final + ['indus_idx'],
|
||||||
|
}
|
||||||
@ -0,0 +1,366 @@
|
|||||||
|
{
|
||||||
|
"version": "csiallx_feature2_ntrla_flag_pnlnorm",
|
||||||
|
"created_at": "2026-03-01T12:18:01.969109",
|
||||||
|
"source_file": "2013-01-01",
|
||||||
|
"fit_start_time": "2013-01-01",
|
||||||
|
"fit_end_time": "2018-12-31",
|
||||||
|
"fields_group": [
|
||||||
|
"feature",
|
||||||
|
"feature_ext"
|
||||||
|
],
|
||||||
|
"feature_columns": {
|
||||||
|
"alpha158_ntrl": [
|
||||||
|
"KMID",
|
||||||
|
"KLEN",
|
||||||
|
"KMID2",
|
||||||
|
"KUP",
|
||||||
|
"KUP2",
|
||||||
|
"KLOW",
|
||||||
|
"KLOW2",
|
||||||
|
"KSFT",
|
||||||
|
"KSFT2",
|
||||||
|
"OPEN0",
|
||||||
|
"HIGH0",
|
||||||
|
"LOW0",
|
||||||
|
"VWAP0",
|
||||||
|
"ROC5",
|
||||||
|
"ROC10",
|
||||||
|
"ROC20",
|
||||||
|
"ROC30",
|
||||||
|
"ROC60",
|
||||||
|
"MA5",
|
||||||
|
"MA10",
|
||||||
|
"MA20",
|
||||||
|
"MA30",
|
||||||
|
"MA60",
|
||||||
|
"STD5",
|
||||||
|
"STD10",
|
||||||
|
"STD20",
|
||||||
|
"STD30",
|
||||||
|
"STD60",
|
||||||
|
"BETA5",
|
||||||
|
"BETA10",
|
||||||
|
"BETA20",
|
||||||
|
"BETA30",
|
||||||
|
"BETA60",
|
||||||
|
"RSQR5",
|
||||||
|
"RSQR10",
|
||||||
|
"RSQR20",
|
||||||
|
"RSQR30",
|
||||||
|
"RSQR60",
|
||||||
|
"RESI5",
|
||||||
|
"RESI10",
|
||||||
|
"RESI20",
|
||||||
|
"RESI30",
|
||||||
|
"RESI60",
|
||||||
|
"MAX5",
|
||||||
|
"MAX10",
|
||||||
|
"MAX20",
|
||||||
|
"MAX30",
|
||||||
|
"MAX60",
|
||||||
|
"MIN5",
|
||||||
|
"MIN10",
|
||||||
|
"MIN20",
|
||||||
|
"MIN30",
|
||||||
|
"MIN60",
|
||||||
|
"QTLU5",
|
||||||
|
"QTLU10",
|
||||||
|
"QTLU20",
|
||||||
|
"QTLU30",
|
||||||
|
"QTLU60",
|
||||||
|
"QTLD5",
|
||||||
|
"QTLD10",
|
||||||
|
"QTLD20",
|
||||||
|
"QTLD30",
|
||||||
|
"QTLD60",
|
||||||
|
"RANK5",
|
||||||
|
"RANK10",
|
||||||
|
"RANK20",
|
||||||
|
"RANK30",
|
||||||
|
"RANK60",
|
||||||
|
"RSV5",
|
||||||
|
"RSV10",
|
||||||
|
"RSV20",
|
||||||
|
"RSV30",
|
||||||
|
"RSV60",
|
||||||
|
"IMAX5",
|
||||||
|
"IMAX10",
|
||||||
|
"IMAX20",
|
||||||
|
"IMAX30",
|
||||||
|
"IMAX60",
|
||||||
|
"IMIN5",
|
||||||
|
"IMIN10",
|
||||||
|
"IMIN20",
|
||||||
|
"IMIN30",
|
||||||
|
"IMIN60",
|
||||||
|
"IMXD5",
|
||||||
|
"IMXD10",
|
||||||
|
"IMXD20",
|
||||||
|
"IMXD30",
|
||||||
|
"IMXD60",
|
||||||
|
"CORR5",
|
||||||
|
"CORR10",
|
||||||
|
"CORR20",
|
||||||
|
"CORR30",
|
||||||
|
"CORR60",
|
||||||
|
"CORD5",
|
||||||
|
"CORD10",
|
||||||
|
"CORD20",
|
||||||
|
"CORD30",
|
||||||
|
"CORD60",
|
||||||
|
"CNTP5",
|
||||||
|
"CNTP10",
|
||||||
|
"CNTP20",
|
||||||
|
"CNTP30",
|
||||||
|
"CNTP60",
|
||||||
|
"CNTN5",
|
||||||
|
"CNTN10",
|
||||||
|
"CNTN20",
|
||||||
|
"CNTN30",
|
||||||
|
"CNTN60",
|
||||||
|
"CNTD5",
|
||||||
|
"CNTD10",
|
||||||
|
"CNTD20",
|
||||||
|
"CNTD30",
|
||||||
|
"CNTD60",
|
||||||
|
"SUMP5",
|
||||||
|
"SUMP10",
|
||||||
|
"SUMP20",
|
||||||
|
"SUMP30",
|
||||||
|
"SUMP60",
|
||||||
|
"SUMN5",
|
||||||
|
"SUMN10",
|
||||||
|
"SUMN20",
|
||||||
|
"SUMN30",
|
||||||
|
"SUMN60",
|
||||||
|
"SUMD5",
|
||||||
|
"SUMD10",
|
||||||
|
"SUMD20",
|
||||||
|
"SUMD30",
|
||||||
|
"SUMD60",
|
||||||
|
"VMA5",
|
||||||
|
"VMA10",
|
||||||
|
"VMA20",
|
||||||
|
"VMA30",
|
||||||
|
"VMA60",
|
||||||
|
"VSTD5",
|
||||||
|
"VSTD10",
|
||||||
|
"VSTD20",
|
||||||
|
"VSTD30",
|
||||||
|
"VSTD60",
|
||||||
|
"WVMA5",
|
||||||
|
"WVMA10",
|
||||||
|
"WVMA20",
|
||||||
|
"WVMA30",
|
||||||
|
"WVMA60",
|
||||||
|
"VSUMP5",
|
||||||
|
"VSUMP10",
|
||||||
|
"VSUMP20",
|
||||||
|
"VSUMP30",
|
||||||
|
"VSUMP60",
|
||||||
|
"VSUMN5",
|
||||||
|
"VSUMN10",
|
||||||
|
"VSUMN20",
|
||||||
|
"VSUMN30",
|
||||||
|
"VSUMN60",
|
||||||
|
"VSUMD5",
|
||||||
|
"VSUMD10",
|
||||||
|
"VSUMD20",
|
||||||
|
"VSUMD30",
|
||||||
|
"VSUMD60"
|
||||||
|
],
|
||||||
|
"alpha158_raw": [
|
||||||
|
"KMID",
|
||||||
|
"KLEN",
|
||||||
|
"KMID2",
|
||||||
|
"KUP",
|
||||||
|
"KUP2",
|
||||||
|
"KLOW",
|
||||||
|
"KLOW2",
|
||||||
|
"KSFT",
|
||||||
|
"KSFT2",
|
||||||
|
"OPEN0",
|
||||||
|
"HIGH0",
|
||||||
|
"LOW0",
|
||||||
|
"VWAP0",
|
||||||
|
"ROC5",
|
||||||
|
"ROC10",
|
||||||
|
"ROC20",
|
||||||
|
"ROC30",
|
||||||
|
"ROC60",
|
||||||
|
"MA5",
|
||||||
|
"MA10",
|
||||||
|
"MA20",
|
||||||
|
"MA30",
|
||||||
|
"MA60",
|
||||||
|
"STD5",
|
||||||
|
"STD10",
|
||||||
|
"STD20",
|
||||||
|
"STD30",
|
||||||
|
"STD60",
|
||||||
|
"BETA5",
|
||||||
|
"BETA10",
|
||||||
|
"BETA20",
|
||||||
|
"BETA30",
|
||||||
|
"BETA60",
|
||||||
|
"RSQR5",
|
||||||
|
"RSQR10",
|
||||||
|
"RSQR20",
|
||||||
|
"RSQR30",
|
||||||
|
"RSQR60",
|
||||||
|
"RESI5",
|
||||||
|
"RESI10",
|
||||||
|
"RESI20",
|
||||||
|
"RESI30",
|
||||||
|
"RESI60",
|
||||||
|
"MAX5",
|
||||||
|
"MAX10",
|
||||||
|
"MAX20",
|
||||||
|
"MAX30",
|
||||||
|
"MAX60",
|
||||||
|
"MIN5",
|
||||||
|
"MIN10",
|
||||||
|
"MIN20",
|
||||||
|
"MIN30",
|
||||||
|
"MIN60",
|
||||||
|
"QTLU5",
|
||||||
|
"QTLU10",
|
||||||
|
"QTLU20",
|
||||||
|
"QTLU30",
|
||||||
|
"QTLU60",
|
||||||
|
"QTLD5",
|
||||||
|
"QTLD10",
|
||||||
|
"QTLD20",
|
||||||
|
"QTLD30",
|
||||||
|
"QTLD60",
|
||||||
|
"RANK5",
|
||||||
|
"RANK10",
|
||||||
|
"RANK20",
|
||||||
|
"RANK30",
|
||||||
|
"RANK60",
|
||||||
|
"RSV5",
|
||||||
|
"RSV10",
|
||||||
|
"RSV20",
|
||||||
|
"RSV30",
|
||||||
|
"RSV60",
|
||||||
|
"IMAX5",
|
||||||
|
"IMAX10",
|
||||||
|
"IMAX20",
|
||||||
|
"IMAX30",
|
||||||
|
"IMAX60",
|
||||||
|
"IMIN5",
|
||||||
|
"IMIN10",
|
||||||
|
"IMIN20",
|
||||||
|
"IMIN30",
|
||||||
|
"IMIN60",
|
||||||
|
"IMXD5",
|
||||||
|
"IMXD10",
|
||||||
|
"IMXD20",
|
||||||
|
"IMXD30",
|
||||||
|
"IMXD60",
|
||||||
|
"CORR5",
|
||||||
|
"CORR10",
|
||||||
|
"CORR20",
|
||||||
|
"CORR30",
|
||||||
|
"CORR60",
|
||||||
|
"CORD5",
|
||||||
|
"CORD10",
|
||||||
|
"CORD20",
|
||||||
|
"CORD30",
|
||||||
|
"CORD60",
|
||||||
|
"CNTP5",
|
||||||
|
"CNTP10",
|
||||||
|
"CNTP20",
|
||||||
|
"CNTP30",
|
||||||
|
"CNTP60",
|
||||||
|
"CNTN5",
|
||||||
|
"CNTN10",
|
||||||
|
"CNTN20",
|
||||||
|
"CNTN30",
|
||||||
|
"CNTN60",
|
||||||
|
"CNTD5",
|
||||||
|
"CNTD10",
|
||||||
|
"CNTD20",
|
||||||
|
"CNTD30",
|
||||||
|
"CNTD60",
|
||||||
|
"SUMP5",
|
||||||
|
"SUMP10",
|
||||||
|
"SUMP20",
|
||||||
|
"SUMP30",
|
||||||
|
"SUMP60",
|
||||||
|
"SUMN5",
|
||||||
|
"SUMN10",
|
||||||
|
"SUMN20",
|
||||||
|
"SUMN30",
|
||||||
|
"SUMN60",
|
||||||
|
"SUMD5",
|
||||||
|
"SUMD10",
|
||||||
|
"SUMD20",
|
||||||
|
"SUMD30",
|
||||||
|
"SUMD60",
|
||||||
|
"VMA5",
|
||||||
|
"VMA10",
|
||||||
|
"VMA20",
|
||||||
|
"VMA30",
|
||||||
|
"VMA60",
|
||||||
|
"VSTD5",
|
||||||
|
"VSTD10",
|
||||||
|
"VSTD20",
|
||||||
|
"VSTD30",
|
||||||
|
"VSTD60",
|
||||||
|
"WVMA5",
|
||||||
|
"WVMA10",
|
||||||
|
"WVMA20",
|
||||||
|
"WVMA30",
|
||||||
|
"WVMA60",
|
||||||
|
"VSUMP5",
|
||||||
|
"VSUMP10",
|
||||||
|
"VSUMP20",
|
||||||
|
"VSUMP30",
|
||||||
|
"VSUMP60",
|
||||||
|
"VSUMN5",
|
||||||
|
"VSUMN10",
|
||||||
|
"VSUMN20",
|
||||||
|
"VSUMN30",
|
||||||
|
"VSUMN60",
|
||||||
|
"VSUMD5",
|
||||||
|
"VSUMD10",
|
||||||
|
"VSUMD20",
|
||||||
|
"VSUMD30",
|
||||||
|
"VSUMD60"
|
||||||
|
],
|
||||||
|
"market_ext_ntrl": [
|
||||||
|
"turnover",
|
||||||
|
"free_turnover",
|
||||||
|
"log_size",
|
||||||
|
"con_rating_strength",
|
||||||
|
"turnover_diff",
|
||||||
|
"free_turnover_diff",
|
||||||
|
"con_rating_strength_diff"
|
||||||
|
],
|
||||||
|
"market_ext_raw": [
|
||||||
|
"turnover",
|
||||||
|
"free_turnover",
|
||||||
|
"log_size",
|
||||||
|
"con_rating_strength",
|
||||||
|
"turnover_diff",
|
||||||
|
"free_turnover_diff",
|
||||||
|
"con_rating_strength_diff"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"feature_count": {
|
||||||
|
"alpha158_ntrl": 158,
|
||||||
|
"alpha158_raw": 158,
|
||||||
|
"market_ext_ntrl": 7,
|
||||||
|
"market_ext_raw": 7,
|
||||||
|
"total": 330
|
||||||
|
},
|
||||||
|
"parameter_shapes": {
|
||||||
|
"mean_train": [
|
||||||
|
330
|
||||||
|
],
|
||||||
|
"std_train": [
|
||||||
|
330
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,305 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file.
|
||||||
|
|
||||||
|
This script extracts the pre-fitted mean_train and std_train parameters from
|
||||||
|
Qlib's RobustZScoreNorm processor and saves them as reusable .npy files.
|
||||||
|
|
||||||
|
The extracted parameters can be used by the Polars RobustZScoreNorm processor
|
||||||
|
via the from_version() class method.
|
||||||
|
|
||||||
|
Output structure:
|
||||||
|
stock_1d/d033/alpha158_beta/data/robust_zscore_params/
|
||||||
|
└── {version}/
|
||||||
|
├── mean_train.npy
|
||||||
|
├── std_train.npy
|
||||||
|
└── metadata.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import pickle as pkl
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Default paths
|
||||||
|
DEFAULT_PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc"
|
||||||
|
DEFAULT_VERSION = "csiallx_feature2_ntrla_flag_pnlnorm"
|
||||||
|
|
||||||
|
# Alpha158 columns in order (158 features)
|
||||||
|
ALPHA158_COLS = [
|
||||||
|
'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
|
||||||
|
'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
|
||||||
|
'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
|
||||||
|
'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
|
||||||
|
'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
|
||||||
|
'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
|
||||||
|
'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
|
||||||
|
'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
|
||||||
|
'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
|
||||||
|
'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
|
||||||
|
'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
|
||||||
|
'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
|
||||||
|
'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
|
||||||
|
'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
|
||||||
|
'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
|
||||||
|
'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
|
||||||
|
'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
|
||||||
|
'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
|
||||||
|
'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
|
||||||
|
'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
|
||||||
|
'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
|
||||||
|
'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
|
||||||
|
'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
|
||||||
|
'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
|
||||||
|
'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
|
||||||
|
'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
|
||||||
|
'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
|
||||||
|
'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
|
||||||
|
'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
|
||||||
|
'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
|
||||||
|
'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
|
||||||
|
]
|
||||||
|
assert len(ALPHA158_COLS) == 158, f"Expected 158 alpha158 cols, got {len(ALPHA158_COLS)}"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_robust_zscore_params(proc_list_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proc_list_path: Path to the proc_list.proc pickle file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- mean_train: numpy array of shape (330,)
|
||||||
|
- std_train: numpy array of shape (330,)
|
||||||
|
- fit_start_time: datetime string
|
||||||
|
- fit_end_time: datetime string
|
||||||
|
- fields_group: list of field groups
|
||||||
|
"""
|
||||||
|
print(f"Loading proc_list.proc from: {proc_list_path}")
|
||||||
|
|
||||||
|
with open(proc_list_path, 'rb') as f:
|
||||||
|
proc_list = pkl.load(f)
|
||||||
|
|
||||||
|
print(f"Loaded {len(proc_list)} processors from proc_list")
|
||||||
|
|
||||||
|
# Find RobustZScoreNorm processor (typically at index 7)
|
||||||
|
zscore_proc = None
|
||||||
|
for i, proc in enumerate(proc_list):
|
||||||
|
proc_name = type(proc).__name__
|
||||||
|
print(f" [{i}] {proc_name}")
|
||||||
|
if proc_name == "RobustZScoreNorm":
|
||||||
|
zscore_proc = proc
|
||||||
|
|
||||||
|
if zscore_proc is None:
|
||||||
|
raise ValueError("RobustZScoreNorm processor not found in proc_list")
|
||||||
|
|
||||||
|
# Extract parameters
|
||||||
|
params = {
|
||||||
|
'mean_train': zscore_proc.mean_train,
|
||||||
|
'std_train': zscore_proc.std_train,
|
||||||
|
'fit_start_time': getattr(zscore_proc, 'fit_start_time', None),
|
||||||
|
'fit_end_time': getattr(zscore_proc, 'fit_end_time', None),
|
||||||
|
'fields_group': getattr(zscore_proc, 'fields_group', None),
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\nExtracted RobustZScoreNorm parameters:")
|
||||||
|
print(f" mean_train shape: {params['mean_train'].shape}, dtype: {params['mean_train'].dtype}")
|
||||||
|
print(f" std_train shape: {params['std_train'].shape}, dtype: {params['std_train'].dtype}")
|
||||||
|
print(f" fit_start_time: {params['fit_start_time']}")
|
||||||
|
print(f" fit_end_time: {params['fit_end_time']}")
|
||||||
|
print(f" fields_group: {params['fields_group']}")
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def build_feature_column_names() -> list:
|
||||||
|
"""
|
||||||
|
Build the complete list of 330 feature column names in order.
|
||||||
|
|
||||||
|
Feature order (330 total):
|
||||||
|
1. alpha158_ntrl (158 features)
|
||||||
|
2. alpha158_raw (158 features)
|
||||||
|
3. market_ext_ntrl (7 features)
|
||||||
|
4. market_ext_raw (7 features)
|
||||||
|
|
||||||
|
market_ext columns (after processing):
|
||||||
|
- Base: turnover, free_turnover, log_size, con_rating_strength
|
||||||
|
- Diff: turnover_diff, free_turnover_diff, con_rating_strength_diff
|
||||||
|
- Note: log_size_diff is removed by ColumnRemover
|
||||||
|
"""
|
||||||
|
# Alpha158 neutralized columns (158)
|
||||||
|
alpha158_ntrl = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||||
|
|
||||||
|
# Alpha158 raw columns (158)
|
||||||
|
alpha158_raw = ALPHA158_COLS.copy()
|
||||||
|
|
||||||
|
# market_ext columns (7 after ColumnRemover)
|
||||||
|
# After Diff: 4 base + 4 diff = 8
|
||||||
|
# After ColumnRemover (removes log_size_diff): 7 remain
|
||||||
|
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||||
|
market_ext_diff = ['turnover_diff', 'free_turnover_diff', 'log_size_diff', 'con_rating_strength_diff']
|
||||||
|
market_ext_all = market_ext_base + market_ext_diff
|
||||||
|
market_ext_final = [c for c in market_ext_all if c != 'log_size_diff']
|
||||||
|
|
||||||
|
# market_ext neutralized columns (7)
|
||||||
|
market_ext_ntrl = [f"{c}_ntrl" for c in market_ext_final]
|
||||||
|
|
||||||
|
# market_ext raw columns (7)
|
||||||
|
market_ext_raw = market_ext_final.copy()
|
||||||
|
|
||||||
|
# Combine all feature columns in Qlib order
|
||||||
|
feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw
|
||||||
|
|
||||||
|
print(f"\nBuilt feature column names:")
|
||||||
|
print(f" alpha158_ntrl: {len(alpha158_ntrl)} features")
|
||||||
|
print(f" alpha158_raw: {len(alpha158_raw)} features")
|
||||||
|
print(f" market_ext_ntrl: {len(market_ext_ntrl)} features")
|
||||||
|
print(f" market_ext_raw: {len(market_ext_raw)} features")
|
||||||
|
print(f" Total: {len(feature_cols)} features")
|
||||||
|
|
||||||
|
return feature_cols
|
||||||
|
|
||||||
|
|
||||||
|
def save_parameters(
|
||||||
|
params: dict,
|
||||||
|
feature_cols: list,
|
||||||
|
output_dir: str,
|
||||||
|
version: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Save extracted parameters to output directory.
|
||||||
|
|
||||||
|
Creates:
|
||||||
|
- mean_train.npy
|
||||||
|
- std_train.npy
|
||||||
|
- metadata.json
|
||||||
|
"""
|
||||||
|
output_path = Path(output_dir) / version
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"\nSaving parameters to: {output_path}")
|
||||||
|
|
||||||
|
# Save mean_train.npy
|
||||||
|
mean_path = output_path / "mean_train.npy"
|
||||||
|
np.save(mean_path, params['mean_train'])
|
||||||
|
print(f" Saved mean_train to: {mean_path}")
|
||||||
|
|
||||||
|
# Save std_train.npy
|
||||||
|
std_path = output_path / "std_train.npy"
|
||||||
|
np.save(std_path, params['std_train'])
|
||||||
|
print(f" Saved std_train to: {std_path}")
|
||||||
|
|
||||||
|
# Build metadata
|
||||||
|
metadata = {
|
||||||
|
'version': version,
|
||||||
|
'created_at': datetime.now().isoformat(),
|
||||||
|
'source_file': str(params.get('fit_start_time', 'unknown')),
|
||||||
|
'fit_start_time': str(params['fit_start_time']),
|
||||||
|
'fit_end_time': str(params['fit_end_time']),
|
||||||
|
'fields_group': list(params['fields_group']) if params['fields_group'] else None,
|
||||||
|
'feature_columns': {
|
||||||
|
'alpha158_ntrl': ALPHA158_COLS, # Store base names, _ntrl is implied
|
||||||
|
'alpha158_raw': ALPHA158_COLS,
|
||||||
|
'market_ext_ntrl': [
|
||||||
|
'turnover', 'free_turnover', 'log_size', 'con_rating_strength',
|
||||||
|
'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'
|
||||||
|
],
|
||||||
|
'market_ext_raw': [
|
||||||
|
'turnover', 'free_turnover', 'log_size', 'con_rating_strength',
|
||||||
|
'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'
|
||||||
|
],
|
||||||
|
},
|
||||||
|
'feature_count': {
|
||||||
|
'alpha158_ntrl': 158,
|
||||||
|
'alpha158_raw': 158,
|
||||||
|
'market_ext_ntrl': 7,
|
||||||
|
'market_ext_raw': 7,
|
||||||
|
'total': 330
|
||||||
|
},
|
||||||
|
'parameter_shapes': {
|
||||||
|
'mean_train': list(params['mean_train'].shape),
|
||||||
|
'std_train': list(params['std_train'].shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save metadata.json
|
||||||
|
metadata_path = output_path / "metadata.json"
|
||||||
|
with open(metadata_path, 'w') as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
print(f" Saved metadata to: {metadata_path}")
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Extract RobustZScoreNorm parameters from Qlib's proc_list.proc"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--proc-list',
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_PROC_LIST_PATH,
|
||||||
|
help=f"Path to proc_list.proc (default: {DEFAULT_PROC_LIST_PATH})"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--version',
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_VERSION,
|
||||||
|
help=f"Version name for output directory (default: {DEFAULT_VERSION})"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--output-dir',
|
||||||
|
type=str,
|
||||||
|
default=str(Path(__file__).parent.parent / "data" / "robust_zscore_params"),
|
||||||
|
help="Output directory for parameter files"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Extract Qlib RobustZScoreNorm Parameters")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"Source: {args.proc_list}")
|
||||||
|
print(f"Version: {args.version}")
|
||||||
|
print(f"Output: {args.output_dir}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Step 1: Extract parameters from proc_list.proc
|
||||||
|
params = extract_robust_zscore_params(args.proc_list)
|
||||||
|
|
||||||
|
# Step 2: Build feature column names
|
||||||
|
feature_cols = build_feature_column_names()
|
||||||
|
|
||||||
|
# Step 3: Verify parameter shape matches feature count
|
||||||
|
if params['mean_train'].shape[0] != len(feature_cols):
|
||||||
|
print(f"\nWARNING: Parameter shape mismatch!")
|
||||||
|
print(f" Expected: {len(feature_cols)} features")
|
||||||
|
print(f" Got: {params['mean_train'].shape[0]} parameters")
|
||||||
|
else:
|
||||||
|
print(f"\n✓ Parameter shape matches feature count ({len(feature_cols)})")
|
||||||
|
|
||||||
|
# Step 4: Save parameters
|
||||||
|
output_path = save_parameters(params, feature_cols, args.output_dir, args.version)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Extraction complete!")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"Output files:")
|
||||||
|
print(f" {output_path}/mean_train.npy")
|
||||||
|
print(f" {output_path}/std_train.npy")
|
||||||
|
print(f" {output_path}/metadata.json")
|
||||||
|
print()
|
||||||
|
print("Usage in Polars RobustZScoreNorm:")
|
||||||
|
print(f' norm = RobustZScoreNorm.from_version("{args.version}")')
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in new issue