You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
7.3 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Stock 15m Baseline Model\n",
"\n",
"Train and evaluate a baseline XGBoost model for 15-minute return prediction.\n",
"\n",
"**Purpose**: Establish baseline performance with standard configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import polars as pl\n",
"import matplotlib.pyplot as plt\n",
"import xgboost as xgb\n",
"from sklearn.metrics import r2_score\n",
"\n",
"from qshare.data.polars.ret15m import load_dataset\n",
"from qshare.io.polars import load_from_pq\n",
"\n",
"import sys\n",
"sys.path.insert(0, '../')\n",
"from common.plotting import setup_plot_style\n",
"from common.paths import create_experiment_dir\n",
"\n",
"setup_plot_style()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"CONFIG = {\n",
" 'experiment_name': 'baseline_xgb',\n",
" 'save_results': True,\n",
" 'path_a158': '/data/parquet/stock_1min_alpha158',\n",
" 'path_kline': '/data/parquet/stock_1min',\n",
" 'path_kline_daily': '/data/parquet/stock_1day',\n",
" 'path_industry': '/data/parquet/industry_idx',\n",
" 'dt_range': ['2022-01-01', '2024-12-31'],\n",
" 'train_range': ['2022-01-01', '2023-12-31'],\n",
" 'test_range': ['2024-01-01', '2024-12-31'],\n",
" 'normalization_mode': 'dual',\n",
" 'positive_factor': 1.0,\n",
" 'negative_factor': 2.0,\n",
" 'model_params': {\n",
" 'objective': 'reg:squarederror',\n",
" 'eval_metric': 'rmse',\n",
" 'max_depth': 6,\n",
" 'learning_rate': 0.1,\n",
" 'n_estimators': 100,\n",
" 'subsample': 0.8,\n",
" 'colsample_bytree': 0.8,\n",
" 'random_state': 42,\n",
" },\n",
"}\n",
"\n",
"print('Configuration:')\n",
"for key, value in CONFIG.items():\n",
" if not isinstance(value, dict):\n",
" print(f' {key}: {value}')\n",
"print(f\"Model params: {CONFIG['model_params']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Load Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('Loading data sources...')\n",
"\n",
"pl_ldf_a158 = load_from_pq(\n",
" path=CONFIG['path_a158'],\n",
" table_alias='a158',\n",
" start_time=CONFIG['dt_range'][0],\n",
" as_struct=True\n",
")\n",
"\n",
"pl_ldf_kline = load_from_pq(\n",
" path=CONFIG['path_kline'],\n",
" table_alias='kline_1min',\n",
" start_time=CONFIG['dt_range'][0],\n",
" as_struct=True\n",
")\n",
"\n",
"pl_ldf_kline_daily = load_from_pq(\n",
" path=CONFIG['path_kline_daily'],\n",
" table_alias='kline_1day',\n",
" start_time=CONFIG['dt_range'][0],\n",
")\n",
"\n",
"pl_ldf_industry = load_from_pq(\n",
" path=CONFIG['path_industry'],\n",
" table_alias='indus_idx',\n",
" start_time=CONFIG['dt_range'][0],\n",
")\n",
"\n",
"print('Loading dataset...')\n",
"pl_df = load_dataset(\n",
" pl_ldf_a158_1min=pl_ldf_a158,\n",
" pl_ldf_kline_1min=pl_ldf_kline,\n",
" pl_ldf_kline_1day=pl_ldf_kline_daily,\n",
" pl_ldf_indus_idx=pl_ldf_industry,\n",
" dt_range=CONFIG['dt_range'],\n",
" normalization_mode=CONFIG['normalization_mode'],\n",
" negative_factor=CONFIG['negative_factor'],\n",
" positive_factor=CONFIG['positive_factor'],\n",
")\n",
"\n",
"df_full = pl_df.to_pandas()\n",
"print(f'Full dataset shape: {df_full.shape}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Train/Test Split and Model Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"feature_cols = [c for c in df_full.columns if c.startswith('alpha158_')]\n",
"target_cols = [c for c in df_full.columns if 'return' in c.lower()]\n",
"weight_cols = [c for c in df_full.columns if 'weight' in c.lower()]\n",
"\n",
"target_col = target_cols[0]\n",
"weight_col = weight_cols[0] if weight_cols else None\n",
"\n",
"df_train = df_full.loc[CONFIG['train_range'][0]:CONFIG['train_range'][1]]\n",
"df_test = df_full.loc[CONFIG['test_range'][0]:CONFIG['test_range'][1]]\n",
"\n",
"print(f'Train: {df_train.shape}, Test: {df_test.shape}')\n",
"print(f'Features: {len(feature_cols)}, Target: {target_col}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_train = df_train[feature_cols].fillna(df_train[feature_cols].median())\n",
"y_train = df_train[target_col]\n",
"w_train = df_train[weight_col] if weight_col else None\n",
"\n",
"X_test = df_test[feature_cols].fillna(df_train[feature_cols].median())\n",
"y_test = df_test[target_col]\n",
"\n",
"print('Training XGBoost...')\n",
"model = xgb.XGBRegressor(**CONFIG['model_params'])\n",
"model.fit(X_train, y_train, sample_weight=w_train, verbose=False)\n",
"print('Training complete!')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred_test = model.predict(X_test)\n",
"\n",
"test_r2 = r2_score(y_test, y_pred_test)\n",
"test_ic = np.corrcoef(y_test, y_pred_test)[0, 1]\n",
"\n",
"print(f'Test R2: {test_r2:.4f}')\n",
"print(f'Test IC: {test_ic:.4f}')\n",
"\n",
"# Daily IC\n",
"df_test_eval = df_test.copy()\n",
"df_test_eval['pred'] = y_pred_test\n",
"df_test_eval['target'] = y_test\n",
"df_test_eval['datetime'] = df_test_eval.index.get_level_values(0)\n",
"\n",
"daily_ic = df_test_eval.groupby('datetime').apply(\n",
" lambda x: x['target'].corr(x['pred'])\n",
")\n",
"\n",
"print(f'Daily IC Mean: {daily_ic.mean():.4f}')\n",
"print(f'Daily IC Std: {daily_ic.std():.4f}')\n",
"print(f'IR: {daily_ic.mean() / daily_ic.std():.4f}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Plot daily IC\n",
"fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n",
"\n",
"daily_ic.hist(bins=50, ax=axes[0], edgecolor='black')\n",
"axes[0].axvline(x=daily_ic.mean(), color='green', linestyle='--')\n",
"axes[0].set_title('Daily IC Distribution')\n",
"\n",
"daily_ic.rolling(20, min_periods=5).mean().plot(ax=axes[1])\n",
"axes[1].axhline(y=0, color='red', linestyle='--')\n",
"axes[1].set_title('Rolling IC (20-day)')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}