Major changes: - Fix FixedFlagMarketInjector to add market_0, market_1 columns based on instrument codes - Fix FixedFlagSTInjector to create IsST column from ST_S, ST_Y flags - Update generate_beta_embedding.py to handle IsST creation conditionally - Add dump_polars_dataset.py for generating raw and processed datasets - Add debug_data_divergence.py for comparing gold-standard vs polars output Documentation: - Update BUG_ANALYSIS_FINAL.md with IsST column issue discovery - Update README.md with polars dataset generation instructions Key discovery: - The FlagSTInjector in the gold-standard qlib code fails silently - The VAE was trained without IsST column (341 features, not 342) - The polars pipeline correctly skips FlagSTInjector to match gold-standard Generated dataset structure (2026-02-23 to 2026-02-27): - Raw data: 18,291 rows × 204 columns - Processed data: 18,291 rows × 342 columns (341 for VAE input) - market_0, market_1 columns correctly added to feature_flag groupmaster
parent
4d382dc6bd
commit
8bd36c1939
@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Debug script to compare gold-standard qlib data vs polars-based pipeline.
|
||||
|
||||
This script helps identify where the data loading and processing pipeline
|
||||
starts to diverge from the gold-standard qlib output.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pickle as pkl
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
from pathlib import Path
|
||||
|
||||
# Paths
|
||||
GOLD_RAW_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/raw_data_20190101_20190131.pkl"
|
||||
GOLD_PROC_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/processed_data_20190101_20190131.pkl"
|
||||
PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc"
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
|
||||
|
||||
def compare_raw_data():
|
||||
"""Compare raw data from gold standard vs polars pipeline."""
|
||||
print("=" * 80)
|
||||
print("STEP 1: Compare RAW DATA (before proc_list)")
|
||||
print("=" * 80)
|
||||
|
||||
# Load gold standard raw data
|
||||
with open(GOLD_RAW_PATH, "rb") as f:
|
||||
gold_raw = pkl.load(f)
|
||||
|
||||
print(f"\nGold standard raw data:")
|
||||
print(f" Shape: {gold_raw.shape}")
|
||||
print(f" Index: {gold_raw.index.names}")
|
||||
print(f" Column groups: {gold_raw.columns.get_level_values(0).unique().tolist()}")
|
||||
|
||||
# Count columns per group
|
||||
for grp in gold_raw.columns.get_level_values(0).unique().tolist():
|
||||
count = (gold_raw.columns.get_level_values(0) == grp).sum()
|
||||
print(f" {grp}: {count} columns")
|
||||
|
||||
# Show sample values for key columns
|
||||
print("\n Sample values (first 3 rows):")
|
||||
for col in [('feature', 'KMID'), ('feature_ext', 'turnover'), ('feature_ext', 'log_size')]:
|
||||
if col in gold_raw.columns:
|
||||
print(f" {col}: {gold_raw[col].iloc[:3].tolist()}")
|
||||
|
||||
return gold_raw
|
||||
|
||||
|
||||
def compare_processed_data():
|
||||
"""Compare processed data from gold standard vs polars pipeline."""
|
||||
print("\n" + "=" * 80)
|
||||
print("STEP 2: Compare PROCESSED DATA (after proc_list)")
|
||||
print("=" * 80)
|
||||
|
||||
# Load gold standard processed data
|
||||
with open(GOLD_PROC_PATH, "rb") as f:
|
||||
gold_proc = pkl.load(f)
|
||||
|
||||
print(f"\nGold standard processed data:")
|
||||
print(f" Shape: {gold_proc.shape}")
|
||||
print(f" Index: {gold_proc.index.names}")
|
||||
print(f" Column groups: {gold_proc.columns.get_level_values(0).unique().tolist()}")
|
||||
|
||||
# Count columns per group
|
||||
for grp in gold_proc.columns.get_level_values(0).unique().tolist():
|
||||
count = (gold_proc.columns.get_level_values(0) == grp).sum()
|
||||
print(f" {grp}: {count} columns")
|
||||
|
||||
# Show sample values for key columns
|
||||
print("\n Sample values (first 3 rows):")
|
||||
for col in [('feature', 'KMID'), ('feature', 'KMID_ntrl'),
|
||||
('feature_ext', 'turnover'), ('feature_ext', 'turnover_ntrl')]:
|
||||
if col in gold_proc.columns:
|
||||
print(f" {col}: {gold_proc[col].iloc[:3].tolist()}")
|
||||
|
||||
return gold_proc
|
||||
|
||||
|
||||
def analyze_processor_pipeline(gold_raw, gold_proc):
|
||||
"""Analyze what transformations happened in the proc_list."""
|
||||
print("\n" + "=" * 80)
|
||||
print("STEP 3: Analyze Processor Transformations")
|
||||
print("=" * 80)
|
||||
|
||||
# Load proc_list
|
||||
with open(PROC_LIST_PATH, "rb") as f:
|
||||
proc_list = pkl.load(f)
|
||||
|
||||
print(f"\nProcessor pipeline ({len(proc_list)} processors):")
|
||||
for i, proc in enumerate(proc_list):
|
||||
print(f" [{i}] {type(proc).__name__}")
|
||||
|
||||
# Analyze column changes
|
||||
print("\nColumn count changes:")
|
||||
print(f" Before: {gold_raw.shape[1]} columns")
|
||||
print(f" After: {gold_proc.shape[1]} columns")
|
||||
print(f" Change: +{gold_proc.shape[1] - gold_raw.shape[1]} columns")
|
||||
|
||||
# Check which columns were added/removed
|
||||
gold_raw_cols = set(gold_raw.columns)
|
||||
gold_proc_cols = set(gold_proc.columns)
|
||||
|
||||
added_cols = gold_proc_cols - gold_raw_cols
|
||||
removed_cols = gold_raw_cols - gold_proc_cols
|
||||
|
||||
print(f"\n Added columns: {len(added_cols)}")
|
||||
print(f" Removed columns: {len(removed_cols)}")
|
||||
|
||||
if removed_cols:
|
||||
print(f" Removed: {list(removed_cols)[:10]}...")
|
||||
|
||||
# Check feature column patterns
|
||||
print("\nFeature column patterns in processed data:")
|
||||
feature_cols = [c for c in gold_proc.columns if c[0] == 'feature']
|
||||
ntrl_cols = [c for c in feature_cols if c[1].endswith('_ntrl')]
|
||||
raw_cols = [c for c in feature_cols if not c[1].endswith('_ntrl')]
|
||||
print(f" Total feature columns: {len(feature_cols)}")
|
||||
print(f" _ntrl columns: {len(ntrl_cols)}")
|
||||
print(f" raw columns: {len(raw_cols)}")
|
||||
|
||||
|
||||
def check_polars_pipeline():
|
||||
"""Run the polars-based pipeline and compare."""
|
||||
print("\n" + "=" * 80)
|
||||
print("STEP 4: Generate data using Polars pipeline")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
from generate_beta_embedding import (
|
||||
load_all_data, merge_data_sources, apply_feature_pipeline,
|
||||
filter_stock_universe
|
||||
)
|
||||
|
||||
# Load data using polars pipeline
|
||||
print("\nLoading data with polars pipeline...")
|
||||
df_alpha, df_kline, df_flag, df_industry = load_all_data(
|
||||
"2019-01-01", "2019-01-31"
|
||||
)
|
||||
|
||||
print(f"\nPolars data sources loaded:")
|
||||
print(f" Alpha158: {df_alpha.shape}")
|
||||
print(f" Kline (market_ext): {df_kline.shape}")
|
||||
print(f" Flags: {df_flag.shape}")
|
||||
print(f" Industry: {df_industry.shape}")
|
||||
|
||||
# Merge
|
||||
df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
|
||||
print(f"\nAfter merge: {df_merged.shape}")
|
||||
|
||||
# Convert to pandas for easier comparison
|
||||
df_pandas = df_merged.to_pandas()
|
||||
df_pandas = df_pandas.set_index(['datetime', 'instrument'])
|
||||
|
||||
print(f"\nAfter converting to pandas MultiIndex: {df_pandas.shape}")
|
||||
|
||||
# Compare column names
|
||||
with open(GOLD_RAW_PATH, "rb") as f:
|
||||
gold_raw = pkl.load(f)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("STEP 5: Compare Column Names (Gold vs Polars)")
|
||||
print("=" * 80)
|
||||
|
||||
gold_cols = set(str(c) for c in gold_raw.columns)
|
||||
polars_cols = set(str(c) for c in df_pandas.columns)
|
||||
|
||||
common_cols = gold_cols & polars_cols
|
||||
only_in_gold = gold_cols - polars_cols
|
||||
only_in_polars = polars_cols - gold_cols
|
||||
|
||||
print(f"\n Common columns: {len(common_cols)}")
|
||||
print(f" Only in gold standard: {len(only_in_gold)}")
|
||||
print(f" Only in polars: {len(only_in_polars)}")
|
||||
|
||||
if only_in_gold:
|
||||
print(f"\n Columns only in gold standard (first 20):")
|
||||
for col in list(only_in_gold)[:20]:
|
||||
print(f" {col}")
|
||||
|
||||
if only_in_polars:
|
||||
print(f"\n Columns only in polars (first 20):")
|
||||
for col in list(only_in_polars)[:20]:
|
||||
print(f" {col}")
|
||||
|
||||
# Check common columns values
|
||||
print("\n" + "=" * 80)
|
||||
print("STEP 6: Compare Values for Common Columns")
|
||||
print("=" * 80)
|
||||
|
||||
# Get common columns as tuples
|
||||
common_tuples = []
|
||||
for gc in gold_raw.columns:
|
||||
gc_str = str(gc)
|
||||
for pc in df_pandas.columns:
|
||||
if str(pc) == gc_str:
|
||||
common_tuples.append((gc, pc))
|
||||
break
|
||||
|
||||
print(f"\nComparing {len(common_tuples)} common columns...")
|
||||
|
||||
# Compare first few columns
|
||||
matching_count = 0
|
||||
diff_count = 0
|
||||
for i, (gc, pc) in enumerate(common_tuples[:20]):
|
||||
gold_vals = gold_raw[gc].dropna().values
|
||||
polars_vals = df_pandas[pc].dropna().values
|
||||
|
||||
if len(gold_vals) > 0 and len(polars_vals) > 0:
|
||||
# Compare min, max, mean
|
||||
if np.allclose([gold_vals.min(), gold_vals.max(), gold_vals.mean()],
|
||||
[polars_vals.min(), polars_vals.max(), polars_vals.mean()],
|
||||
rtol=1e-5):
|
||||
matching_count += 1
|
||||
else:
|
||||
diff_count += 1
|
||||
if diff_count <= 3:
|
||||
print(f" DIFF: {gc}")
|
||||
print(f" Gold: min={gold_vals.min():.6f}, max={gold_vals.max():.6f}, mean={gold_vals.mean():.6f}")
|
||||
print(f" Polars: min={polars_vals.min():.6f}, max={polars_vals.max():.6f}, mean={polars_vals.mean():.6f}")
|
||||
|
||||
print(f"\n Matching columns: {matching_count}")
|
||||
print(f" Different columns: {diff_count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError running polars pipeline: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 80)
|
||||
print("DATA DIVERGENCE DEBUG SCRIPT")
|
||||
print("Comparing gold-standard qlib output vs polars-based pipeline")
|
||||
print("=" * 80)
|
||||
|
||||
# Step 1: Check raw data
|
||||
gold_raw = compare_raw_data()
|
||||
|
||||
# Step 2: Check processed data
|
||||
gold_proc = compare_processed_data()
|
||||
|
||||
# Step 3: Analyze processor transformations
|
||||
analyze_processor_pipeline(gold_raw, gold_proc)
|
||||
|
||||
# Step 4 & 5: Run polars pipeline and compare
|
||||
check_polars_pipeline()
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("DEBUG COMPLETE")
|
||||
print("=" * 80)
|
||||
@ -0,0 +1,330 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Script to dump raw and processed datasets using the polars-based pipeline.
|
||||
|
||||
This generates:
|
||||
1. Raw data (before applying processors) - equivalent to qlib's handler output
|
||||
2. Processed data (after applying all processors) - ready for VAE encoding
|
||||
|
||||
Date range: 2026-02-23 to today (2026-02-27)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pickle as pkl
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from generate_beta_embedding import (
|
||||
load_all_data,
|
||||
merge_data_sources,
|
||||
filter_stock_universe,
|
||||
DiffProcessor,
|
||||
FlagMarketInjector,
|
||||
ColumnRemover,
|
||||
FlagToOnehot,
|
||||
IndusNtrlInjector,
|
||||
RobustZScoreNorm,
|
||||
Fillna,
|
||||
load_qlib_processor_params,
|
||||
ALPHA158_COLS,
|
||||
INDUSTRY_FLAG_COLS,
|
||||
)
|
||||
|
||||
# Date range
|
||||
START_DATE = "2026-02-23"
|
||||
END_DATE = "2026-02-27"
|
||||
|
||||
# Output directory
|
||||
OUTPUT_DIR = Path(__file__).parent.parent / "data_polars"
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Apply the full processor pipeline (equivalent to qlib's proc_list).
|
||||
|
||||
This mimics the qlib proc_list:
|
||||
0. Diff: Adds diff features for market_ext columns
|
||||
1. FlagMarketInjector: Adds market_0, market_1
|
||||
2. FlagSTInjector: SKIPPED (fails even in gold-standard)
|
||||
3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
|
||||
4. FlagToOnehot: Converts 29 industry flags to indus_idx
|
||||
5. IndusNtrlInjector: Industry neutralization for feature
|
||||
6. IndusNtrlInjector: Industry neutralization for feature_ext
|
||||
7. RobustZScoreNorm: Normalization using pre-fitted qlib params
|
||||
8. Fillna: Fill NaN with 0
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("Applying processor pipeline")
|
||||
print("=" * 60)
|
||||
|
||||
# market_ext columns (4 base)
|
||||
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||
|
||||
# market_flag columns (12 total before ColumnRemover)
|
||||
market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
|
||||
'open_limit', 'close_limit', 'low_limit',
|
||||
'open_stop', 'close_stop', 'high_stop']
|
||||
|
||||
# Step 1: Diff Processor
|
||||
print("\n[1] Applying Diff processor...")
|
||||
diff_processor = DiffProcessor(market_ext_base)
|
||||
df = diff_processor.process(df)
|
||||
|
||||
# After Diff: market_ext has 8 columns (4 base + 4 diff)
|
||||
market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base]
|
||||
|
||||
# Step 2: FlagMarketInjector (adds market_0, market_1)
|
||||
print("[2] Applying FlagMarketInjector...")
|
||||
flag_injector = FlagMarketInjector()
|
||||
df = flag_injector.process(df)
|
||||
|
||||
# Add market_0, market_1 to flag list
|
||||
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
|
||||
|
||||
# Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard)
|
||||
print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...")
|
||||
market_flag_with_st = market_flag_with_market # No IsST added
|
||||
|
||||
# Step 4: ColumnRemover
|
||||
print("[4] Applying ColumnRemover...")
|
||||
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
||||
remover = ColumnRemover(columns_to_remove)
|
||||
df = remover.process(df)
|
||||
|
||||
# Update column lists after removal
|
||||
market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove]
|
||||
market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove]
|
||||
|
||||
print(f" Removed columns: {columns_to_remove}")
|
||||
print(f" Remaining market_ext: {len(market_ext_cols)} columns")
|
||||
print(f" Remaining market_flag: {len(market_flag_with_st)} columns")
|
||||
|
||||
# Step 5: FlagToOnehot
|
||||
print("[5] Applying FlagToOnehot...")
|
||||
flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
|
||||
df = flag_to_onehot.process(df)
|
||||
|
||||
# Step 6 & 7: IndusNtrlInjector
|
||||
print("[6] Applying IndusNtrlInjector for alpha158...")
|
||||
alpha158_cols = ALPHA158_COLS.copy()
|
||||
indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl')
|
||||
df = indus_ntrl_alpha.process(df)
|
||||
|
||||
print("[7] Applying IndusNtrlInjector for market_ext...")
|
||||
indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl')
|
||||
df = indus_ntrl_ext.process(df)
|
||||
|
||||
# Build column lists for normalization
|
||||
alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols]
|
||||
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols]
|
||||
|
||||
# Step 8: RobustZScoreNorm
|
||||
print("[8] Applying RobustZScoreNorm...")
|
||||
norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
|
||||
|
||||
qlib_params = load_qlib_processor_params()
|
||||
|
||||
# Verify parameter shape
|
||||
expected_features = len(norm_feature_cols)
|
||||
if qlib_params['mean_train'].shape[0] != expected_features:
|
||||
print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
|
||||
f"got {qlib_params['mean_train'].shape[0]}")
|
||||
|
||||
robust_norm = RobustZScoreNorm(
|
||||
norm_feature_cols,
|
||||
clip_range=(-3, 3),
|
||||
use_qlib_params=True,
|
||||
qlib_mean=qlib_params['mean_train'],
|
||||
qlib_std=qlib_params['std_train']
|
||||
)
|
||||
df = robust_norm.process(df)
|
||||
|
||||
# Step 9: Fillna
|
||||
print("[9] Applying Fillna...")
|
||||
final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx']
|
||||
fillna = Fillna()
|
||||
df = fillna.process(df, final_feature_cols)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Processor pipeline complete!")
|
||||
print(f" Normalized features: {len(norm_feature_cols)}")
|
||||
print(f" Market flags: {len(market_flag_with_st)}")
|
||||
print(f" Total features (with indus_idx): {len(final_feature_cols)}")
|
||||
print("=" * 60)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
|
||||
"""
|
||||
Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
|
||||
This matches the format of qlib's output.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
# Convert to pandas
|
||||
df = df_polars.to_pandas()
|
||||
|
||||
# Check if datetime and instrument are columns
|
||||
if 'datetime' in df.columns and 'instrument' in df.columns:
|
||||
# Set MultiIndex
|
||||
df = df.set_index(['datetime', 'instrument'])
|
||||
# If they're already not in columns, assume they're already the index
|
||||
|
||||
# Drop raw columns that shouldn't be in processed data
|
||||
raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue']
|
||||
existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns]
|
||||
if existing_raw_cols:
|
||||
df = df.drop(columns=existing_raw_cols)
|
||||
|
||||
# Build MultiIndex columns based on column name patterns
|
||||
columns_with_group = []
|
||||
|
||||
# Define column sets
|
||||
alpha158_base = set(ALPHA158_COLS)
|
||||
market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'}
|
||||
market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'}
|
||||
market_ext_all = market_ext_base | market_ext_diff
|
||||
feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
|
||||
'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
|
||||
|
||||
for col in df.columns:
|
||||
if col == 'indus_idx':
|
||||
columns_with_group.append(('indus_idx', col))
|
||||
elif col in feature_flag_cols:
|
||||
columns_with_group.append(('feature_flag', col))
|
||||
elif col.endswith('_ntrl'):
|
||||
base_name = col[:-5] # Remove _ntrl suffix (5 characters)
|
||||
if base_name in alpha158_base:
|
||||
columns_with_group.append(('feature', col))
|
||||
elif base_name in market_ext_all:
|
||||
columns_with_group.append(('feature_ext', col))
|
||||
else:
|
||||
columns_with_group.append(('feature', col)) # Default to feature
|
||||
elif col in alpha158_base:
|
||||
columns_with_group.append(('feature', col))
|
||||
elif col in market_ext_all:
|
||||
columns_with_group.append(('feature_ext', col))
|
||||
elif col in INDUSTRY_FLAG_COLS:
|
||||
columns_with_group.append(('indus_flag', col))
|
||||
elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
|
||||
columns_with_group.append(('st_flag', col))
|
||||
else:
|
||||
# Unknown column - print warning
|
||||
print(f" Warning: Unknown column '{col}', assigning to 'other' group")
|
||||
columns_with_group.append(('other', col))
|
||||
|
||||
# Create MultiIndex columns
|
||||
multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
|
||||
df.columns = multi_cols
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 80)
|
||||
print("Dumping Polars Dataset")
|
||||
print("=" * 80)
|
||||
print(f"Date range: {START_DATE} to {END_DATE}")
|
||||
print(f"Output directory: {OUTPUT_DIR}")
|
||||
print()
|
||||
|
||||
# Step 1: Load all data
|
||||
print("Step 1: Loading data from parquet...")
|
||||
df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE)
|
||||
print(f" Alpha158 shape: {df_alpha.shape}")
|
||||
print(f" Kline (market_ext) shape: {df_kline.shape}")
|
||||
print(f" Flags shape: {df_flag.shape}")
|
||||
print(f" Industry shape: {df_industry.shape}")
|
||||
|
||||
# Step 2: Merge data sources
|
||||
print("\nStep 2: Merging data sources...")
|
||||
df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
|
||||
print(f" Merged shape (after csiallx filter): {df_merged.shape}")
|
||||
|
||||
# Step 3: Save raw data (before processors)
|
||||
print("\nStep 3: Saving raw data (before processors)...")
|
||||
|
||||
# Keep columns that match qlib's raw output format
|
||||
# Include datetime and instrument for MultiIndex conversion
|
||||
raw_columns = (
|
||||
['datetime', 'instrument'] + # Index columns
|
||||
ALPHA158_COLS + # feature group
|
||||
['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base
|
||||
['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline
|
||||
'open_limit', 'close_limit', 'low_limit',
|
||||
'open_stop', 'close_stop', 'high_stop'] +
|
||||
INDUSTRY_FLAG_COLS + # indus_flag
|
||||
(['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available)
|
||||
)
|
||||
|
||||
# Filter to available columns
|
||||
available_raw_cols = [c for c in raw_columns if c in df_merged.columns]
|
||||
print(f" Selecting {len(available_raw_cols)} columns for raw data...")
|
||||
df_raw_polars = df_merged.select(available_raw_cols)
|
||||
|
||||
# Convert to pandas with MultiIndex
|
||||
df_raw_pd = convert_to_multiindex_df(df_raw_polars)
|
||||
|
||||
raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
|
||||
with open(raw_output_path, "wb") as f:
|
||||
pkl.dump(df_raw_pd, f)
|
||||
print(f" Saved raw data to: {raw_output_path}")
|
||||
print(f" Raw data shape: {df_raw_pd.shape}")
|
||||
print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}")
|
||||
|
||||
# Step 4: Apply processor pipeline
|
||||
print("\nStep 4: Applying processor pipeline...")
|
||||
df_processed = apply_processor_pipeline(df_merged)
|
||||
|
||||
# Step 5: Save processed data
|
||||
print("\nStep 5: Saving processed data (after processors)...")
|
||||
|
||||
# Convert to pandas with MultiIndex
|
||||
df_processed_pd = convert_to_multiindex_df(df_processed)
|
||||
|
||||
processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
|
||||
with open(processed_output_path, "wb") as f:
|
||||
pkl.dump(df_processed_pd, f)
|
||||
print(f" Saved processed data to: {processed_output_path}")
|
||||
print(f" Processed data shape: {df_processed_pd.shape}")
|
||||
print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}")
|
||||
|
||||
# Count columns per group
|
||||
print("\n Column counts by group:")
|
||||
for grp in df_processed_pd.columns.get_level_values(0).unique().tolist():
|
||||
count = (df_processed_pd.columns.get_level_values(0) == grp).sum()
|
||||
print(f" {grp}: {count} columns")
|
||||
|
||||
# Step 6: Verify column counts
|
||||
print("\n" + "=" * 80)
|
||||
print("Verification")
|
||||
print("=" * 80)
|
||||
|
||||
feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag']
|
||||
has_market_0 = 'market_0' in feature_flag_cols
|
||||
has_market_1 = 'market_1' in feature_flag_cols
|
||||
|
||||
print(f" feature_flag columns: {feature_flag_cols}")
|
||||
print(f" Has market_0: {has_market_0}")
|
||||
print(f" Has market_1: {has_market_1}")
|
||||
|
||||
if has_market_0 and has_market_1:
|
||||
print("\n SUCCESS: market_0 and market_1 columns are present!")
|
||||
else:
|
||||
print("\n WARNING: market_0 or market_1 columns are missing!")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Dataset dump complete!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in new issue