You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
257 lines
7.3 KiB
257 lines
7.3 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Stock 15m Baseline Model\n",
|
|
"\n",
|
|
"Train and evaluate a baseline XGBoost model for 15-minute return prediction.\n",
|
|
"\n",
|
|
"**Purpose**: Establish baseline performance with standard configuration."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import polars as pl\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import xgboost as xgb\n",
|
|
"from sklearn.metrics import r2_score\n",
|
|
"\n",
|
|
"from qshare.data.polars.ret15m import load_dataset\n",
|
|
"from qshare.io.polars import load_from_pq\n",
|
|
"\n",
|
|
"import sys\n",
|
|
"sys.path.insert(0, '../')\n",
|
|
"from common.plotting import setup_plot_style\n",
|
|
"from common.paths import create_experiment_dir\n",
|
|
"\n",
|
|
"setup_plot_style()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1. Configuration"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"CONFIG = {\n",
|
|
" 'experiment_name': 'baseline_xgb',\n",
|
|
" 'save_results': True,\n",
|
|
" 'path_a158': '/data/parquet/stock_1min_alpha158',\n",
|
|
" 'path_kline': '/data/parquet/stock_1min',\n",
|
|
" 'path_kline_daily': '/data/parquet/stock_1day',\n",
|
|
" 'path_industry': '/data/parquet/industry_idx',\n",
|
|
" 'dt_range': ['2022-01-01', '2024-12-31'],\n",
|
|
" 'train_range': ['2022-01-01', '2023-12-31'],\n",
|
|
" 'test_range': ['2024-01-01', '2024-12-31'],\n",
|
|
" 'normalization_mode': 'dual',\n",
|
|
" 'positive_factor': 1.0,\n",
|
|
" 'negative_factor': 2.0,\n",
|
|
" 'model_params': {\n",
|
|
" 'objective': 'reg:squarederror',\n",
|
|
" 'eval_metric': 'rmse',\n",
|
|
" 'max_depth': 6,\n",
|
|
" 'learning_rate': 0.1,\n",
|
|
" 'n_estimators': 100,\n",
|
|
" 'subsample': 0.8,\n",
|
|
" 'colsample_bytree': 0.8,\n",
|
|
" 'random_state': 42,\n",
|
|
" },\n",
|
|
"}\n",
|
|
"\n",
|
|
"print('Configuration:')\n",
|
|
"for key, value in CONFIG.items():\n",
|
|
" if not isinstance(value, dict):\n",
|
|
" print(f' {key}: {value}')\n",
|
|
"print(f\"Model params: {CONFIG['model_params']}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. Load Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print('Loading data sources...')\n",
|
|
"\n",
|
|
"pl_ldf_a158 = load_from_pq(\n",
|
|
" path=CONFIG['path_a158'],\n",
|
|
" table_alias='a158',\n",
|
|
" start_time=CONFIG['dt_range'][0],\n",
|
|
" as_struct=True\n",
|
|
")\n",
|
|
"\n",
|
|
"pl_ldf_kline = load_from_pq(\n",
|
|
" path=CONFIG['path_kline'],\n",
|
|
" table_alias='kline_1min',\n",
|
|
" start_time=CONFIG['dt_range'][0],\n",
|
|
" as_struct=True\n",
|
|
")\n",
|
|
"\n",
|
|
"pl_ldf_kline_daily = load_from_pq(\n",
|
|
" path=CONFIG['path_kline_daily'],\n",
|
|
" table_alias='kline_1day',\n",
|
|
" start_time=CONFIG['dt_range'][0],\n",
|
|
")\n",
|
|
"\n",
|
|
"pl_ldf_industry = load_from_pq(\n",
|
|
" path=CONFIG['path_industry'],\n",
|
|
" table_alias='indus_idx',\n",
|
|
" start_time=CONFIG['dt_range'][0],\n",
|
|
")\n",
|
|
"\n",
|
|
"print('Loading dataset...')\n",
|
|
"pl_df = load_dataset(\n",
|
|
" pl_ldf_a158_1min=pl_ldf_a158,\n",
|
|
" pl_ldf_kline_1min=pl_ldf_kline,\n",
|
|
" pl_ldf_kline_1day=pl_ldf_kline_daily,\n",
|
|
" pl_ldf_indus_idx=pl_ldf_industry,\n",
|
|
" dt_range=CONFIG['dt_range'],\n",
|
|
" normalization_mode=CONFIG['normalization_mode'],\n",
|
|
" negative_factor=CONFIG['negative_factor'],\n",
|
|
" positive_factor=CONFIG['positive_factor'],\n",
|
|
")\n",
|
|
"\n",
|
|
"df_full = pl_df.to_pandas()\n",
|
|
"print(f'Full dataset shape: {df_full.shape}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3. Train/Test Split and Model Training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"feature_cols = [c for c in df_full.columns if c.startswith('alpha158_')]\n",
|
|
"target_cols = [c for c in df_full.columns if 'return' in c.lower()]\n",
|
|
"weight_cols = [c for c in df_full.columns if 'weight' in c.lower()]\n",
|
|
"\n",
|
|
"target_col = target_cols[0]\n",
|
|
"weight_col = weight_cols[0] if weight_cols else None\n",
|
|
"\n",
|
|
"df_train = df_full.loc[CONFIG['train_range'][0]:CONFIG['train_range'][1]]\n",
|
|
"df_test = df_full.loc[CONFIG['test_range'][0]:CONFIG['test_range'][1]]\n",
|
|
"\n",
|
|
"print(f'Train: {df_train.shape}, Test: {df_test.shape}')\n",
|
|
"print(f'Features: {len(feature_cols)}, Target: {target_col}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X_train = df_train[feature_cols].fillna(df_train[feature_cols].median())\n",
|
|
"y_train = df_train[target_col]\n",
|
|
"w_train = df_train[weight_col] if weight_col else None\n",
|
|
"\n",
|
|
"X_test = df_test[feature_cols].fillna(df_train[feature_cols].median())\n",
|
|
"y_test = df_test[target_col]\n",
|
|
"\n",
|
|
"print('Training XGBoost...')\n",
|
|
"model = xgb.XGBRegressor(**CONFIG['model_params'])\n",
|
|
"model.fit(X_train, y_train, sample_weight=w_train, verbose=False)\n",
|
|
"print('Training complete!')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4. Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y_pred_test = model.predict(X_test)\n",
|
|
"\n",
|
|
"test_r2 = r2_score(y_test, y_pred_test)\n",
|
|
"test_ic = np.corrcoef(y_test, y_pred_test)[0, 1]\n",
|
|
"\n",
|
|
"print(f'Test R2: {test_r2:.4f}')\n",
|
|
"print(f'Test IC: {test_ic:.4f}')\n",
|
|
"\n",
|
|
"# Daily IC\n",
|
|
"df_test_eval = df_test.copy()\n",
|
|
"df_test_eval['pred'] = y_pred_test\n",
|
|
"df_test_eval['target'] = y_test\n",
|
|
"df_test_eval['datetime'] = df_test_eval.index.get_level_values(0)\n",
|
|
"\n",
|
|
"daily_ic = df_test_eval.groupby('datetime').apply(\n",
|
|
" lambda x: x['target'].corr(x['pred'])\n",
|
|
")\n",
|
|
"\n",
|
|
"print(f'Daily IC Mean: {daily_ic.mean():.4f}')\n",
|
|
"print(f'Daily IC Std: {daily_ic.std():.4f}')\n",
|
|
"print(f'IR: {daily_ic.mean() / daily_ic.std():.4f}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plot daily IC\n",
|
|
"fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n",
|
|
"\n",
|
|
"daily_ic.hist(bins=50, ax=axes[0], edgecolor='black')\n",
|
|
"axes[0].axvline(x=daily_ic.mean(), color='green', linestyle='--')\n",
|
|
"axes[0].set_title('Daily IC Distribution')\n",
|
|
"\n",
|
|
"daily_ic.rolling(20, min_periods=5).mean().plot(ax=axes[1])\n",
|
|
"axes[1].axhline(y=0, color='red', linestyle='--')\n",
|
|
"axes[1].set_title('Rolling IC (20-day)')\n",
|
|
"\n",
|
|
"plt.tight_layout()\n",
|
|
"plt.show()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python",
|
|
"version": "3.8.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
} |