You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

317 lines
11 KiB

2 years ago
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "f15b133a-12f3-4803-830d-820a471add0e",
"metadata": {},
"outputs": [],
"source": [
"import qlib\n",
"from qlib.data import D\n",
"from qlib.constant import REG_CN\n",
"\n",
"from qlib.data.pit import P, PRef"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b6e1089e-e51d-42ac-aee8-53d1a886a355",
"metadata": {},
"outputs": [],
"source": [
"from pprint import pprint \n",
"\n",
"class PRelRef(P):\n",
" \n",
" def __init__(self, feature, rel_period):\n",
" super().__init__(feature)\n",
" self.rel_period = rel_period\n",
" self.unit = unit\n",
" \n",
" def __str__(self):\n",
" return f\"{super().__str__()}[{self.rel_period, self.unit}]\"\n",
" \n",
" def _load_feature(self, instrucument, start_index, end_index, cur_time):\n",
" #pprint(f\"{start_index}, {end_index}\")\n",
" #pprint(f\"{self.feature.get_longest_back_rolling()}, {self.feature.get_extended_window_size()}\")\n",
" #pprint(f\"{cur_time}, {self.rel_period}, {self.unit}\")\n",
" return self.feature.load(instrucument, start_index, end_index, cur_time, self.rel_period)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8730a9fb-9356-4847-b33d-370a3095df04",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from qlib.data.ops import ElemOperator\n",
"from qlib.data.data import Cal\n",
"\n",
"def is_of_quarter(period:int, quarter:int) -> bool:\n",
" return (period - quarter) % 100 == 0\n",
"\n",
"\n",
"class PDiff(P):\n",
" \"\"\"\n",
" 还是继承P而不是EleOperator以减少麻烦。\n",
" \"\"\"\n",
" \n",
" def __init__(self, feature, **kwargs):\n",
" super().__init__(feature)\n",
" self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n",
" self.skip_q1 = False if 'skip_q1' not in kwargs else kwargs['skip_q1']\n",
" \n",
" def _load_internal(self, instrument, start_index, end_index, freq):\n",
" _calendar = Cal.calendar(freq=freq)\n",
" resample_data = np.empty(end_index - start_index + 1, dtype=\"float32\")\n",
" \n",
" # 对日期区间逐一循环考虑到使用PIT数据的模型一般最多到日频单个股票序列长度最多到千级\n",
" for cur_index in range(start_index, end_index + 1):\n",
" cur_time = _calendar[cur_index]\n",
" # To load expression accurately, more historical data are required\n",
" start_ws, end_ws = self.get_extended_window_size()\n",
" if end_ws > 0:\n",
" raise ValueError(\n",
" \"PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported\"\n",
" )\n",
"\n",
" # The calculated value will always the last element, so the end_offset is zero.\n",
" try:\n",
" s = self._load_feature(instrument, -start_ws, 0, cur_time)\n",
" pprint(s)\n",
" # 满足不需要做diff的条件在需要跳过一季度的前提下当前引用的财报期确实为一季度\n",
" if self.skip_q1 or is_of_quarter(s.index[-1], 1):\n",
" resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan\n",
" else:\n",
" resample_data[cur_index - start_index] = (s.iloc[-1] - s.iloc[-2]) if len(s) > 1 else np.nan\n",
" except FileNotFoundError:\n",
" get_module_logger(\"base\").warning(f\"WARN: period data not found for {str(self)}\")\n",
" return pd.Series(dtype=\"float32\", name=str(self))\n",
"\n",
" resample_series = pd.Series(\n",
" resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype=\"float32\", name=str(self)\n",
" )\n",
" return resample_series\n",
"\n",
" def get_longest_back_rolling(self):\n",
" return self.feature.get_longest_back_rolling() + self.rel_period\n",
" \n",
" def get_extended_window_size(self):\n",
" # 这里需要考虑的是feature的windows size而不仅仅是自身的windows size\n",
" lft_etd, rght_etd = self.feature.get_extended_window_size()\n",
" return lft_etd + self.rel_period, rght_etd"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "abe001d9-ccb4-48a8-b5d0-91ef8862b14f",
"metadata": {},
"outputs": [],
"source": [
"\n",
"class PPairDiff(PairOperator):\n",
" \n",
" def __init__(self, feature_left, feature_right, **kwargs):\n",
" super().__init__(feature_left, feature_right)\n",
" self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n",
" \n",
"\n",
" def _load_internal(self, instrument, start_index, end_index, *args):\n",
" assert any(\n",
" [isinstance(self.feature_left, Expression), self.feature_right, Expression]\n",
" ), \"at least one of two inputs is Expression instance\"\n",
"\n",
" if isinstance(self.feature_left, Expression):\n",
" series_left = self.feature_left.load(instrument, start_index, end_index, *args)\n",
" else:\n",
" series_left = self.feature_left # numeric value\n",
" if isinstance(self.feature_right, Expression):\n",
" series_right = self.feature_right.load(instrument, start_index, end_index, *args)\n",
" else:\n",
" series_right = self.feature_right\n",
"\n",
" if self.N == 0:\n",
" series = getattr(series_left.expanding(min_periods=1), self.func)(series_right)\n",
" else:\n",
" series = getattr(series_left.rolling(self.N, min_periods=1), self.func)(series_right)\n",
" return series\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c7272be7-8df1-47b0-b18e-b631aca6e3cd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[41422:MainThread](2022-07-14 15:01:12,046) INFO - qlib.Initialization - [config.py:413] - default_conf: client.\n",
"[41422:MainThread](2022-07-14 15:01:12,053) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.\n",
"[41422:MainThread](2022-07-14 15:01:12,057) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/home/guofu/Workspaces/guofu/TslDataFeed/_data/test/target')}\n"
]
}
],
"source": [
"qlib.init(provider_uri='_data/test/target/', region=REG_CN, custom_ops=[PDiff, PPairDiff])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "3939574b-80f4-48db-bd16-1636f92b2e02",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>PDiff($$净利润_q, skip_q1=True)</th>\n",
" </tr>\n",
" <tr>\n",
" <th>instrument</th>\n",
" <th>datetime</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"6\" valign=\"top\">sh600000</th>\n",
" <th>2021-03-26</th>\n",
" <td>1.593600e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-03-29</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-03-30</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-03-31</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-04-01</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-04-02</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" PDiff($$净利润_q, skip_q1=True)\n",
"instrument datetime \n",
"sh600000 2021-03-26 1.593600e+10\n",
" 2021-03-29 1.380300e+10\n",
" 2021-03-30 1.380300e+10\n",
" 2021-03-31 1.380300e+10\n",
" 2021-04-01 1.380300e+10\n",
" 2021-04-02 1.380300e+10"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"D.features(['sh600000'], ['PDiff($$净利润_q, skip_q1=True)'], start_time='2021-03-26', end_time='2021-04-02', freq=\"day\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "679d0bcd-6975-43f3-b394-5e2e775a9b8f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 21,
"id": "c1ef5dbb-930f-4d2d-ac3a-5287ddac6d6c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(202003 - 2) % 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c01d2ab-eafa-4492-bbe5-6e3c8382bd3c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}