|
|
|
|
"""
|
|
|
|
|
CTA 1-day return prediction dataset loader.
|
|
|
|
|
|
|
|
|
|
Uses the new qshare.data framework with Dataset class and processors.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from datetime import date
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import polars as pl
|
|
|
|
|
|
|
|
|
|
from qshare.data import pl_Dataset, pl_pipe, pl_clip, pl_cs_zscore
|
|
|
|
|
from qshare.data.universal import DataSpec
|
|
|
|
|
from qshare.io.ddb import get_ddb_sess, reset_index_from_ddb
|
|
|
|
|
|
|
|
|
|
from .labels import get_blend_weights
|
|
|
|
|
|
|
|
|
|
# HFFACTOR columns (defined inline - qshare.config.research not available)
|
|
|
|
|
HFFACTOR_COLS = [
|
|
|
|
|
'vol_1min',
|
|
|
|
|
'skew_1min',
|
|
|
|
|
'volp_1min',
|
|
|
|
|
'volp_ratio_1min',
|
|
|
|
|
'voln_ratio_1min',
|
|
|
|
|
'trend_strength_1min',
|
|
|
|
|
'pv_corr_1min',
|
|
|
|
|
'flowin_ratio_1min',
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class CTA1DLoader:
|
|
|
|
|
"""
|
|
|
|
|
CTA 1-day return prediction dataset loader.
|
|
|
|
|
|
|
|
|
|
Loads features (alpha158, hffactor), labels, and calculates weights
|
|
|
|
|
for CTA futures daily return prediction tasks.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
>>> loader = CTA1DLoader(
|
|
|
|
|
... return_type='o2c_twap1min',
|
|
|
|
|
... normalization='dual',
|
|
|
|
|
... feature_sets=['alpha158', 'hffactor']
|
|
|
|
|
... )
|
|
|
|
|
>>> dataset = loader.load(dt_range=['2020-01-01', '2023-12-31'])
|
|
|
|
|
>>> dataset = dataset.with_segments({
|
|
|
|
|
... 'train': ('2020-01-01', '2022-12-31'),
|
|
|
|
|
... 'test': ('2023-01-01', '2023-12-31')
|
|
|
|
|
... })
|
|
|
|
|
>>> X_train, y_train, w_train = dataset.split('train').to_numpy()
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return_type: str = 'o2c_twap1min'
|
|
|
|
|
normalization: str = 'dual'
|
|
|
|
|
feature_sets: List[str] = None
|
|
|
|
|
weight_factors: dict = None
|
|
|
|
|
blend_weights: str | List[float] | None = None
|
|
|
|
|
ddb_host: str = '192.168.1.146'
|
|
|
|
|
label_cap_upper: float = 0.5
|
|
|
|
|
label_cap_lower: float = -0.5
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
if self.feature_sets is None:
|
|
|
|
|
self.feature_sets = ['alpha158', 'hffactor']
|
|
|
|
|
if self.weight_factors is None:
|
|
|
|
|
self.weight_factors = {'positive': 1.0, 'negative': 2.0}
|
|
|
|
|
|
|
|
|
|
def load(
|
|
|
|
|
self,
|
|
|
|
|
dt_range: List[str],
|
|
|
|
|
fit_range: Optional[List[str]] = None
|
|
|
|
|
) -> pl_Dataset:
|
|
|
|
|
"""
|
|
|
|
|
Load and prepare CTA 1-day training dataset.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dt_range: Date range [start_date, end_date] for dataset
|
|
|
|
|
fit_range: Date range [start, end] for fitting normalization params.
|
|
|
|
|
If None, uses first 60% of dt_range.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
pl_Dataset with features, label, and weight columns
|
|
|
|
|
"""
|
|
|
|
|
start_date, end_date = dt_range
|
|
|
|
|
|
|
|
|
|
if fit_range is None:
|
|
|
|
|
# Default: use first 60% for fit
|
|
|
|
|
all_dates = pd.date_range(start_date, end_date)
|
|
|
|
|
split_idx = int(len(all_dates) * 0.6)
|
|
|
|
|
fit_range = [
|
|
|
|
|
all_dates[0].strftime('%Y-%m-%d'),
|
|
|
|
|
all_dates[split_idx].strftime('%Y-%m-%d')
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Load extended history for rolling normalization
|
|
|
|
|
load_start = (pd.to_datetime(start_date) - pd.Timedelta(days=120)).strftime('%Y-%m-%d')
|
|
|
|
|
|
|
|
|
|
# Load features
|
|
|
|
|
df_features = self._load_features(load_start, end_date)
|
|
|
|
|
|
|
|
|
|
# Load and normalize labels
|
|
|
|
|
df_label = self._load_labels(load_start, end_date, fit_range)
|
|
|
|
|
|
|
|
|
|
# Combine
|
|
|
|
|
df = df_features.join(df_label, on=['datetime', 'instrument'], how='inner')
|
|
|
|
|
|
|
|
|
|
# Filter to requested date range
|
|
|
|
|
# Convert string dates to date objects for proper comparison
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
start_dt = datetime.strptime(start_date, '%Y-%m-%d').date()
|
|
|
|
|
end_dt = datetime.strptime(end_date, '%Y-%m-%d').date()
|
|
|
|
|
df = df.filter(
|
|
|
|
|
(pl.col('datetime') >= start_dt) &
|
|
|
|
|
(pl.col('datetime') <= end_dt)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Calculate weights
|
|
|
|
|
df = self._calculate_weights(df)
|
|
|
|
|
|
|
|
|
|
# Clean data
|
|
|
|
|
df = self._clean_data(df)
|
|
|
|
|
|
|
|
|
|
# Get feature columns
|
|
|
|
|
feature_cols = [c for c in df.columns
|
|
|
|
|
if any(c.startswith(prefix) for prefix in ['f_a158_', 'f_hf_'])]
|
|
|
|
|
|
|
|
|
|
return pl_Dataset(
|
|
|
|
|
data=df,
|
|
|
|
|
features=feature_cols,
|
|
|
|
|
label='label',
|
|
|
|
|
weight='weight' if self.weight_factors else None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _load_features(self, start_date: str, end_date: str) -> pl.DataFrame:
|
|
|
|
|
"""Load feature data from DolphinDB."""
|
|
|
|
|
sess = get_ddb_sess(host=self.ddb_host)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
feature_dfs = []
|
|
|
|
|
|
|
|
|
|
if 'alpha158' in self.feature_sets:
|
|
|
|
|
df_alpha = self._load_alpha158(sess, start_date, end_date)
|
|
|
|
|
feature_dfs.append(df_alpha)
|
|
|
|
|
|
|
|
|
|
if 'hffactor' in self.feature_sets:
|
|
|
|
|
df_hf = self._load_hffactor(sess, start_date, end_date)
|
|
|
|
|
feature_dfs.append(df_hf)
|
|
|
|
|
|
|
|
|
|
# Join all feature sets
|
|
|
|
|
result = feature_dfs[0]
|
|
|
|
|
for df in feature_dfs[1:]:
|
|
|
|
|
result = result.join(df, on=['datetime', 'instrument'], how='inner')
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
sess.close()
|
|
|
|
|
|
|
|
|
|
def _load_alpha158(self, sess, start_date: str, end_date: str) -> pl.DataFrame:
|
|
|
|
|
"""Load alpha158 features from DolphinDB."""
|
|
|
|
|
since_ddb = pd.to_datetime(start_date).strftime('%Y.%m.%d')
|
|
|
|
|
|
|
|
|
|
df = sess.run(f"""
|
|
|
|
|
select code, m_nDate, *
|
|
|
|
|
from loadTable('dfs://daily_stock_run', 'stg_1day_tinysoft_cta_alpha159_0_7_beta')
|
|
|
|
|
where m_nDate >= {since_ddb}
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
df = reset_index_from_ddb(df)
|
|
|
|
|
|
|
|
|
|
# Drop non-numeric columns
|
|
|
|
|
if 'code_init' in df.columns:
|
|
|
|
|
df = df.drop(columns=['code_init'])
|
|
|
|
|
|
|
|
|
|
# Convert to polars and add prefix
|
|
|
|
|
pl_df = pl.from_pandas(df.reset_index())
|
|
|
|
|
pl_df = pl_df.rename({
|
|
|
|
|
c: f'f_a158_{c}'
|
|
|
|
|
for c in pl_df.columns
|
|
|
|
|
if c not in ['datetime', 'instrument']
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return pl_df
|
|
|
|
|
|
|
|
|
|
def _load_hffactor(self, sess, start_date: str, end_date: str) -> pl.DataFrame:
|
|
|
|
|
"""Load hffactor features from DolphinDB."""
|
|
|
|
|
since_ddb = pd.to_datetime(start_date).strftime('%Y.%m.%d')
|
|
|
|
|
|
|
|
|
|
# Load from factor table
|
|
|
|
|
factor_list = ','.join([f"'{c}'" for c in HFFACTOR_COLS])
|
|
|
|
|
query = f"""select code, m_nDate, factor_name, value
|
|
|
|
|
from loadTable('dfs://daily_stock_run', 'stg_1day_tinysoft_cta_hffactor')
|
|
|
|
|
where m_nDate >= {since_ddb} and factor_name in [{factor_list}]"""
|
|
|
|
|
df = sess.run(query)
|
|
|
|
|
|
|
|
|
|
# Pivot to wide format
|
|
|
|
|
df = df.pivot_table(
|
|
|
|
|
index=['code', 'm_nDate'],
|
|
|
|
|
columns='factor_name',
|
|
|
|
|
values='value'
|
|
|
|
|
).reset_index()
|
|
|
|
|
|
|
|
|
|
df = reset_index_from_ddb(df)
|
|
|
|
|
|
|
|
|
|
# Convert to polars and add prefix
|
|
|
|
|
pl_df = pl.from_pandas(df.reset_index())
|
|
|
|
|
pl_df = pl_df.rename({
|
|
|
|
|
c: f'f_hf_{c}'
|
|
|
|
|
for c in pl_df.columns
|
|
|
|
|
if c not in ['datetime', 'instrument']
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return pl_df
|
|
|
|
|
|
|
|
|
|
def _load_labels(
|
|
|
|
|
self,
|
|
|
|
|
start_date: str,
|
|
|
|
|
end_date: str,
|
|
|
|
|
fit_range: List[str]
|
|
|
|
|
) -> pl.DataFrame:
|
|
|
|
|
"""Load and normalize labels."""
|
|
|
|
|
sess = get_ddb_sess(host=self.ddb_host)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Map return type to indicator name
|
|
|
|
|
indicator_map = {
|
|
|
|
|
'o2c_twap1min': 'twap_open1m@1_twap_close1m@1',
|
|
|
|
|
'o2o_twap1min': 'twap_open1m@1_twap_open1m@2',
|
|
|
|
|
}
|
|
|
|
|
indicator = indicator_map.get(self.return_type, self.return_type)
|
|
|
|
|
|
|
|
|
|
since_ddb = pd.to_datetime(start_date).strftime('%Y.%m.%d')
|
|
|
|
|
|
|
|
|
|
# Load dominant contract mapping
|
|
|
|
|
df_contract = sess.run(f"""
|
|
|
|
|
select first(code) as code, m_nDate, code_init
|
|
|
|
|
from loadTable('dfs://daily_stock_run', 'dwm_1day_cta_dom')
|
|
|
|
|
where m_nDate >= {since_ddb} and version='vp_csmax_roll2_cummax'
|
|
|
|
|
group by m_nDate, code_init
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
# Load returns
|
|
|
|
|
df_return = sess.run(f"""
|
|
|
|
|
select code, m_nDate, value as ret
|
|
|
|
|
from loadTable('dfs://daily_stock_run', 'stg_1day_tinysoft_cta_hfvalue')
|
|
|
|
|
where indicator='{indicator}' and m_nDate >= {since_ddb}
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
# Merge with dominant contract mapping
|
|
|
|
|
df_return = pd.merge(
|
|
|
|
|
left=df_return[['code', 'm_nDate', 'ret']],
|
|
|
|
|
right=df_contract,
|
|
|
|
|
on=['code', 'm_nDate'],
|
|
|
|
|
how='inner'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Convert to index format
|
|
|
|
|
df_return['code'] = df_return['code_init'] + 'Ind'
|
|
|
|
|
df_return = df_return[['code', 'm_nDate', 'ret']]
|
|
|
|
|
df_return = reset_index_from_ddb(df_return)
|
|
|
|
|
|
|
|
|
|
# Convert to polars
|
|
|
|
|
pl_df = pl.from_pandas(df_return.reset_index())
|
|
|
|
|
|
|
|
|
|
# Apply normalization
|
|
|
|
|
pl_df = self._normalize_label(pl_df, fit_range)
|
|
|
|
|
|
|
|
|
|
return pl_df
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
sess.close()
|
|
|
|
|
|
|
|
|
|
def _normalize_label(self, pl_df: pl.DataFrame, fit_range: List[str]) -> pl.DataFrame:
|
|
|
|
|
"""Apply specified normalization to label."""
|
|
|
|
|
fit_start, fit_end = fit_range
|
|
|
|
|
|
|
|
|
|
# Convert fit_range strings to date objects for comparison
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
fit_start_date = datetime.strptime(fit_start, '%Y-%m-%d').date()
|
|
|
|
|
fit_end_date = datetime.strptime(fit_end, '%Y-%m-%d').date()
|
|
|
|
|
|
|
|
|
|
if self.normalization == 'zscore':
|
|
|
|
|
# Calculate mean/std on fit range
|
|
|
|
|
fit_data = pl_df.filter(
|
|
|
|
|
(pl.col('datetime') >= fit_start_date) &
|
|
|
|
|
(pl.col('datetime') <= fit_end_date)
|
|
|
|
|
)
|
|
|
|
|
mean = fit_data['ret'].mean()
|
|
|
|
|
std = fit_data['ret'].std()
|
|
|
|
|
|
|
|
|
|
result = pl_df.with_columns(
|
|
|
|
|
((pl.col('ret') - mean) / std).clip(
|
|
|
|
|
self.label_cap_lower, self.label_cap_upper
|
|
|
|
|
).alias('label')
|
|
|
|
|
).select(['datetime', 'instrument', 'label'])
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
elif self.normalization == 'cs_zscore':
|
|
|
|
|
# Cross-sectional z-score per datetime
|
|
|
|
|
return pl_df.with_columns(
|
|
|
|
|
((pl.col('ret') - pl.col('ret').mean().over('datetime')) /
|
|
|
|
|
pl.col('ret').std().over('datetime')).clip(
|
|
|
|
|
self.label_cap_lower, self.label_cap_upper
|
|
|
|
|
).alias('label')
|
|
|
|
|
).select(['datetime', 'instrument', 'label'])
|
|
|
|
|
|
|
|
|
|
elif self.normalization == 'rolling_20':
|
|
|
|
|
return self._apply_rolling_norm(pl_df, window=20, fit_range=fit_range)
|
|
|
|
|
|
|
|
|
|
elif self.normalization == 'rolling_60':
|
|
|
|
|
return self._apply_rolling_norm(pl_df, window=60, fit_range=fit_range)
|
|
|
|
|
|
|
|
|
|
elif self.normalization == 'dual':
|
|
|
|
|
# Create all normalization variants
|
|
|
|
|
label_zscore = self._normalize_zscore(pl_df, fit_range)
|
|
|
|
|
label_cszscore = self._normalize_cs_zscore(pl_df)
|
|
|
|
|
label_roll20 = self._normalize_rolling(pl_df, window=20, fit_range=fit_range)
|
|
|
|
|
label_roll60 = self._normalize_rolling(pl_df, window=60, fit_range=fit_range)
|
|
|
|
|
|
|
|
|
|
# Get blend weights
|
|
|
|
|
weights = get_blend_weights(self.blend_weights)
|
|
|
|
|
|
|
|
|
|
# Join and blend
|
|
|
|
|
pl_df = label_zscore.join(label_cszscore, on=['datetime', 'instrument'])
|
|
|
|
|
pl_df = pl_df.join(label_roll20, on=['datetime', 'instrument'])
|
|
|
|
|
pl_df = pl_df.join(label_roll60, on=['datetime', 'instrument'])
|
|
|
|
|
|
|
|
|
|
return pl_df.with_columns(
|
|
|
|
|
(weights[0] * pl.col('label_zscore') +
|
|
|
|
|
weights[1] * pl.col('label_cszscore') +
|
|
|
|
|
weights[2] * pl.col('label_roll20') +
|
|
|
|
|
weights[3] * pl.col('label_roll60')).clip(
|
|
|
|
|
self.label_cap_lower, self.label_cap_upper
|
|
|
|
|
).alias('label')
|
|
|
|
|
).select(['datetime', 'instrument', 'label'])
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown normalization: {self.normalization}")
|
|
|
|
|
|
|
|
|
|
def _normalize_zscore(self, pl_df: pl.DataFrame, fit_range: List[str]) -> pl.DataFrame:
|
|
|
|
|
"""Create z-score normalized label."""
|
|
|
|
|
fit_start, fit_end = fit_range
|
|
|
|
|
|
|
|
|
|
# Convert fit_range strings to date objects for comparison
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
fit_start_date = datetime.strptime(fit_start, '%Y-%m-%d').date()
|
|
|
|
|
fit_end_date = datetime.strptime(fit_end, '%Y-%m-%d').date()
|
|
|
|
|
|
|
|
|
|
fit_data = pl_df.filter(
|
|
|
|
|
(pl.col('datetime') >= fit_start_date) &
|
|
|
|
|
(pl.col('datetime') <= fit_end_date)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
mean = fit_data['ret'].mean()
|
|
|
|
|
std = fit_data['ret'].std()
|
|
|
|
|
|
|
|
|
|
return pl_df.with_columns(
|
|
|
|
|
((pl.col('ret') - mean) / std).alias('label_zscore')
|
|
|
|
|
).select(['datetime', 'instrument', 'label_zscore'])
|
|
|
|
|
|
|
|
|
|
def _normalize_cs_zscore(self, pl_df: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
"""Create cross-sectional z-score normalized label."""
|
|
|
|
|
return pl_df.with_columns(
|
|
|
|
|
((pl.col('ret') - pl.col('ret').mean().over('datetime')) /
|
|
|
|
|
pl.col('ret').std().over('datetime')).alias('label_cszscore')
|
|
|
|
|
).select(['datetime', 'instrument', 'label_cszscore'])
|
|
|
|
|
|
|
|
|
|
def _normalize_rolling(
|
|
|
|
|
self,
|
|
|
|
|
pl_df: pl.DataFrame,
|
|
|
|
|
window: int,
|
|
|
|
|
fit_range: List[str]
|
|
|
|
|
) -> pl.DataFrame:
|
|
|
|
|
"""Create rolling window normalized label."""
|
|
|
|
|
# Convert to pandas for rolling calculation
|
|
|
|
|
pd_df = pl_df.to_pandas().set_index(['datetime', 'instrument'])
|
|
|
|
|
|
|
|
|
|
# Unstack to wide format
|
|
|
|
|
df_wide = pd_df['ret'].unstack('instrument')
|
|
|
|
|
|
|
|
|
|
# Calculate rolling mean and std
|
|
|
|
|
rolling_mean = df_wide.rolling(window=window, min_periods=window//2).mean()
|
|
|
|
|
rolling_std = df_wide.rolling(window=window, min_periods=window//2).std()
|
|
|
|
|
|
|
|
|
|
# Normalize
|
|
|
|
|
df_normalized = (df_wide - rolling_mean) / rolling_std
|
|
|
|
|
|
|
|
|
|
# Restack
|
|
|
|
|
rolling_label = df_normalized.stack().reset_index()
|
|
|
|
|
rolling_label.columns = ['datetime', 'instrument', f'label_roll{window}']
|
|
|
|
|
|
|
|
|
|
return pl.from_pandas(rolling_label)
|
|
|
|
|
|
|
|
|
|
def _apply_rolling_norm(
|
|
|
|
|
self,
|
|
|
|
|
pl_df: pl.DataFrame,
|
|
|
|
|
window: int,
|
|
|
|
|
fit_range: List[str]
|
|
|
|
|
) -> pl.DataFrame:
|
|
|
|
|
"""Apply rolling normalization and cap."""
|
|
|
|
|
result = self._normalize_rolling(pl_df, window, fit_range)
|
|
|
|
|
return result.with_columns(
|
|
|
|
|
pl.col(f'label_roll{window}').clip(
|
|
|
|
|
self.label_cap_lower, self.label_cap_upper
|
|
|
|
|
).alias('label')
|
|
|
|
|
).select(['datetime', 'instrument', 'label'])
|
|
|
|
|
|
|
|
|
|
def _calculate_weights(self, pl_df: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
"""Calculate sample weights based on return magnitude."""
|
|
|
|
|
# Base weights by return magnitude tiers
|
|
|
|
|
pl_df = pl_df.with_columns(
|
|
|
|
|
pl.when(pl.col('label').abs() > 1.5).then(pl.lit(2.5))
|
|
|
|
|
.when(pl.col('label').abs() > 1.0).then(pl.lit(2.0))
|
|
|
|
|
.when(pl.col('label').abs() > 0.5).then(pl.lit(1.5))
|
|
|
|
|
.when(pl.col('label').abs() > 0.2).then(pl.lit(1.0))
|
|
|
|
|
.otherwise(0.0).alias('weight')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Apply negative return multiplier
|
|
|
|
|
if self.weight_factors.get('negative'):
|
|
|
|
|
pl_df = pl_df.with_columns(
|
|
|
|
|
pl.when(pl.col('label') < -0.5)
|
|
|
|
|
.then(pl.col('weight') * self.weight_factors['negative'])
|
|
|
|
|
.otherwise(pl.col('weight'))
|
|
|
|
|
.alias('weight')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Apply positive return multiplier
|
|
|
|
|
if self.weight_factors.get('positive'):
|
|
|
|
|
pl_df = pl_df.with_columns(
|
|
|
|
|
pl.when(pl.col('label') > 0.5)
|
|
|
|
|
.then(pl.col('weight') * self.weight_factors['positive'])
|
|
|
|
|
.otherwise(pl.col('weight'))
|
|
|
|
|
.alias('weight')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return pl_df
|
|
|
|
|
|
|
|
|
|
def _clean_data(self, pl_df: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
"""Clean data: remove inf/nan values."""
|
|
|
|
|
# Get numeric columns only
|
|
|
|
|
numeric_cols = [
|
|
|
|
|
c for c in pl_df.columns
|
|
|
|
|
if pl_df[c].dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64]
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Replace inf with null, then drop nulls
|
|
|
|
|
pl_df = pl_df.with_columns([
|
|
|
|
|
pl.when(pl.col(c).is_infinite()).then(None).otherwise(pl.col(c)).alias(c)
|
|
|
|
|
for c in numeric_cols
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
return pl_df.drop_nulls()
|