{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "f15b133a-12f3-4803-830d-820a471add0e", "metadata": {}, "outputs": [], "source": [ "import qlib\n", "from qlib.data import D\n", "from qlib.constant import REG_CN\n", "\n", "from qlib.data.pit import P, PRef" ] }, { "cell_type": "code", "execution_count": 4, "id": "b6e1089e-e51d-42ac-aee8-53d1a886a355", "metadata": {}, "outputs": [], "source": [ "from pprint import pprint \n", "\n", "class PRelRef(P):\n", " \n", " def __init__(self, feature, rel_period):\n", " super().__init__(feature)\n", " self.rel_period = rel_period\n", " self.unit = unit\n", " \n", " def __str__(self):\n", " return f\"{super().__str__()}[{self.rel_period, self.unit}]\"\n", " \n", " def _load_feature(self, instrucument, start_index, end_index, cur_time):\n", " #pprint(f\"{start_index}, {end_index}\")\n", " #pprint(f\"{self.feature.get_longest_back_rolling()}, {self.feature.get_extended_window_size()}\")\n", " #pprint(f\"{cur_time}, {self.rel_period}, {self.unit}\")\n", " return self.feature.load(instrucument, start_index, end_index, cur_time, self.rel_period)" ] }, { "cell_type": "code", "execution_count": 15, "id": "8730a9fb-9356-4847-b33d-370a3095df04", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "from qlib.data.ops import ElemOperator\n", "from qlib.data.data import Cal\n", "\n", "def is_of_quarter(period:int, quarter:int) -> bool:\n", " return (period - quarter) % 100 == 0\n", "\n", "\n", "class PDiff(P):\n", " \"\"\"\n", " 还是继承P,而不是EleOperator,以减少麻烦。\n", " \"\"\"\n", " \n", " def __init__(self, feature, **kwargs):\n", " super().__init__(feature)\n", " self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n", " self.skip_q1 = False if 'skip_q1' not in kwargs else kwargs['skip_q1']\n", " \n", " def _load_internal(self, instrument, start_index, end_index, freq):\n", " _calendar = Cal.calendar(freq=freq)\n", " resample_data = np.empty(end_index - start_index + 1, dtype=\"float32\")\n", " \n", " # 对日期区间逐一循环,考虑到使用PIT数据的模型一般最多到日频,单个股票序列长度最多到千级\n", " for cur_index in range(start_index, end_index + 1):\n", " cur_time = _calendar[cur_index]\n", " # To load expression accurately, more historical data are required\n", " start_ws, end_ws = self.get_extended_window_size()\n", " if end_ws > 0:\n", " raise ValueError(\n", " \"PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported\"\n", " )\n", "\n", " # The calculated value will always the last element, so the end_offset is zero.\n", " try:\n", " s = self._load_feature(instrument, -start_ws, 0, cur_time)\n", " pprint(s)\n", " # 满足不需要做diff的条件:在需要跳过一季度的前提下,当前引用的财报期确实为一季度\n", " if self.skip_q1 or is_of_quarter(s.index[-1], 1):\n", " resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan\n", " else:\n", " resample_data[cur_index - start_index] = (s.iloc[-1] - s.iloc[-2]) if len(s) > 1 else np.nan\n", " except FileNotFoundError:\n", " get_module_logger(\"base\").warning(f\"WARN: period data not found for {str(self)}\")\n", " return pd.Series(dtype=\"float32\", name=str(self))\n", "\n", " resample_series = pd.Series(\n", " resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype=\"float32\", name=str(self)\n", " )\n", " return resample_series\n", "\n", " def get_longest_back_rolling(self):\n", " return self.feature.get_longest_back_rolling() + self.rel_period\n", " \n", " def get_extended_window_size(self):\n", " # 这里需要考虑的是feature的windows size,而不仅仅是自身的windows size\n", " lft_etd, rght_etd = self.feature.get_extended_window_size()\n", " return lft_etd + self.rel_period, rght_etd" ] }, { "cell_type": "code", "execution_count": 16, "id": "abe001d9-ccb4-48a8-b5d0-91ef8862b14f", "metadata": {}, "outputs": [], "source": [ "\n", "class PPairDiff(PairOperator):\n", " \n", " def __init__(self, feature_left, feature_right, **kwargs):\n", " super().__init__(feature_left, feature_right)\n", " self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n", " \n", "\n", " def _load_internal(self, instrument, start_index, end_index, *args):\n", " assert any(\n", " [isinstance(self.feature_left, Expression), self.feature_right, Expression]\n", " ), \"at least one of two inputs is Expression instance\"\n", "\n", " if isinstance(self.feature_left, Expression):\n", " series_left = self.feature_left.load(instrument, start_index, end_index, *args)\n", " else:\n", " series_left = self.feature_left # numeric value\n", " if isinstance(self.feature_right, Expression):\n", " series_right = self.feature_right.load(instrument, start_index, end_index, *args)\n", " else:\n", " series_right = self.feature_right\n", "\n", " if self.N == 0:\n", " series = getattr(series_left.expanding(min_periods=1), self.func)(series_right)\n", " else:\n", " series = getattr(series_left.rolling(self.N, min_periods=1), self.func)(series_right)\n", " return series\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 17, "id": "c7272be7-8df1-47b0-b18e-b631aca6e3cd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[41422:MainThread](2022-07-14 15:01:12,046) INFO - qlib.Initialization - [config.py:413] - default_conf: client.\n", "[41422:MainThread](2022-07-14 15:01:12,053) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.\n", "[41422:MainThread](2022-07-14 15:01:12,057) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/home/guofu/Workspaces/guofu/TslDataFeed/_data/test/target')}\n" ] } ], "source": [ "qlib.init(provider_uri='_data/test/target/', region=REG_CN, custom_ops=[PDiff, PPairDiff])" ] }, { "cell_type": "code", "execution_count": 20, "id": "3939574b-80f4-48db-bd16-1636f92b2e02", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | \n", " | PDiff($$净利润_q, skip_q1=True) | \n", "
---|---|---|
instrument | \n", "datetime | \n", "\n", " |
sh600000 | \n", "2021-03-26 | \n", "1.593600e+10 | \n", "
2021-03-29 | \n", "1.380300e+10 | \n", "|
2021-03-30 | \n", "1.380300e+10 | \n", "|
2021-03-31 | \n", "1.380300e+10 | \n", "|
2021-04-01 | \n", "1.380300e+10 | \n", "|
2021-04-02 | \n", "1.380300e+10 | \n", "