{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "f15b133a-12f3-4803-830d-820a471add0e", "metadata": {}, "outputs": [], "source": [ "import qlib\n", "from qlib.data import D\n", "from qlib.constant import REG_CN\n", "\n", "from qlib.data.pit import P, PRef" ] }, { "cell_type": "code", "execution_count": 4, "id": "b6e1089e-e51d-42ac-aee8-53d1a886a355", "metadata": {}, "outputs": [], "source": [ "from pprint import pprint \n", "\n", "class PRelRef(P):\n", " \n", " def __init__(self, feature, rel_period):\n", " super().__init__(feature)\n", " self.rel_period = rel_period\n", " self.unit = unit\n", " \n", " def __str__(self):\n", " return f\"{super().__str__()}[{self.rel_period, self.unit}]\"\n", " \n", " def _load_feature(self, instrucument, start_index, end_index, cur_time):\n", " #pprint(f\"{start_index}, {end_index}\")\n", " #pprint(f\"{self.feature.get_longest_back_rolling()}, {self.feature.get_extended_window_size()}\")\n", " #pprint(f\"{cur_time}, {self.rel_period}, {self.unit}\")\n", " return self.feature.load(instrucument, start_index, end_index, cur_time, self.rel_period)" ] }, { "cell_type": "code", "execution_count": 15, "id": "8730a9fb-9356-4847-b33d-370a3095df04", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "from qlib.data.ops import ElemOperator\n", "from qlib.data.data import Cal\n", "\n", "def is_of_quarter(period:int, quarter:int) -> bool:\n", " return (period - quarter) % 100 == 0\n", "\n", "\n", "class PDiff(P):\n", " \"\"\"\n", " 还是继承P,而不是EleOperator,以减少麻烦。\n", " \"\"\"\n", " \n", " def __init__(self, feature, **kwargs):\n", " super().__init__(feature)\n", " self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n", " self.skip_q1 = False if 'skip_q1' not in kwargs else kwargs['skip_q1']\n", " \n", " def _load_internal(self, instrument, start_index, end_index, freq):\n", " _calendar = Cal.calendar(freq=freq)\n", " resample_data = np.empty(end_index - start_index + 1, dtype=\"float32\")\n", " \n", " # 对日期区间逐一循环,考虑到使用PIT数据的模型一般最多到日频,单个股票序列长度最多到千级\n", " for cur_index in range(start_index, end_index + 1):\n", " cur_time = _calendar[cur_index]\n", " # To load expression accurately, more historical data are required\n", " start_ws, end_ws = self.get_extended_window_size()\n", " if end_ws > 0:\n", " raise ValueError(\n", " \"PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported\"\n", " )\n", "\n", " # The calculated value will always the last element, so the end_offset is zero.\n", " try:\n", " s = self._load_feature(instrument, -start_ws, 0, cur_time)\n", " pprint(s)\n", " # 满足不需要做diff的条件:在需要跳过一季度的前提下,当前引用的财报期确实为一季度\n", " if self.skip_q1 or is_of_quarter(s.index[-1], 1):\n", " resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan\n", " else:\n", " resample_data[cur_index - start_index] = (s.iloc[-1] - s.iloc[-2]) if len(s) > 1 else np.nan\n", " except FileNotFoundError:\n", " get_module_logger(\"base\").warning(f\"WARN: period data not found for {str(self)}\")\n", " return pd.Series(dtype=\"float32\", name=str(self))\n", "\n", " resample_series = pd.Series(\n", " resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype=\"float32\", name=str(self)\n", " )\n", " return resample_series\n", "\n", " def get_longest_back_rolling(self):\n", " return self.feature.get_longest_back_rolling() + self.rel_period\n", " \n", " def get_extended_window_size(self):\n", " # 这里需要考虑的是feature的windows size,而不仅仅是自身的windows size\n", " lft_etd, rght_etd = self.feature.get_extended_window_size()\n", " return lft_etd + self.rel_period, rght_etd" ] }, { "cell_type": "code", "execution_count": 16, "id": "abe001d9-ccb4-48a8-b5d0-91ef8862b14f", "metadata": {}, "outputs": [], "source": [ "\n", "class PPairDiff(PairOperator):\n", " \n", " def __init__(self, feature_left, feature_right, **kwargs):\n", " super().__init__(feature_left, feature_right)\n", " self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n", " \n", "\n", " def _load_internal(self, instrument, start_index, end_index, *args):\n", " assert any(\n", " [isinstance(self.feature_left, Expression), self.feature_right, Expression]\n", " ), \"at least one of two inputs is Expression instance\"\n", "\n", " if isinstance(self.feature_left, Expression):\n", " series_left = self.feature_left.load(instrument, start_index, end_index, *args)\n", " else:\n", " series_left = self.feature_left # numeric value\n", " if isinstance(self.feature_right, Expression):\n", " series_right = self.feature_right.load(instrument, start_index, end_index, *args)\n", " else:\n", " series_right = self.feature_right\n", "\n", " if self.N == 0:\n", " series = getattr(series_left.expanding(min_periods=1), self.func)(series_right)\n", " else:\n", " series = getattr(series_left.rolling(self.N, min_periods=1), self.func)(series_right)\n", " return series\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 17, "id": "c7272be7-8df1-47b0-b18e-b631aca6e3cd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[41422:MainThread](2022-07-14 15:01:12,046) INFO - qlib.Initialization - [config.py:413] - default_conf: client.\n", "[41422:MainThread](2022-07-14 15:01:12,053) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.\n", "[41422:MainThread](2022-07-14 15:01:12,057) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/home/guofu/Workspaces/guofu/TslDataFeed/_data/test/target')}\n" ] } ], "source": [ "qlib.init(provider_uri='_data/test/target/', region=REG_CN, custom_ops=[PDiff, PPairDiff])" ] }, { "cell_type": "code", "execution_count": 20, "id": "3939574b-80f4-48db-bd16-1636f92b2e02", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PDiff($$净利润_q, skip_q1=True)
instrumentdatetime
sh6000002021-03-261.593600e+10
2021-03-291.380300e+10
2021-03-301.380300e+10
2021-03-311.380300e+10
2021-04-011.380300e+10
2021-04-021.380300e+10
\n", "
" ], "text/plain": [ " PDiff($$净利润_q, skip_q1=True)\n", "instrument datetime \n", "sh600000 2021-03-26 1.593600e+10\n", " 2021-03-29 1.380300e+10\n", " 2021-03-30 1.380300e+10\n", " 2021-03-31 1.380300e+10\n", " 2021-04-01 1.380300e+10\n", " 2021-04-02 1.380300e+10" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "D.features(['sh600000'], ['PDiff($$净利润_q, skip_q1=True)'], start_time='2021-03-26', end_time='2021-04-02', freq=\"day\")" ] }, { "cell_type": "code", "execution_count": null, "id": "679d0bcd-6975-43f3-b394-5e2e775a9b8f", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 21, "id": "c1ef5dbb-930f-4d2d-ac3a-5287ddac6d6c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(202003 - 2) % 100" ] }, { "cell_type": "code", "execution_count": null, "id": "1c01d2ab-eafa-4492-bbe5-6e3c8382bd3c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }