You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

449 lines
16 KiB

"""
CTA 1-day return prediction dataset loader.
Uses the new qshare.data framework with Dataset class and processors.
"""
from dataclasses import dataclass
from datetime import date
from typing import List, Optional
import numpy as np
import pandas as pd
import polars as pl
from qshare.data import pl_Dataset, pl_pipe, pl_clip, pl_cs_zscore
from qshare.data.universal import DataSpec
from qshare.io.ddb import get_ddb_sess, reset_index_from_ddb
from qshare.config.research.cta.features import HFFACTOR_COLS
from .labels import get_blend_weights
@dataclass
class CTA1DLoader:
"""
CTA 1-day return prediction dataset loader.
Loads features (alpha158, hffactor), labels, and calculates weights
for CTA futures daily return prediction tasks.
Example:
>>> loader = CTA1DLoader(
... return_type='o2c_twap1min',
... normalization='dual',
... feature_sets=['alpha158', 'hffactor']
... )
>>> dataset = loader.load(dt_range=['2020-01-01', '2023-12-31'])
>>> dataset = dataset.with_segments({
... 'train': ('2020-01-01', '2022-12-31'),
... 'test': ('2023-01-01', '2023-12-31')
... })
>>> X_train, y_train, w_train = dataset.split('train').to_numpy()
"""
return_type: str = 'o2c_twap1min'
normalization: str = 'dual'
feature_sets: List[str] = None
weight_factors: dict = None
blend_weights: str | List[float] | None = None
ddb_host: str = '192.168.1.146'
label_cap_upper: float = 0.5
label_cap_lower: float = -0.5
def __post_init__(self):
if self.feature_sets is None:
self.feature_sets = ['alpha158', 'hffactor']
if self.weight_factors is None:
self.weight_factors = {'positive': 1.0, 'negative': 2.0}
def load(
self,
dt_range: List[str],
fit_range: Optional[List[str]] = None
) -> pl_Dataset:
"""
Load and prepare CTA 1-day training dataset.
Args:
dt_range: Date range [start_date, end_date] for dataset
fit_range: Date range [start, end] for fitting normalization params.
If None, uses first 60% of dt_range.
Returns:
pl_Dataset with features, label, and weight columns
"""
start_date, end_date = dt_range
if fit_range is None:
# Default: use first 60% for fit
all_dates = pd.date_range(start_date, end_date)
split_idx = int(len(all_dates) * 0.6)
fit_range = [
all_dates[0].strftime('%Y-%m-%d'),
all_dates[split_idx].strftime('%Y-%m-%d')
]
# Load extended history for rolling normalization
load_start = (pd.to_datetime(start_date) - pd.Timedelta(days=120)).strftime('%Y-%m-%d')
# Load features
df_features = self._load_features(load_start, end_date)
# Load and normalize labels
df_label = self._load_labels(load_start, end_date, fit_range)
# Combine
df = df_features.join(df_label, on=['datetime', 'instrument'], how='inner')
# Filter to requested date range
df = df.filter(
(pl.col('datetime') >= start_date) &
(pl.col('datetime') <= end_date)
)
# Calculate weights
df = self._calculate_weights(df)
# Clean data
df = self._clean_data(df)
# Get feature columns
feature_cols = [c for c in df.columns
if any(c.startswith(prefix) for prefix in ['f_a158_', 'f_hf_'])]
return pl_Dataset(
data=df,
features=feature_cols,
label='label',
weight='weight' if self.weight_factors else None
)
def _load_features(self, start_date: str, end_date: str) -> pl.DataFrame:
"""Load feature data from DolphinDB."""
sess = get_ddb_sess(host=self.ddb_host)
try:
feature_dfs = []
if 'alpha158' in self.feature_sets:
df_alpha = self._load_alpha158(sess, start_date, end_date)
feature_dfs.append(df_alpha)
if 'hffactor' in self.feature_sets:
df_hf = self._load_hffactor(sess, start_date, end_date)
feature_dfs.append(df_hf)
# Join all feature sets
result = feature_dfs[0]
for df in feature_dfs[1:]:
result = result.join(df, on=['datetime', 'instrument'], how='inner')
return result
finally:
sess.close()
def _load_alpha158(self, sess, start_date: str, end_date: str) -> pl.DataFrame:
"""Load alpha158 features from DolphinDB."""
since_ddb = pd.to_datetime(start_date).strftime('%Y.%m.%d')
df = sess.run(f"""
select code, m_nDate, *
from loadTable('dfs://daily_stock_run', 'stg_1day_tinysoft_cta_alpha159_0_7_beta')
where m_nDate >= {since_ddb}
""")
df = reset_index_from_ddb(df)
# Drop non-numeric columns
if 'code_init' in df.columns:
df = df.drop(columns=['code_init'])
# Convert to polars and add prefix
pl_df = pl.from_pandas(df.reset_index())
pl_df = pl_df.rename({
c: f'f_a158_{c}'
for c in pl_df.columns
if c not in ['datetime', 'instrument']
})
return pl_df
def _load_hffactor(self, sess, start_date: str, end_date: str) -> pl.DataFrame:
"""Load hffactor features from DolphinDB."""
since_ddb = pd.to_datetime(start_date).strftime('%Y.%m.%d')
# Load from factor table
df = sess.run(f"""
select code, m_nDate, factor_name, value
from loadTable('dfs://daily_stock_run', 'stg_1day_tinysoft_cta_hffactor')
where m_nDate >= {since_ddb}
and factor_name in [{','.join([f"'{c}'" for c in HFFACTOR_COLS])}]
""")
# Pivot to wide format
df = df.pivot_table(
index=['code', 'm_nDate'],
columns='factor_name',
values='value'
).reset_index()
df = reset_index_from_ddb(df)
# Convert to polars and add prefix
pl_df = pl.from_pandas(df.reset_index())
pl_df = pl_df.rename({
c: f'f_hf_{c}'
for c in pl_df.columns
if c not in ['datetime', 'instrument']
})
return pl_df
def _load_labels(
self,
start_date: str,
end_date: str,
fit_range: List[str]
) -> pl.DataFrame:
"""Load and normalize labels."""
sess = get_ddb_sess(host=self.ddb_host)
try:
# Map return type to indicator name
indicator_map = {
'o2c_twap1min': 'twap_open1m@1_twap_close1m@1',
'o2o_twap1min': 'twap_open1m@1_twap_open1m@2',
}
indicator = indicator_map.get(self.return_type, self.return_type)
since_ddb = pd.to_datetime(start_date).strftime('%Y.%m.%d')
# Load dominant contract mapping
df_contract = sess.run(f"""
select first(code) as code, m_nDate, code_init
from loadTable('dfs://daily_stock_run', 'dwm_1day_cta_dom')
where m_nDate >= {since_ddb} and version='vp_csmax_roll2_cummax'
group by m_nDate, code_init
""")
# Load returns
df_return = sess.run(f"""
select code, m_nDate, value as ret
from loadTable('dfs://daily_stock_run', 'stg_1day_tinysoft_cta_hfvalue')
where indicator='{indicator}' and m_nDate >= {since_ddb}
""")
# Merge with dominant contract mapping
df_return = pd.merge(
left=df_return[['code', 'm_nDate', 'ret']],
right=df_contract,
on=['code', 'm_nDate'],
how='inner'
)
# Convert to index format
df_return['code'] = df_return['code_init'] + 'Ind'
df_return = df_return[['code', 'm_nDate', 'ret']]
df_return = reset_index_from_ddb(df_return)
# Convert to polars
pl_df = pl.from_pandas(df_return.reset_index())
# Apply normalization
pl_df = self._normalize_label(pl_df, fit_range)
return pl_df
finally:
sess.close()
def _normalize_label(self, pl_df: pl.DataFrame, fit_range: List[str]) -> pl.DataFrame:
"""Apply specified normalization to label."""
fit_start, fit_end = fit_range
# Ensure datetime column is string for comparison
if pl_df['datetime'].dtype == pl.Date:
pl_df = pl_df.with_columns(
pl.col('datetime').dt.strftime('%Y-%m-%d').alias('datetime_str')
)
date_col = 'datetime_str'
else:
date_col = 'datetime'
if self.normalization == 'zscore':
# Calculate mean/std on fit range
fit_data = pl_df.filter(
(pl.col(date_col) >= fit_start) &
(pl.col(date_col) <= fit_end)
)
mean = fit_data['ret'].mean()
std = fit_data['ret'].std()
result = pl_df.with_columns(
((pl.col('ret') - mean) / std).clip(
self.label_cap_lower, self.label_cap_upper
).alias('label')
).select(['datetime', 'instrument', 'label'])
return result
elif self.normalization == 'cs_zscore':
# Cross-sectional z-score per datetime
return pl_df.with_columns(
((pl.col('ret') - pl.col('ret').mean().over('datetime')) /
pl.col('ret').std().over('datetime')).clip(
self.label_cap_lower, self.label_cap_upper
).alias('label')
).select(['datetime', 'instrument', 'label'])
elif self.normalization == 'rolling_20':
return self._apply_rolling_norm(pl_df, window=20, fit_range=fit_range)
elif self.normalization == 'rolling_60':
return self._apply_rolling_norm(pl_df, window=60, fit_range=fit_range)
elif self.normalization == 'dual':
# Create all normalization variants
label_zscore = self._normalize_zscore(pl_df, fit_range)
label_cszscore = self._normalize_cs_zscore(pl_df)
label_roll20 = self._normalize_rolling(pl_df, window=20, fit_range=fit_range)
label_roll60 = self._normalize_rolling(pl_df, window=60, fit_range=fit_range)
# Get blend weights
weights = get_blend_weights(self.blend_weights)
# Join and blend
pl_df = label_zscore.join(label_cszscore, on=['datetime', 'instrument'])
pl_df = pl_df.join(label_roll20, on=['datetime', 'instrument'])
pl_df = pl_df.join(label_roll60, on=['datetime', 'instrument'])
return pl_df.with_columns(
(weights[0] * pl.col('label_zscore') +
weights[1] * pl.col('label_cszscore') +
weights[2] * pl.col('label_roll20') +
weights[3] * pl.col('label_roll60')).clip(
self.label_cap_lower, self.label_cap_upper
).alias('label')
).select(['datetime', 'instrument', 'label'])
else:
raise ValueError(f"Unknown normalization: {self.normalization}")
def _normalize_zscore(self, pl_df: pl.DataFrame, fit_range: List[str]) -> pl.DataFrame:
"""Create z-score normalized label."""
fit_start, fit_end = fit_range
# Handle date type conversion for comparison
if pl_df['datetime'].dtype == pl.Date:
fit_data = pl_df.filter(
(pl.col('datetime').dt.strftime('%Y-%m-%d') >= fit_start) &
(pl.col('datetime').dt.strftime('%Y-%m-%d') <= fit_end)
)
else:
fit_data = pl_df.filter(
(pl.col('datetime') >= fit_start) &
(pl.col('datetime') <= fit_end)
)
mean = fit_data['ret'].mean()
std = fit_data['ret'].std()
return pl_df.with_columns(
((pl.col('ret') - mean) / std).alias('label_zscore')
).select(['datetime', 'instrument', 'label_zscore'])
def _normalize_cs_zscore(self, pl_df: pl.DataFrame) -> pl.DataFrame:
"""Create cross-sectional z-score normalized label."""
return pl_df.with_columns(
((pl.col('ret') - pl.col('ret').mean().over('datetime')) /
pl.col('ret').std().over('datetime')).alias('label_cszscore')
).select(['datetime', 'instrument', 'label_cszscore'])
def _normalize_rolling(
self,
pl_df: pl.DataFrame,
window: int,
fit_range: List[str]
) -> pl.DataFrame:
"""Create rolling window normalized label."""
# Convert to pandas for rolling calculation
pd_df = pl_df.to_pandas().set_index(['datetime', 'instrument'])
# Unstack to wide format
df_wide = pd_df['ret'].unstack('instrument')
# Calculate rolling mean and std
rolling_mean = df_wide.rolling(window=window, min_periods=window//2).mean()
rolling_std = df_wide.rolling(window=window, min_periods=window//2).std()
# Normalize
df_normalized = (df_wide - rolling_mean) / rolling_std
# Restack
rolling_label = df_normalized.stack().reset_index()
rolling_label.columns = ['datetime', 'instrument', f'label_roll{window}']
return pl.from_pandas(rolling_label)
def _apply_rolling_norm(
self,
pl_df: pl.DataFrame,
window: int,
fit_range: List[str]
) -> pl.DataFrame:
"""Apply rolling normalization and cap."""
result = self._normalize_rolling(pl_df, window, fit_range)
return result.with_columns(
pl.col(f'label_roll{window}').clip(
self.label_cap_lower, self.label_cap_upper
).alias('label')
).select(['datetime', 'instrument', 'label'])
def _calculate_weights(self, pl_df: pl.DataFrame) -> pl.DataFrame:
"""Calculate sample weights based on return magnitude."""
# Base weights by return magnitude tiers
pl_df = pl_df.with_columns(
pl.when(pl.col('label').abs() > 1.5).then(pl.lit(2.5))
.when(pl.col('label').abs() > 1.0).then(pl.lit(2.0))
.when(pl.col('label').abs() > 0.5).then(pl.lit(1.5))
.when(pl.col('label').abs() > 0.2).then(pl.lit(1.0))
.otherwise(0.0).alias('weight')
)
# Apply negative return multiplier
if self.weight_factors.get('negative'):
pl_df = pl_df.with_columns(
pl.when(pl.col('label') < -0.5)
.then(pl.col('weight') * self.weight_factors['negative'])
.otherwise(pl.col('weight'))
.alias('weight')
)
# Apply positive return multiplier
if self.weight_factors.get('positive'):
pl_df = pl_df.with_columns(
pl.when(pl.col('label') > 0.5)
.then(pl.col('weight') * self.weight_factors['positive'])
.otherwise(pl.col('weight'))
.alias('weight')
)
return pl_df
def _clean_data(self, pl_df: pl.DataFrame) -> pl.DataFrame:
"""Clean data: remove inf/nan values."""
# Get numeric columns only
numeric_cols = [
c for c in pl_df.columns
if pl_df[c].dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64]
]
# Replace inf with null, then drop nulls
pl_df = pl_df.with_columns([
pl.when(pl.col(c).is_infinite()).then(None).otherwise(pl.col(c)).alias(c)
for c in numeric_cols
])
return pl_df.drop_nulls()