You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

317 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "f15b133a-12f3-4803-830d-820a471add0e",
"metadata": {},
"outputs": [],
"source": [
"import qlib\n",
"from qlib.data import D\n",
"from qlib.constant import REG_CN\n",
"\n",
"from qlib.data.pit import P, PRef"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b6e1089e-e51d-42ac-aee8-53d1a886a355",
"metadata": {},
"outputs": [],
"source": [
"from pprint import pprint \n",
"\n",
"class PRelRef(P):\n",
" \n",
" def __init__(self, feature, rel_period):\n",
" super().__init__(feature)\n",
" self.rel_period = rel_period\n",
" self.unit = unit\n",
" \n",
" def __str__(self):\n",
" return f\"{super().__str__()}[{self.rel_period, self.unit}]\"\n",
" \n",
" def _load_feature(self, instrucument, start_index, end_index, cur_time):\n",
" #pprint(f\"{start_index}, {end_index}\")\n",
" #pprint(f\"{self.feature.get_longest_back_rolling()}, {self.feature.get_extended_window_size()}\")\n",
" #pprint(f\"{cur_time}, {self.rel_period}, {self.unit}\")\n",
" return self.feature.load(instrucument, start_index, end_index, cur_time, self.rel_period)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8730a9fb-9356-4847-b33d-370a3095df04",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from qlib.data.ops import ElemOperator\n",
"from qlib.data.data import Cal\n",
"\n",
"def is_of_quarter(period:int, quarter:int) -> bool:\n",
" return (period - quarter) % 100 == 0\n",
"\n",
"\n",
"class PDiff(P):\n",
" \"\"\"\n",
" 还是继承P而不是EleOperator以减少麻烦。\n",
" \"\"\"\n",
" \n",
" def __init__(self, feature, **kwargs):\n",
" super().__init__(feature)\n",
" self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n",
" self.skip_q1 = False if 'skip_q1' not in kwargs else kwargs['skip_q1']\n",
" \n",
" def _load_internal(self, instrument, start_index, end_index, freq):\n",
" _calendar = Cal.calendar(freq=freq)\n",
" resample_data = np.empty(end_index - start_index + 1, dtype=\"float32\")\n",
" \n",
" # 对日期区间逐一循环考虑到使用PIT数据的模型一般最多到日频单个股票序列长度最多到千级\n",
" for cur_index in range(start_index, end_index + 1):\n",
" cur_time = _calendar[cur_index]\n",
" # To load expression accurately, more historical data are required\n",
" start_ws, end_ws = self.get_extended_window_size()\n",
" if end_ws > 0:\n",
" raise ValueError(\n",
" \"PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported\"\n",
" )\n",
"\n",
" # The calculated value will always the last element, so the end_offset is zero.\n",
" try:\n",
" s = self._load_feature(instrument, -start_ws, 0, cur_time)\n",
" pprint(s)\n",
" # 满足不需要做diff的条件在需要跳过一季度的前提下当前引用的财报期确实为一季度\n",
" if self.skip_q1 or is_of_quarter(s.index[-1], 1):\n",
" resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan\n",
" else:\n",
" resample_data[cur_index - start_index] = (s.iloc[-1] - s.iloc[-2]) if len(s) > 1 else np.nan\n",
" except FileNotFoundError:\n",
" get_module_logger(\"base\").warning(f\"WARN: period data not found for {str(self)}\")\n",
" return pd.Series(dtype=\"float32\", name=str(self))\n",
"\n",
" resample_series = pd.Series(\n",
" resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype=\"float32\", name=str(self)\n",
" )\n",
" return resample_series\n",
"\n",
" def get_longest_back_rolling(self):\n",
" return self.feature.get_longest_back_rolling() + self.rel_period\n",
" \n",
" def get_extended_window_size(self):\n",
" # 这里需要考虑的是feature的windows size而不仅仅是自身的windows size\n",
" lft_etd, rght_etd = self.feature.get_extended_window_size()\n",
" return lft_etd + self.rel_period, rght_etd"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "abe001d9-ccb4-48a8-b5d0-91ef8862b14f",
"metadata": {},
"outputs": [],
"source": [
"\n",
"class PPairDiff(PairOperator):\n",
" \n",
" def __init__(self, feature_left, feature_right, **kwargs):\n",
" super().__init__(feature_left, feature_right)\n",
" self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n",
" \n",
"\n",
" def _load_internal(self, instrument, start_index, end_index, *args):\n",
" assert any(\n",
" [isinstance(self.feature_left, Expression), self.feature_right, Expression]\n",
" ), \"at least one of two inputs is Expression instance\"\n",
"\n",
" if isinstance(self.feature_left, Expression):\n",
" series_left = self.feature_left.load(instrument, start_index, end_index, *args)\n",
" else:\n",
" series_left = self.feature_left # numeric value\n",
" if isinstance(self.feature_right, Expression):\n",
" series_right = self.feature_right.load(instrument, start_index, end_index, *args)\n",
" else:\n",
" series_right = self.feature_right\n",
"\n",
" if self.N == 0:\n",
" series = getattr(series_left.expanding(min_periods=1), self.func)(series_right)\n",
" else:\n",
" series = getattr(series_left.rolling(self.N, min_periods=1), self.func)(series_right)\n",
" return series\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c7272be7-8df1-47b0-b18e-b631aca6e3cd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[41422:MainThread](2022-07-14 15:01:12,046) INFO - qlib.Initialization - [config.py:413] - default_conf: client.\n",
"[41422:MainThread](2022-07-14 15:01:12,053) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.\n",
"[41422:MainThread](2022-07-14 15:01:12,057) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/home/guofu/Workspaces/guofu/TslDataFeed/_data/test/target')}\n"
]
}
],
"source": [
"qlib.init(provider_uri='_data/test/target/', region=REG_CN, custom_ops=[PDiff, PPairDiff])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "3939574b-80f4-48db-bd16-1636f92b2e02",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>PDiff($$净利润_q, skip_q1=True)</th>\n",
" </tr>\n",
" <tr>\n",
" <th>instrument</th>\n",
" <th>datetime</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"6\" valign=\"top\">sh600000</th>\n",
" <th>2021-03-26</th>\n",
" <td>1.593600e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-03-29</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-03-30</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-03-31</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-04-01</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2021-04-02</th>\n",
" <td>1.380300e+10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" PDiff($$净利润_q, skip_q1=True)\n",
"instrument datetime \n",
"sh600000 2021-03-26 1.593600e+10\n",
" 2021-03-29 1.380300e+10\n",
" 2021-03-30 1.380300e+10\n",
" 2021-03-31 1.380300e+10\n",
" 2021-04-01 1.380300e+10\n",
" 2021-04-02 1.380300e+10"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"D.features(['sh600000'], ['PDiff($$净利润_q, skip_q1=True)'], start_time='2021-03-26', end_time='2021-04-02', freq=\"day\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "679d0bcd-6975-43f3-b394-5e2e775a9b8f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 21,
"id": "c1ef5dbb-930f-4d2d-ac3a-5287ddac6d6c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(202003 - 2) % 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c01d2ab-eafa-4492-bbe5-6e3c8382bd3c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}