commit
0f18a40b56
@ -0,0 +1,168 @@
|
||||
_data/
|
||||
|
||||
*.pkl
|
||||
*.csv
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -0,0 +1,7 @@
|
||||
from collections import namedtuple
|
||||
|
||||
TINYSOFT_DATA_PATH = '../_data/tinysoft-data/'
|
||||
|
||||
DateRange = namedtuple('DateRange', ['start_date', 'end_date'])
|
||||
|
||||
|
@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# # 导入sqlserver
|
||||
|
||||
# In[77]:
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import os
|
||||
import numpy as np
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.types import CHAR, INT, VARCHAR, FLOAT, TEXT
|
||||
import time
|
||||
|
||||
filelist = os.listdir('D:/数据/天软基本面数据/46.合并利润分配表')
|
||||
engine = create_engine('mssql+pymssql://sa:admin@192.168.1.128/test')
|
||||
conn = engine.connect()
|
||||
for i in range(2451, 2452):
|
||||
file = filelist[i]
|
||||
data = pd.read_csv('D:/数据/天软基本面数据/46.合并利润分配表/' + file)
|
||||
data['备注'] = data['备注'].astype(str) # 强制转换备注列为str格式
|
||||
# data['主营业务利润'] = data['主营业务利润'].astype(np.float64)
|
||||
# data['预警详情'] = data['预警详情'].astype(str)
|
||||
title_list = pd.read_excel('D:/数据/天软基本面数据/wind-天软字段匹配/天软中英文字段对照表.xlsx',
|
||||
'合并利润分配表')
|
||||
eng_name = title_list['英文名']
|
||||
dtype = title_list['数据类型']
|
||||
data.columns = eng_name # 替换data字段名为英文字段
|
||||
dtype_dict = {'INT': INT(),
|
||||
'CHAR(8)': CHAR(8),
|
||||
'VARCHAR(20)': VARCHAR(20),
|
||||
'TEXT': TEXT(),
|
||||
'FLOAT': FLOAT()} # 原表数据类型和sql数据类型转换映射
|
||||
dtype_list = dtype.map(dtype_dict)
|
||||
dtype_list = pd.concat([eng_name, dtype_list], axis=1)
|
||||
dtype_list = dtype_list.set_index('英文名').to_dict()['数据类型']
|
||||
data.to_sql('CONSOLIDATED_INCOMESTATEMENT',
|
||||
conn,
|
||||
index=False,
|
||||
if_exists='append',
|
||||
dtype=dtype_list)
|
||||
time.sleep(0.1)
|
||||
conn.close()
|
||||
engine.dispose()
|
||||
|
||||
|
||||
# # 转换列属性
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
import pymssql
|
||||
conn = pymssql.connect('192.168.1.128', 'sa', 'admin', "test") #tempdb is a database
|
||||
cursor = conn.cursor()
|
||||
sqls = ['ALTER TABLE EARNINGS_PREANNOUNCEMENT ALTER COLUMN StockID CHAR(8) NOT NULL',
|
||||
'ALTER TABLE EARNINGS_PREANNOUNCEMENT ALTER COLUMN S_PROFITNOTICE_PERIOD INT NOT NULL',
|
||||
# 'ALTER TABLE PRELIMINARY_EARNING_EST ALTER COLUMN REPORT_DATE INT NOT NULL',
|
||||
'ALTER TABLE EARNINGS_PREANNOUNCEMENT ALTER COLUMN S_PROFITNOTICE_DATE INT NOT NULL',
|
||||
'''ALTER TABLE EARNINGS_PREANNOUNCEMENT ADD CONSTRAINT STOCK_EP_ID PRIMARY KEY
|
||||
(StockID,S_PROFITNOTICE_PERIOD,S_PROFITNOTICE_DATE)''',
|
||||
'CREATE NONCLUSTERED INDEX STOCK_INTID ON EARNINGS_PREANNOUNCEMENT (StockID_INT)']
|
||||
for sql in sqls:
|
||||
cursor.execute(sql) # create a new sheet
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
# In[60]:
|
||||
|
||||
|
||||
conn = pymssql.connect('192.168.1.128', 'sa', 'admin', "test") #tempdb is a database
|
||||
cursor = conn.cursor()
|
||||
sql1 = '''
|
||||
sp_rename 'DIVIDEND_ANNOUNCEMENT.REPORT_PERIOD', 'ReportPeriod', 'column'
|
||||
'''
|
||||
# sql2 = '''
|
||||
# sp_rename 'DIVIDEND_ANNOUNCEMENT.S_PROFITNOTICE_DATE', 'AppearAtDate', 'column'
|
||||
# '''
|
||||
sql3 = '''
|
||||
drop index STOCK_INTID on DIVIDEND_ANNOUNCEMENT
|
||||
'''
|
||||
sql4 = '''
|
||||
alter table DIVIDEND_ANNOUNCEMENT drop column StockID_INT
|
||||
'''
|
||||
cursor.execute(sql1)
|
||||
# cursor.execute(sql2)
|
||||
cursor.execute(sql3)
|
||||
cursor.execute(sql4)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
# # 创建view
|
||||
|
||||
# In[184]:
|
||||
|
||||
|
||||
conn = pymssql.connect('192.168.1.128', 'sa', 'admin', "test") #tempdb is a database
|
||||
cursor = conn.cursor()
|
||||
# sql = '''
|
||||
# CREATE VIEW LAYER1 AS
|
||||
# SELECT A.StockID, A.REPORT_PERIOD, A.ACTUAL_ANN_DT FROM
|
||||
# (
|
||||
# SELECT *, ROW_NUMBER() OVER(PARTITION BY StockID, REPORT_PERIOD order by StockID, REPORT_PERIOD)
|
||||
# AS RowNumber FROM FINCOMP_CASHFLOWSTATEMENT
|
||||
# )A
|
||||
# WHERE A.RowNumber=1
|
||||
# '''
|
||||
sql = '''
|
||||
CREATE VIEW DIVMeta AS
|
||||
SELECT StockID,
|
||||
RIGHT(StockID,6) + '.' + LEFT(StockID,2) AS WIND_CODE,
|
||||
CONVERT(INT, RIGHT(StockID, 6)) AS IntCode,
|
||||
ReportPeriod FROM DIVIDEND_ANNOUNCEMENT
|
||||
'''
|
||||
cursor.execute(sql)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
# In[185]:
|
||||
|
||||
|
||||
import re
|
||||
wind_index = pd.read_excel('D:/数据/天软基本面数据/wind-天软字段匹配/已匹配索引总表.xlsx',
|
||||
'分红送股')
|
||||
view_column = list(wind_index['字段名']) # wind英文字段, AS后字段
|
||||
view_joinindex = list(wind_index['天软字段']) # 天软对应wind中文字段和计算字段
|
||||
ts_index = pd.read_excel('D:/数据/天软基本面数据/wind-天软字段匹配/天软中英文字段对照表.xlsx',
|
||||
'分红送股')
|
||||
sql_matching = []
|
||||
for i in range(len(view_column)):
|
||||
if 'exp' in view_joinindex[i]:
|
||||
expresion = view_joinindex[i]
|
||||
ts_columns = re.split('\$| ', expresion)
|
||||
ts_columns = [col for col in ts_columns if u'\u4e00' <= col <= u'\u9fff'] #解析出所有中文字段
|
||||
for col in ts_columns:
|
||||
expresion = expresion.replace(col, 'B.' + ts_index.loc[ts_index['天软字段'] == col, '英文名'].values[0])
|
||||
expresion = expresion.replace('exp:', '')
|
||||
expresion = expresion.replace('$', '')
|
||||
sql_matching.append(expresion)
|
||||
sql_matching.append(' AS ' + view_column[i] + ',')
|
||||
else:
|
||||
sql_matching.append('B.' + ts_index.loc[ts_index['天软字段'] == view_joinindex[i], '英文名'].values[0] +
|
||||
' AS ' + view_column[i] + ',')
|
||||
sql_matching[-1] = sql_matching[-1].replace(',', '')
|
||||
sql_matching = ''.join(sql_matching)
|
||||
|
||||
|
||||
# In[186]:
|
||||
|
||||
|
||||
conn = pymssql.connect('192.168.1.128', 'sa', 'admin', "test") #tempdb is a database
|
||||
cursor = conn.cursor()
|
||||
sql = 'CREATE VIEW DIVWind AS ' + 'SELECT A.WIND_CODE, A.IntCode, A.ReportPeriod, ' + sql_matching + ' FROM DIVMeta A LEFT JOIN DIVIDEND_ANNOUNCEMENT B ON ' + 'A.StockID = B.StockID AND A.ReportPeriod = B.ReportPeriod'
|
||||
cursor.execute(sql)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
# # 删除view
|
||||
|
||||
# In[134]:
|
||||
|
||||
|
||||
import pymssql
|
||||
conn = pymssql.connect('192.168.1.128', 'sa', 'admin', "test") #tempdb is a database
|
||||
cursor = conn.cursor()
|
||||
# sql = '''
|
||||
# SELECT RIGHT(StockID,6) + '.' + LEFT(StockID,2) FROM CONSOLIDATED_BALANCESHEET
|
||||
# '''
|
||||
sql = '''
|
||||
IF EXISTS(SELECT * FROM SYS.VIEWS WHERE NAME='CBSBeforeAdj')
|
||||
DROP VIEW CBSBeforeAdj
|
||||
'''
|
||||
cursor.execute(sql)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
@ -0,0 +1,135 @@
|
||||
import os, time
|
||||
from tqdm import tqdm
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from tsl import *
|
||||
from config import *
|
||||
|
||||
class DLFinancial:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self.table_id_to_table_name = {
|
||||
18 : '分红送股',
|
||||
40 : '业绩预测',
|
||||
41 : '业绩快报',
|
||||
42 : '主要财务指标',
|
||||
44 : '合并资产负债表',
|
||||
46 : '合并利润分配表',
|
||||
48 : '合并现金流量表',
|
||||
56 : '金融公司资产负债表',
|
||||
58 : '金融公司利润分配表',
|
||||
60 : '金融公司现金流量表',
|
||||
}
|
||||
|
||||
self.table_id_to_index_list = {
|
||||
18 : ['StockID', 'StockName', '截止日'],
|
||||
40 : ['StockID', 'StockName', '截止日', '公布日'],
|
||||
41 : ['StockID', 'StockName', '截止日', '公布日'],
|
||||
42 : ['StockID', 'StockName', '截止日', '公布日'],
|
||||
44 : ['StockID', 'StockName', '截止日', '数据报告期', '公布日'],
|
||||
46 : ['StockID', 'StockName', '截止日', '数据报告期', '公布日'],
|
||||
48 : ['StockID', 'StockName', '截止日', '数据报告期', '公布日'],
|
||||
56 : ['StockID', 'StockName', '截止日', '数据报告期', '公布日'],
|
||||
58 : ['StockID', 'StockName', '截止日', '数据报告期', '公布日'],
|
||||
60 : ['StockID', 'StockName', '截止日', '数据报告期', '公布日'],
|
||||
}
|
||||
|
||||
self.config_name_list = [
|
||||
'每股指标',
|
||||
'盈利能力',
|
||||
'偿债能力',
|
||||
'资本结构',
|
||||
'经营能力',
|
||||
'投资收益',
|
||||
'成长能力',
|
||||
'现金流指标',
|
||||
'资产负债表结构',
|
||||
'利润分配表结构',
|
||||
'现金流量表结构',
|
||||
'估值指标',
|
||||
]
|
||||
|
||||
|
||||
def do_fin_report(self):
|
||||
for table_id in self.table_id_to_table_name.keys():
|
||||
self._dump_fin_report(table_id, 20000101)
|
||||
|
||||
|
||||
def _dump_fin_report(self, table_id, report_start_date):
|
||||
|
||||
table_name = self.table_id_to_table_name[table_id]
|
||||
dump_folder = '{}/基础报表/{}.{}/'.format(TINYSOFT_DATA_PATH, table_id, table_name)
|
||||
Path(dump_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _dump_df_to_csv(dump_folder, stock_id, df):
|
||||
dump_path = '{}/{}.csv'.format(dump_folder, stock_id)
|
||||
|
||||
index_cols = self.table_id_to_index_list[table_id]
|
||||
df.set_index(index_cols, inplace=True)
|
||||
df.to_csv(dump_path)
|
||||
|
||||
with tsl() as ts:
|
||||
stock_list = ts.get_stock_list()
|
||||
print('正在获取数据:', table_id, table_name)
|
||||
|
||||
with tqdm(stock_list) as pbar:
|
||||
for stock_id in pbar:
|
||||
pbar.set_description(dump_folder + stock_id)
|
||||
|
||||
df = ts.get_cmp_report(
|
||||
table_id=table_id,
|
||||
stock_id=stock_id,
|
||||
start_date=report_start_date
|
||||
)
|
||||
|
||||
if df.shape[0] == 0 or df.shape[1] == 0:
|
||||
print('{}的{}为空数据'.format(stock_id, table_name))
|
||||
continue
|
||||
|
||||
_dump_df_to_csv(dump_folder, stock_id, df)
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
def _dump_common_info(self):
|
||||
pass
|
||||
|
||||
|
||||
def _dump_derived_indicators(self, report_start_date):
|
||||
|
||||
def _dump_df_to_csv(config_name, stock_id, df):
|
||||
dump_path = '{}/衍生指标/{}/{}.csv'.format(
|
||||
TINYSOFT_DATA_PATH, config_name, stock_id)
|
||||
df.to_csv(dump_path)
|
||||
|
||||
with tsl() as ts:
|
||||
stock_list = ts.get_stock_list()
|
||||
|
||||
for config_name in config_name_list:
|
||||
print('正在获取数据:', config_name)
|
||||
|
||||
with tqdm(stock_list) as pbar:
|
||||
for stock_id in pbar:
|
||||
pbar.set_description(stock_id)
|
||||
|
||||
df = ts.get_cmp_indicator(
|
||||
stock_id=stock_id,
|
||||
start_year=report_start_year,
|
||||
indicator_config_fname='{}/indicator-config/{}.csv'.format(
|
||||
TINYSOFT_DATA_PATH,
|
||||
config_name
|
||||
)
|
||||
)
|
||||
_dump_df_to_csv(config_name, stock_id, df)
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
DLFinancial().do_fin_report()
|
||||
|
@ -0,0 +1,61 @@
|
||||
import os.path
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from tsl import *
|
||||
from market import *
|
||||
from config import *
|
||||
|
||||
def merge_k_daily():
|
||||
|
||||
shard_dump_folder = '{}/行情数据/日K线/shards/'.format(TINYSOFT_DATA_PATH)
|
||||
merged_dump_folder = '{}/行情数据/日K线/merged/'.format(TINYSOFT_DATA_PATH)
|
||||
|
||||
DLMarket.k_daily_data_shards
|
||||
|
||||
with tsl() as t:
|
||||
stock_list = t.get_stock_list()
|
||||
|
||||
with tqdm(stock_list) as pbar:
|
||||
|
||||
for stock_id in pbar:
|
||||
pbar.set_description('正在处理', stock_id)
|
||||
|
||||
df_list = []
|
||||
for date_range in Market.k_daily_data_shards:
|
||||
|
||||
start_date, end_date = \
|
||||
date_range[0], date_range[1]
|
||||
shard_name = str(start_date) + '-' + str(end_date)
|
||||
|
||||
load_path = '{}/{}/{}.csv'.format(
|
||||
shard_dump_folder,
|
||||
shard_name,
|
||||
stock_id
|
||||
)
|
||||
|
||||
pbar.set_description('正在载入', load_path)
|
||||
if not os.path.exists(load_path):
|
||||
continue
|
||||
|
||||
_df = pd.read_csv(load_path)
|
||||
|
||||
if len(_df) > 0:
|
||||
df_list.append(_df)
|
||||
|
||||
pbar.set_description('正在拼接', stock_id)
|
||||
if len(df_list) > 0:
|
||||
df = pd.concat(df_list, axis=0)
|
||||
df.set_index(['StockID', 'date'], inplace=True)
|
||||
|
||||
dump_path = '{}/{}.csv'.format(
|
||||
merged_dump_folder,
|
||||
stock_id
|
||||
)
|
||||
df.to_csv(dump_path)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
merge_k_daily()
|
@ -0,0 +1,69 @@
|
||||
|
||||
- [Download Qlib Data](#Download-Qlib-Data)
|
||||
- [Download CN Data](#Download-CN-Data)
|
||||
- [Download US Data](#Download-US-Data)
|
||||
- [Download CN Simple Data](#Download-CN-Simple-Data)
|
||||
- [Help](#Help)
|
||||
- [Using in Qlib](#Using-in-Qlib)
|
||||
- [US data](#US-data)
|
||||
- [CN data](#CN-data)
|
||||
|
||||
|
||||
## Download Qlib Data
|
||||
|
||||
|
||||
### Download CN Data
|
||||
|
||||
```bash
|
||||
# daily data
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# 1min data (Optional for running non-high-frequency strategies)
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
```
|
||||
|
||||
### Download US Data
|
||||
|
||||
|
||||
```bash
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
```
|
||||
|
||||
### Download CN Simple Data
|
||||
|
||||
```bash
|
||||
python get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
```
|
||||
|
||||
### Help
|
||||
|
||||
```bash
|
||||
python get_data.py qlib_data --help
|
||||
```
|
||||
|
||||
## Using in Qlib
|
||||
> For more information: https://qlib.readthedocs.io/en/latest/start/initialization.html
|
||||
|
||||
|
||||
### US data
|
||||
|
||||
> Need to download data first: [Download US Data](#Download-US-Data)
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.config import REG_US
|
||||
provider_uri = "~/.qlib/qlib_data/us_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_US)
|
||||
```
|
||||
|
||||
### CN data
|
||||
|
||||
> Need to download data first: [Download CN Data](#Download-CN-Data)
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.constant import REG_CN
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
```
|
@ -0,0 +1,143 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
import fire
|
||||
import datacompy
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CheckBin:
|
||||
|
||||
NOT_IN_FEATURES = "not in features"
|
||||
COMPARE_FALSE = "compare False"
|
||||
COMPARE_TRUE = "compare True"
|
||||
COMPARE_ERROR = "compare error"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qlib_dir: str,
|
||||
csv_path: str,
|
||||
check_fields: str = None,
|
||||
freq: str = "day",
|
||||
symbol_field_name: str = "symbol",
|
||||
date_field_name: str = "date",
|
||||
file_suffix: str = ".csv",
|
||||
max_workers: int = 16,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir : str
|
||||
qlib dir
|
||||
csv_path : str
|
||||
origin csv path
|
||||
check_fields : str, optional
|
||||
check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin
|
||||
freq : str, optional
|
||||
freq, value from ["day", "1m"]
|
||||
symbol_field_name: str, optional
|
||||
symbol field name, by default "symbol"
|
||||
date_field_name: str, optional
|
||||
date field name, by default "date"
|
||||
file_suffix: str, optional
|
||||
csv file suffix, by default ".csv"
|
||||
max_workers: int, optional
|
||||
max workers, by default 16
|
||||
"""
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
bin_path_list = list(self.qlib_dir.joinpath("features").iterdir())
|
||||
self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list))
|
||||
qlib.init(
|
||||
provider_uri=str(self.qlib_dir.resolve()),
|
||||
mount_path=str(self.qlib_dir.resolve()),
|
||||
auto_mount=False,
|
||||
redis_port=-1,
|
||||
)
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
|
||||
if check_fields is None:
|
||||
check_fields = list(map(lambda x: x.name.split(".")[0], bin_path_list[0].glob(f"*.bin")))
|
||||
else:
|
||||
check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields
|
||||
self.check_fields = list(map(lambda x: x.strip(), check_fields))
|
||||
self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields))
|
||||
self.max_workers = max_workers
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.date_field_name = date_field_name
|
||||
self.freq = freq
|
||||
self.file_suffix = file_suffix
|
||||
|
||||
def _compare(self, file_path: Path):
|
||||
symbol = file_path.name.strip(self.file_suffix)
|
||||
if symbol.lower() not in self.qlib_symbols:
|
||||
return self.NOT_IN_FEATURES
|
||||
# qlib data
|
||||
qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq)
|
||||
qlib_df.rename(columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True)
|
||||
# csv data
|
||||
origin_df = pd.read_csv(file_path)
|
||||
origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name])
|
||||
if self.symbol_field_name not in origin_df.columns:
|
||||
origin_df[self.symbol_field_name] = symbol
|
||||
origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True)
|
||||
origin_df.index.names = qlib_df.index.names
|
||||
origin_df = origin_df.reindex(qlib_df.index)
|
||||
try:
|
||||
compare = datacompy.Compare(
|
||||
origin_df,
|
||||
qlib_df,
|
||||
on_index=True,
|
||||
abs_tol=1e-08, # Optional, defaults to 0
|
||||
rel_tol=1e-05, # Optional, defaults to 0
|
||||
df1_name="Original", # Optional, defaults to 'df1'
|
||||
df2_name="New", # Optional, defaults to 'df2'
|
||||
)
|
||||
_r = compare.matches(ignore_extra_columns=True)
|
||||
return self.COMPARE_TRUE if _r else self.COMPARE_FALSE
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol} compare error: {e}")
|
||||
return self.COMPARE_ERROR
|
||||
|
||||
def check(self):
|
||||
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data"""
|
||||
logger.info("start check......")
|
||||
|
||||
error_list = []
|
||||
not_in_features = []
|
||||
compare_false = []
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)):
|
||||
symbol = file_path.name.strip(self.file_suffix)
|
||||
if _check_res == self.NOT_IN_FEATURES:
|
||||
not_in_features.append(symbol)
|
||||
elif _check_res == self.COMPARE_ERROR:
|
||||
error_list.append(symbol)
|
||||
elif _check_res == self.COMPARE_FALSE:
|
||||
compare_false.append(symbol)
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of check......")
|
||||
if error_list:
|
||||
logger.warning(f"compare error: {error_list}")
|
||||
if not_in_features:
|
||||
logger.warning(f"not in features: {not_in_features}")
|
||||
if compare_false:
|
||||
logger.warning(f"compare False: {compare_false}")
|
||||
logger.info(
|
||||
f"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(CheckBin)
|
@ -0,0 +1,71 @@
|
||||
import sys
|
||||
import platform
|
||||
import qlib
|
||||
import fire
|
||||
import pkg_resources
|
||||
from pathlib import Path
|
||||
|
||||
QLIB_PATH = Path(__file__).absolute().resolve().parent.parent
|
||||
|
||||
|
||||
class InfoCollector:
|
||||
"""
|
||||
User could collect system info by following commands
|
||||
`cd scripts && python collect_info.py all`
|
||||
- NOTE: please avoid running this script in the project folder which contains `qlib`
|
||||
"""
|
||||
|
||||
def sys(self):
|
||||
"""collect system related info"""
|
||||
for method in ["system", "machine", "platform", "version"]:
|
||||
print(getattr(platform, method)())
|
||||
|
||||
def py(self):
|
||||
"""collect Python related info"""
|
||||
print("Python version: {}".format(sys.version.replace("\n", " ")))
|
||||
|
||||
def qlib(self):
|
||||
"""collect qlib related info"""
|
||||
print("Qlib version: {}".format(qlib.__version__))
|
||||
REQUIRED = [
|
||||
"numpy",
|
||||
"pandas",
|
||||
"scipy",
|
||||
"requests",
|
||||
"sacred",
|
||||
"python-socketio",
|
||||
"redis",
|
||||
"python-redis-lock",
|
||||
"schedule",
|
||||
"cvxpy",
|
||||
"hyperopt",
|
||||
"fire",
|
||||
"statsmodels",
|
||||
"xlrd",
|
||||
"plotly",
|
||||
"matplotlib",
|
||||
"tables",
|
||||
"pyyaml",
|
||||
"mlflow",
|
||||
"tqdm",
|
||||
"loguru",
|
||||
"lightgbm",
|
||||
"tornado",
|
||||
"joblib",
|
||||
"fire",
|
||||
"ruamel.yaml",
|
||||
]
|
||||
|
||||
for package in REQUIRED:
|
||||
version = pkg_resources.get_distribution(package).version
|
||||
print(f"{package}=={version}")
|
||||
|
||||
def all(self):
|
||||
"""collect all info"""
|
||||
for method in ["sys", "py", "qlib"]:
|
||||
getattr(self, method)()
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(InfoCollector)
|
@ -0,0 +1,60 @@
|
||||
# Data Collector
|
||||
|
||||
## Introduction
|
||||
|
||||
Scripts for data collection
|
||||
|
||||
- yahoo: get *US/CN* stock data from *Yahoo Finance*
|
||||
- fund: get fund data from *http://fund.eastmoney.com*
|
||||
- cn_index: get *CN index* from *http://www.csindex.com.cn*, *CSI300*/*CSI100*
|
||||
- us_index: get *US index* from *https://en.wikipedia.org/wiki*, *SP500*/*NASDAQ100*/*DJIA*/*SP400*
|
||||
- contrib: scripts for some auxiliary functions
|
||||
|
||||
|
||||
## Custom Data Collection
|
||||
|
||||
> Specific implementation reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo
|
||||
|
||||
1. Create a dataset code directory in the current directory
|
||||
2. Add `collector.py`
|
||||
- add collector class:
|
||||
```python
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
class UserCollector(BaseCollector):
|
||||
...
|
||||
```
|
||||
- add normalize class:
|
||||
```python
|
||||
class UserNormalzie(BaseNormalize):
|
||||
...
|
||||
```
|
||||
- add `CLI` class:
|
||||
```python
|
||||
class Run(BaseRun):
|
||||
...
|
||||
```
|
||||
3. add `README.md`
|
||||
4. add `requirements.txt`
|
||||
|
||||
|
||||
## Description of dataset
|
||||
|
||||
| | Basic data |
|
||||
|------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
|
||||
| Features | **Price/Volume**: <br> - $close/$open/$low/$high/$volume/$change/$factor |
|
||||
| Calendar | **\<freq>.txt**: <br> - day.txt<br> - 1min.txt |
|
||||
| Instruments | **\<market>.txt**: <br> - required: **all.txt**; <br> - csi300.txt/csi500.txt/sp500.txt |
|
||||
|
||||
- `Features`: data, **digital**
|
||||
- if not **adjusted**, **factor=1**
|
||||
|
||||
### Data-dependent component
|
||||
|
||||
> To make the component running correctly, the dependent data are required
|
||||
|
||||
| Component | required data |
|
||||
|---------------------------------------------------|--------------------------------|
|
||||
| Data retrieval | Features, Calendar, Instrument |
|
||||
| Backtest | **Features[Price/Volume]**, Calendar, Instruments |
|
@ -0,0 +1,427 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import abc
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type, Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from joblib import Parallel, delayed
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
|
||||
class BaseCollector(abc.ABC):
|
||||
|
||||
CACHE_FLAG = "CACHED"
|
||||
NORMAL_FLAG = "NORMAL"
|
||||
|
||||
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date()
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date()
|
||||
DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D
|
||||
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=1,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
instrument save dir
|
||||
max_workers: int
|
||||
workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.delay = delay
|
||||
self.max_workers = max_workers
|
||||
self.max_collector_count = max_collector_count
|
||||
self.mini_symbol_map = {}
|
||||
self.interval = interval
|
||||
self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0)
|
||||
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
|
||||
self.instrument_list = sorted(set(self.get_instrument_list()))
|
||||
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.instrument_list = self.instrument_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(start_datetime))
|
||||
if start_datetime
|
||||
else getattr(self, f"DEFAULT_START_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(end_datetime))
|
||||
if end_datetime
|
||||
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_instrument_list(self):
|
||||
raise NotImplementedError("rewrite get_instrument_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
"""get data with symbol
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
interval: str
|
||||
value from [1min, 1d]
|
||||
start_datetime: pd.Timestamp
|
||||
end_datetime: pd.Timestamp
|
||||
|
||||
Returns
|
||||
---------
|
||||
pd.DataFrame, "symbol" and "date"in pd.columns
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def sleep(self):
|
||||
time.sleep(self.delay)
|
||||
|
||||
def _simple_collector(self, symbol: str):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
|
||||
"""
|
||||
self.sleep()
|
||||
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
||||
_result = self.NORMAL_FLAG
|
||||
if self.check_data_length > 0:
|
||||
_result = self.cache_small_data(symbol, df)
|
||||
if _result == self.NORMAL_FLAG:
|
||||
self.save_instrument(symbol, df)
|
||||
return _result
|
||||
|
||||
def save_instrument(self, symbol, df: pd.DataFrame):
|
||||
"""save instrument data to file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
instrument code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df is None or df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
instrument_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if instrument_path.exists():
|
||||
_old_df = pd.read_csv(instrument_path)
|
||||
df = pd.concat([_old_df, df], sort=False)
|
||||
df.to_csv(instrument_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) < self.check_data_length:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.check_data_length}!")
|
||||
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return self.CACHE_FLAG
|
||||
else:
|
||||
if symbol in self.mini_symbol_map:
|
||||
self.mini_symbol_map.pop(symbol)
|
||||
return self.NORMAL_FLAG
|
||||
|
||||
def _collector(self, instrument_list):
|
||||
|
||||
error_symbol = []
|
||||
res = Parallel(n_jobs=self.max_workers)(
|
||||
delayed(self._simple_collector)(_inst) for _inst in tqdm(instrument_list)
|
||||
)
|
||||
for _symbol, _result in zip(instrument_list, res):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(instrument_list)}")
|
||||
error_symbol.extend(self.mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector data......")
|
||||
instrument_list = self.instrument_list
|
||||
for i in range(self.max_collector_count):
|
||||
if not instrument_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
instrument_list = self._collector(instrument_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
_df = pd.concat(_df_list, sort=False)
|
||||
if not _df.empty:
|
||||
self.save_instrument(_symbol, _df.drop_duplicates(["date"]).sort_values(["date"]))
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
self.kwargs = kwargs
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[BaseNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[YahooNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
self._end_date = kwargs.get("end_date", None)
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(
|
||||
date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs
|
||||
)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
|
||||
df = df[_mask]
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class BaseRun(abc.ABC):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = Path(self.default_base_dir).joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = Path(self.default_base_dir).joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
self.interval = interval
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def collector_class_name(self):
|
||||
raise NotImplementedError("rewrite collector_class_name")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def normalize_class_name(self):
|
||||
raise NotImplementedError("rewrite normalize_class_name")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
raise NotImplementedError("rewrite default_base_dir")
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=self.interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
**kwargs,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, self.normalize_class_name)
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
**kwargs,
|
||||
)
|
||||
yc.normalize()
|
@ -0,0 +1,61 @@
|
||||
# iBOVESPA History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install the libs from the file `requirements.txt`
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
- `requirements.txt` file was generated using python3.8
|
||||
|
||||
## For the ibovespa (IBOV) index, we have:
|
||||
|
||||
<hr/>
|
||||
|
||||
### Method `get_new_companies`
|
||||
|
||||
#### <b>Index start date</b>
|
||||
|
||||
- The ibovespa index started on 2 January 1968 ([wiki](https://en.wikipedia.org/wiki/%C3%8Dndice_Bovespa)). In order to use this start date in our `bench_start_date(self)` method, two conditions must be satisfied:
|
||||
1) APIs used to download brazilian stocks (B3) historical prices must keep track of such historic data since 2 January 1968
|
||||
|
||||
2) Some website or API must provide, from that date, the historic index composition. In other words, the companies used to build the index .
|
||||
|
||||
As a consequence, the method `bench_start_date(self)` inside `collector.py` was implemented using `pd.Timestamp("2003-01-03")` due to two reasons
|
||||
|
||||
1) The earliest ibov composition that have been found was from the first quarter of 2003. More informations about such composition can be seen on the sections below.
|
||||
|
||||
2) Yahoo finance, one of the libraries used to download symbols historic prices, keeps track from this date forward.
|
||||
|
||||
- Within the `get_new_companies` method, a logic was implemented to get, for each ibovespa component stock, the start date that yahoo finance keeps track of.
|
||||
|
||||
#### <b>Code Logic</b>
|
||||
|
||||
The code does a web scrapping into the B3's [website](https://sistemaswebb3-listados.b3.com.br/indexPage/day/IBOV?language=pt-br), which keeps track of the ibovespa stocks composition on the current day.
|
||||
|
||||
Other approaches, such as `request` and `Beautiful Soup` could have been used. However, the website shows the table with the stocks with some delay, since it uses a script inside of it to obtain such compositions.
|
||||
Alternatively, `selenium` was used to download this stocks' composition in order to overcome this problem.
|
||||
|
||||
Futhermore, the data downloaded from the selenium script was preprocessed so it could be saved into the `csv` format stablished by `scripts/data_collector/index.py`.
|
||||
|
||||
<hr/>
|
||||
|
||||
### Method `get_changes`
|
||||
|
||||
No suitable data source that keeps track of ibovespa's history stocks composition has been found. Except from this [repository](https://github.com/igor17400/IBOV-HCI) which provide such information have been used, however it only provides the data from the 1st quarter of 2003 to 3rd quarter of 2021.
|
||||
|
||||
With that reference, the index's composition can be compared quarter by quarter and year by year and then generate a file that keeps track of which stocks have been removed and which have been added each quarter and year.
|
||||
|
||||
<hr/>
|
||||
|
||||
### Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python collector.py --index_name IBOV --qlib_dir ~/.qlib/qlib_data/br_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
python collector.py --index_name IBOV --qlib_dir ~/.qlib/qlib_data/br_data --method save_new_companies
|
||||
```
|
||||
|
@ -0,0 +1,287 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from functools import partial
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
import datetime
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_instruments
|
||||
|
||||
quarter_dict = {"1Q": "01-03", "2Q": "05-01", "3Q": "09-01"}
|
||||
|
||||
|
||||
class IBOVIndex(IndexBase):
|
||||
|
||||
ibov_index_composition = "https://raw.githubusercontent.com/igor17400/IBOV-HCI/main/historic_composition/{}.csv"
|
||||
years_4_month_periods = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
qlib_dir: [str, Path] = None,
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
):
|
||||
super(IBOVIndex, self).__init__(
|
||||
index_name=index_name, qlib_dir=qlib_dir, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
|
||||
self.today: datetime = datetime.date.today()
|
||||
self.current_4_month_period = self.get_current_4_month_period(self.today.month)
|
||||
self.year = str(self.today.year)
|
||||
self.years_4_month_periods = self.get_four_month_period()
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
The ibovespa index started on 2 January 1968 (wiki), however,
|
||||
no suitable data source that keeps track of ibovespa's history
|
||||
stocks composition has been found. Except from the repo indicated
|
||||
in README. Which keeps track of such information starting from
|
||||
the first quarter of 2003
|
||||
"""
|
||||
return pd.Timestamp("2003-01-03")
|
||||
|
||||
def get_current_4_month_period(self, current_month: int):
|
||||
"""
|
||||
This function is used to calculated what is the current
|
||||
four month period for the current month. For example,
|
||||
If the current month is August 8, its four month period
|
||||
is 2Q.
|
||||
|
||||
OBS: In english Q is used to represent *quarter*
|
||||
which means a three month period. However, in
|
||||
portuguese we use Q to represent a four month period.
|
||||
In other words,
|
||||
|
||||
Jan, Feb, Mar, Apr: 1Q
|
||||
May, Jun, Jul, Aug: 2Q
|
||||
Sep, Oct, Nov, Dez: 3Q
|
||||
|
||||
Parameters
|
||||
----------
|
||||
month : int
|
||||
Current month (1 <= month <= 12)
|
||||
|
||||
Returns
|
||||
-------
|
||||
current_4m_period:str
|
||||
Current Four Month Period (1Q or 2Q or 3Q)
|
||||
"""
|
||||
if current_month < 5:
|
||||
return "1Q"
|
||||
if current_month < 9:
|
||||
return "2Q"
|
||||
if current_month <= 12:
|
||||
return "3Q"
|
||||
else:
|
||||
return -1
|
||||
|
||||
def get_four_month_period(self):
|
||||
"""
|
||||
The ibovespa index is updated every four months.
|
||||
Therefore, we will represent each time period as 2003_1Q
|
||||
which means 2003 first four mount period (Jan, Feb, Mar, Apr)
|
||||
"""
|
||||
four_months_period = ["1Q", "2Q", "3Q"]
|
||||
init_year = 2003
|
||||
now = datetime.datetime.now()
|
||||
current_year = now.year
|
||||
current_month = now.month
|
||||
for year in [item for item in range(init_year, current_year)]:
|
||||
for el in four_months_period:
|
||||
self.years_4_month_periods.append(str(year) + "_" + el)
|
||||
# For current year the logic must be a little different
|
||||
current_4_month_period = self.get_current_4_month_period(current_month)
|
||||
for i in range(int(current_4_month_period[0])):
|
||||
self.years_4_month_periods.append(str(current_year) + "_" + str(i + 1) + "Q")
|
||||
return self.years_4_month_periods
|
||||
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""formatting the datetime in an instrument
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst_df: pd.DataFrame
|
||||
inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst_df: pd.DataFrame
|
||||
|
||||
"""
|
||||
logger.info("Formatting Datetime")
|
||||
if self.freq != "day":
|
||||
inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(
|
||||
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=23, minutes=59)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
else:
|
||||
inst_df[self.START_DATE_FIELD] = inst_df[self.START_DATE_FIELD].apply(
|
||||
lambda x: (pd.Timestamp(x)).strftime("%Y-%m-%d")
|
||||
)
|
||||
|
||||
inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(
|
||||
lambda x: (pd.Timestamp(x)).strftime("%Y-%m-%d")
|
||||
)
|
||||
return inst_df
|
||||
|
||||
def format_quarter(self, cell: str):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
cell: str
|
||||
It must be on the format 2003_1Q --> years_4_month_periods
|
||||
|
||||
Returns
|
||||
----------
|
||||
date: str
|
||||
Returns date in format 2003-03-01
|
||||
"""
|
||||
cell_split = cell.split("_")
|
||||
return cell_split[0] + "-" + quarter_dict[cell_split[1]]
|
||||
|
||||
def get_changes(self):
|
||||
"""
|
||||
Access the index historic composition and compare it quarter
|
||||
by quarter and year by year in order to generate a file that
|
||||
keeps track of which stocks have been removed and which have
|
||||
been added.
|
||||
|
||||
The Dataframe used as reference will provided the index
|
||||
composition for each year an quarter:
|
||||
pd.DataFrame:
|
||||
symbol
|
||||
SH600000
|
||||
SH600001
|
||||
.
|
||||
.
|
||||
.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self: is used to represent the instance of the class.
|
||||
|
||||
Returns
|
||||
----------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600001 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
logger.info("Getting companies changes in {} index ...".format(self.index_name))
|
||||
|
||||
try:
|
||||
df_changes_list = []
|
||||
for i in tqdm(range(len(self.years_4_month_periods) - 1)):
|
||||
df = pd.read_csv(
|
||||
self.ibov_index_composition.format(self.years_4_month_periods[i]), on_bad_lines="skip"
|
||||
)["symbol"]
|
||||
df_ = pd.read_csv(
|
||||
self.ibov_index_composition.format(self.years_4_month_periods[i + 1]), on_bad_lines="skip"
|
||||
)["symbol"]
|
||||
|
||||
## Remove Dataframe
|
||||
remove_date = (
|
||||
self.years_4_month_periods[i].split("_")[0]
|
||||
+ "-"
|
||||
+ quarter_dict[self.years_4_month_periods[i].split("_")[1]]
|
||||
)
|
||||
list_remove = list(df[~df.isin(df_)])
|
||||
df_removed = pd.DataFrame(
|
||||
{
|
||||
"date": len(list_remove) * [remove_date],
|
||||
"type": len(list_remove) * ["remove"],
|
||||
"symbol": list_remove,
|
||||
}
|
||||
)
|
||||
|
||||
## Add Dataframe
|
||||
add_date = (
|
||||
self.years_4_month_periods[i + 1].split("_")[0]
|
||||
+ "-"
|
||||
+ quarter_dict[self.years_4_month_periods[i + 1].split("_")[1]]
|
||||
)
|
||||
list_add = list(df_[~df_.isin(df)])
|
||||
df_added = pd.DataFrame(
|
||||
{"date": len(list_add) * [add_date], "type": len(list_add) * ["add"], "symbol": list_add}
|
||||
)
|
||||
|
||||
df_changes_list.append(pd.concat([df_added, df_removed], sort=False))
|
||||
df = pd.concat(df_changes_list).reset_index(drop=True)
|
||||
df["symbol"] = df["symbol"].astype(str) + ".SA"
|
||||
|
||||
return df
|
||||
|
||||
except Exception as E:
|
||||
logger.error("An error occured while downloading 2008 index composition - {}".format(E))
|
||||
|
||||
def get_new_companies(self):
|
||||
"""
|
||||
Get latest index composition.
|
||||
The repo indicated on README has implemented a script
|
||||
to get the latest index composition from B3 website using
|
||||
selenium. Therefore, this method will download the file
|
||||
containing such composition
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self: is used to represent the instance of the class.
|
||||
|
||||
Returns
|
||||
----------
|
||||
pd.DataFrame:
|
||||
symbol start_date end_date
|
||||
RRRP3 2020-11-13 2022-03-02
|
||||
ALPA4 2008-01-02 2022-03-02
|
||||
dtypes:
|
||||
symbol: str
|
||||
start_date: pd.Timestamp
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
logger.info("Getting new companies in {} index ...".format(self.index_name))
|
||||
|
||||
try:
|
||||
## Get index composition
|
||||
|
||||
df_index = pd.read_csv(
|
||||
self.ibov_index_composition.format(self.year + "_" + self.current_4_month_period), on_bad_lines="skip"
|
||||
)
|
||||
df_date_first_added = pd.read_csv(
|
||||
self.ibov_index_composition.format("date_first_added_" + self.year + "_" + self.current_4_month_period),
|
||||
on_bad_lines="skip",
|
||||
)
|
||||
df = df_index.merge(df_date_first_added, on="symbol")[["symbol", "Date First Added"]]
|
||||
df[self.START_DATE_FIELD] = df["Date First Added"].map(self.format_quarter)
|
||||
|
||||
# end_date will be our current quarter + 1, since the IBOV index updates itself every quarter
|
||||
df[self.END_DATE_FIELD] = self.year + "-" + quarter_dict[self.current_4_month_period]
|
||||
df = df[["symbol", self.START_DATE_FIELD, self.END_DATE_FIELD]]
|
||||
df["symbol"] = df["symbol"].astype(str) + ".SA"
|
||||
|
||||
return df
|
||||
|
||||
except Exception as E:
|
||||
logger.error("An error occured while getting new companies - {}".format(E))
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Código" in df.columns:
|
||||
return df.loc[:, ["Código"]].copy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(partial(get_instruments, market_index="br_index"))
|
@ -0,0 +1,34 @@
|
||||
async-generator==1.10
|
||||
attrs==21.4.0
|
||||
certifi==2021.10.8
|
||||
cffi==1.15.0
|
||||
charset-normalizer==2.0.12
|
||||
cryptography==36.0.1
|
||||
fire==0.4.0
|
||||
h11==0.13.0
|
||||
idna==3.3
|
||||
loguru==0.6.0
|
||||
lxml==4.8.0
|
||||
multitasking==0.0.10
|
||||
numpy==1.22.2
|
||||
outcome==1.1.0
|
||||
pandas==1.4.1
|
||||
pycoingecko==2.2.0
|
||||
pycparser==2.21
|
||||
pyOpenSSL==22.0.0
|
||||
PySocks==1.7.1
|
||||
python-dateutil==2.8.2
|
||||
pytz==2021.3
|
||||
requests==2.27.1
|
||||
requests-futures==1.0.0
|
||||
six==1.16.0
|
||||
sniffio==1.2.0
|
||||
sortedcontainers==2.4.0
|
||||
termcolor==1.1.0
|
||||
tqdm==4.63.0
|
||||
trio==0.20.0
|
||||
trio-websocket==0.9.2
|
||||
urllib3==1.26.8
|
||||
wget==3.2
|
||||
wsproto==1.1.0
|
||||
yahooquery==2.2.15
|
@ -0,0 +1,22 @@
|
||||
# CSI300/CSI100/CSI500 History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
# index_name support: CSI300, CSI100, CSI500
|
||||
# help
|
||||
python collector.py --help
|
||||
```
|
||||
|
@ -0,0 +1,468 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import abc
|
||||
import sys
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
from typing import List, Iterable
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
import baostock as bs
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
|
||||
from data_collector.utils import get_instruments
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "https://csi-web-dev.oss-cn-shanghai-finance-1-pub.aliyuncs.com/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
|
||||
|
||||
INDEX_CHANGES_URL = "https://www.csindex.com.cn/csindex-home/search/search-content?lang=cn&searchInput=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC&pageNum={page_num}&pageSize={page_size}&sortField=date&dateRange=all&contentType=announcement"
|
||||
|
||||
REQ_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36 Edg/91.0.864.48"
|
||||
}
|
||||
|
||||
|
||||
@deco_retry
|
||||
def retry_request(url: str, method: str = "get", exclude_status: List = None):
|
||||
if exclude_status is None:
|
||||
exclude_status = []
|
||||
method_func = getattr(requests, method)
|
||||
_resp = method_func(url, headers=REQ_HEADERS)
|
||||
_status = _resp.status_code
|
||||
if _status not in exclude_status and _status != 200:
|
||||
raise ValueError(f"response status: {_status}, url={url}")
|
||||
return _resp
|
||||
|
||||
|
||||
class CSIIndex(IndexBase):
|
||||
@property
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
_calendar = getattr(self, "_calendar_list", None)
|
||||
if not _calendar:
|
||||
_calendar = get_calendar_list(bench_code=self.index_name.upper())
|
||||
setattr(self, "_calendar_list", _calendar)
|
||||
return _calendar
|
||||
|
||||
@property
|
||||
def new_companies_url(self) -> str:
|
||||
return NEW_COMPANIES_URL.format(index_code=self.index_code)
|
||||
|
||||
@property
|
||||
def changes_url(self) -> str:
|
||||
return INDEX_CHANGES_URL
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def index_code(self) -> str:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index code
|
||||
"""
|
||||
raise NotImplementedError("rewrite index_code")
|
||||
|
||||
@property
|
||||
def html_table_index(self) -> int:
|
||||
"""Which table of changes in html
|
||||
|
||||
CSI300: 0
|
||||
CSI100: 1
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError("rewrite html_table_index")
|
||||
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""formatting the datetime in an instrument
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst_df: pd.DataFrame
|
||||
inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
if self.freq != "day":
|
||||
inst_df[self.START_DATE_FIELD] = inst_df[self.START_DATE_FIELD].apply(
|
||||
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=9, minutes=30)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(
|
||||
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=15, minutes=0)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
return inst_df
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
logger.info("get companies changes......")
|
||||
res = []
|
||||
for _url in self._get_change_notices_url():
|
||||
_df = self._read_change_from_url(_url)
|
||||
if not _df.empty:
|
||||
res.append(_df)
|
||||
logger.info("get companies changes finish")
|
||||
return pd.concat(res, sort=False)
|
||||
|
||||
@staticmethod
|
||||
def normalize_symbol(symbol: str) -> str:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
|
||||
Returns
|
||||
-------
|
||||
symbol
|
||||
"""
|
||||
symbol = f"{int(symbol):06}"
|
||||
return f"SH{symbol}" if symbol.startswith("60") or symbol.startswith("688") else f"SZ{symbol}"
|
||||
|
||||
def _parse_excel(self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.Timestamp) -> pd.DataFrame:
|
||||
content = retry_request(excel_url, exclude_status=[404]).content
|
||||
_io = BytesIO(content)
|
||||
df_map = pd.read_excel(_io, sheet_name=None)
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
fp.write(content)
|
||||
tmp = []
|
||||
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
|
||||
_df = df_map[_s_name]
|
||||
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
|
||||
_df = _df.applymap(self.normalize_symbol)
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df["type"] = _type
|
||||
_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_df)
|
||||
df = pd.concat(tmp)
|
||||
return df
|
||||
|
||||
def _parse_table(self, content: str, add_date: pd.DataFrame, remove_date: pd.DataFrame) -> pd.DataFrame:
|
||||
df = pd.DataFrame()
|
||||
_tmp_count = 0
|
||||
for _df in pd.read_html(content):
|
||||
if _df.shape[-1] != 4 or _df.isnull().loc(0)[0][0]:
|
||||
continue
|
||||
_tmp_count += 1
|
||||
if self.html_table_index + 1 > _tmp_count:
|
||||
continue
|
||||
tmp = []
|
||||
for _s, _type, _date in [
|
||||
(_df.iloc[2:, 0], self.REMOVE, remove_date),
|
||||
(_df.iloc[2:, 2], self.ADD, add_date),
|
||||
]:
|
||||
_tmp_df = pd.DataFrame()
|
||||
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
|
||||
_tmp_df["type"] = _type
|
||||
_tmp_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_tmp_df)
|
||||
df = pd.concat(tmp)
|
||||
df.to_csv(
|
||||
str(
|
||||
self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
|
||||
).resolve()
|
||||
)
|
||||
)
|
||||
break
|
||||
return df
|
||||
|
||||
def _read_change_from_url(self, url: str) -> pd.DataFrame:
|
||||
"""read change from url
|
||||
The parameter url is from the _get_change_notices_url method.
|
||||
Determine the stock add_date/remove_date based on the title.
|
||||
The response contains three cases:
|
||||
1.Only excel_url(extract data from excel_url)
|
||||
2.Both the excel_url and the body text(try to extract data from excel_url first, and then try to extract data from body text)
|
||||
3.Only body text(extract data from body text)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
change url
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
resp = retry_request(url).json()["data"]
|
||||
title = resp["title"]
|
||||
if not title.startswith("关于"):
|
||||
return pd.DataFrame()
|
||||
if "沪深300" not in title:
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.info(f"load index data from https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}")
|
||||
_text = resp["content"]
|
||||
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
|
||||
if len(date_list) >= 2:
|
||||
add_date = pd.Timestamp("-".join(date_list[0]))
|
||||
else:
|
||||
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
|
||||
add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)
|
||||
if "盘后" in _text or "市后" in _text:
|
||||
add_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=1)
|
||||
remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)
|
||||
|
||||
excel_url = None
|
||||
if resp.get("enclosureList", []):
|
||||
excel_url = resp["enclosureList"][0]["fileUrl"]
|
||||
else:
|
||||
excel_url_list = re.findall('.*href="(.*?xls.*?)".*', _text)
|
||||
if excel_url_list:
|
||||
excel_url = excel_url_list[0]
|
||||
if not excel_url.startswith("http"):
|
||||
excel_url = excel_url if excel_url.startswith("/") else "/" + excel_url
|
||||
excel_url = f"http://www.csindex.com.cn{excel_url}"
|
||||
if excel_url:
|
||||
try:
|
||||
logger.info(f"get {add_date} changes from the excel, title={title}, excel_url={excel_url}")
|
||||
df = self._parse_excel(excel_url, add_date, remove_date)
|
||||
except ValueError:
|
||||
logger.info(
|
||||
f"get {add_date} changes from the web page, title={title}, url=https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}"
|
||||
)
|
||||
df = self._parse_table(_text, add_date, remove_date)
|
||||
else:
|
||||
logger.info(
|
||||
f"get {add_date} changes from the web page, title={title}, url=https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}"
|
||||
)
|
||||
df = self._parse_table(_text, add_date, remove_date)
|
||||
return df
|
||||
|
||||
def _get_change_notices_url(self) -> Iterable[str]:
|
||||
"""get change notices url
|
||||
|
||||
Returns
|
||||
-------
|
||||
[url1, url2]
|
||||
"""
|
||||
page_num = 1
|
||||
page_size = 5
|
||||
data = retry_request(self.changes_url.format(page_size=page_size, page_num=page_num)).json()
|
||||
data = retry_request(self.changes_url.format(page_size=data["total"], page_num=page_num)).json()
|
||||
for item in data["data"]:
|
||||
yield f"https://www.csindex.com.cn/csindex-home/announcement/queryAnnouncementById?id={item['id']}"
|
||||
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
start_date: pd.Timestamp
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
logger.info("get new companies......")
|
||||
context = retry_request(self.new_companies_url).content
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
fp.write(context)
|
||||
_io = BytesIO(context)
|
||||
df = pd.read_excel(_io)
|
||||
df = df.iloc[:, [0, 4]]
|
||||
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
|
||||
df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)
|
||||
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))
|
||||
df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
logger.info("end of get new companies.")
|
||||
return df
|
||||
|
||||
|
||||
class CSI300Index(CSIIndex):
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000300"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2005-01-01")
|
||||
|
||||
@property
|
||||
def html_table_index(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
class CSI100Index(CSIIndex):
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000903"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2006-05-29")
|
||||
|
||||
@property
|
||||
def html_table_index(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
class CSI500Index(CSIIndex):
|
||||
@property
|
||||
def index_code(self) -> str:
|
||||
return "000905"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2007-01-15")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Return
|
||||
--------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
return self.get_changes_with_history_companies(self.get_history_companies())
|
||||
|
||||
def get_history_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
bs.login()
|
||||
today = pd.datetime.now()
|
||||
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
|
||||
ret_list = []
|
||||
col = ["date", "symbol", "code_name"]
|
||||
for date in tqdm(date_range, desc="Download CSI500"):
|
||||
rs = bs.query_zz500_stocks(date=str(date))
|
||||
zz500_stocks = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
zz500_stocks.append(rs.get_row_data())
|
||||
result = pd.DataFrame(zz500_stocks, columns=col)
|
||||
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
|
||||
result = self.get_data_from_baostock(date)
|
||||
ret_list.append(result[["date", "symbol"]])
|
||||
bs.logout()
|
||||
return pd.concat(ret_list, sort=False)
|
||||
|
||||
def get_data_from_baostock(self, date) -> pd.DataFrame:
|
||||
"""
|
||||
Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1
|
||||
Avoid a large number of parallel data acquisition,
|
||||
such as 1000 times of concurrent data acquisition, because IP will be blocked
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
date symbol code_name
|
||||
SH600039 2007-01-15 四川路桥
|
||||
SH600051 2020-01-15 宁波联合
|
||||
dtypes:
|
||||
date: pd.Timestamp
|
||||
symbol: str
|
||||
code_name: str
|
||||
"""
|
||||
col = ["date", "symbol", "code_name"]
|
||||
rs = bs.query_zz500_stocks(date=str(date))
|
||||
zz500_stocks = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
zz500_stocks.append(rs.get_row_data())
|
||||
result = pd.DataFrame(zz500_stocks, columns=col)
|
||||
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
|
||||
return result
|
||||
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
start_date: pd.Timestamp
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
logger.info("get new companies......")
|
||||
today = datetime.date.today()
|
||||
bs.login()
|
||||
result = self.get_data_from_baostock(today)
|
||||
bs.logout()
|
||||
df = result[["date", "symbol"]]
|
||||
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
|
||||
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))
|
||||
df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
logger.info("end of get new companies.")
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(get_instruments)
|
@ -0,0 +1,8 @@
|
||||
baostock
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
tqdm
|
@ -0,0 +1,23 @@
|
||||
# Use 1d data to fill in the missing symbols relative to 1min
|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## fill 1min data
|
||||
|
||||
```bash
|
||||
python fill_1min_using_1d.py --data_1min_dir ~/.qlib/csv_data/cn_data_1min --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- data_1min_dir: csv data
|
||||
- qlib_data_1d_dir: qlib data directory
|
||||
- max_workers: `ThreadPoolExecutor(max_workers=max_workers)`, by default *16*
|
||||
- date_field_name: date field name, by default *date*
|
||||
- symbol_field_name: symbol field name, by default *symbol*
|
||||
|
@ -0,0 +1,100 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from qlib.data import D
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent.parent))
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
|
||||
|
||||
def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: str = "date"):
|
||||
csv_files = list(data_1min_dir.glob("*.csv"))
|
||||
min_date = None
|
||||
max_date = None
|
||||
with tqdm(total=len(csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for _file, _result in zip(csv_files, executor.map(pd.read_csv, csv_files)):
|
||||
if not _result.empty:
|
||||
_dates = pd.to_datetime(_result[date_field_name])
|
||||
|
||||
_tmp_min = _dates.min()
|
||||
min_date = min(min_date, _tmp_min) if min_date is not None else _tmp_min
|
||||
_tmp_max = _dates.max()
|
||||
max_date = max(max_date, _tmp_max) if max_date is not None else _tmp_max
|
||||
p_bar.update()
|
||||
return min_date, max_date
|
||||
|
||||
|
||||
def get_symbols(data_1min_dir: Path):
|
||||
return list(map(lambda x: x.name[:-4].upper(), data_1min_dir.glob("*.csv")))
|
||||
|
||||
|
||||
def fill_1min_using_1d(
|
||||
data_1min_dir: [str, Path],
|
||||
qlib_data_1d_dir: [str, Path],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""Use 1d data to fill in the missing symbols relative to 1min
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_1min_dir: str
|
||||
1min data dir
|
||||
qlib_data_1d_dir: str
|
||||
1d qlib data(bin data) dir, from: https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format
|
||||
max_workers: int
|
||||
ThreadPoolExecutor(max_workers), by default 16
|
||||
date_field_name: str
|
||||
date field name, by default date
|
||||
symbol_field_name: str
|
||||
symbol field name, by default symbol
|
||||
|
||||
"""
|
||||
data_1min_dir = Path(data_1min_dir).expanduser().resolve()
|
||||
qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()
|
||||
|
||||
min_date, max_date = get_date_range(data_1min_dir, max_workers, date_field_name)
|
||||
symbols_1min = get_symbols(data_1min_dir)
|
||||
|
||||
qlib.init(provider_uri=str(qlib_data_1d_dir))
|
||||
data_1d = D.features(D.instruments("all"), ["$close"], min_date, max_date, freq="day")
|
||||
|
||||
miss_symbols = set(data_1d.index.get_level_values(level="instrument").unique()) - set(symbols_1min)
|
||||
if not miss_symbols:
|
||||
logger.warning("More symbols in 1min than 1d, no padding required")
|
||||
return
|
||||
|
||||
logger.info(f"miss_symbols {len(miss_symbols)}: {miss_symbols}")
|
||||
tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0])
|
||||
columns = tmp_df.columns
|
||||
_si = tmp_df[symbol_field_name].first_valid_index()
|
||||
is_lower = tmp_df.loc[_si][symbol_field_name].islower()
|
||||
for symbol in tqdm(miss_symbols):
|
||||
if is_lower:
|
||||
symbol = symbol.lower()
|
||||
index_1d = data_1d.loc(axis=0)[symbol.upper()].index
|
||||
index_1min = generate_minutes_calendar_from_daily(index_1d)
|
||||
index_1min.name = date_field_name
|
||||
_df = pd.DataFrame(columns=columns, index=index_1min)
|
||||
if date_field_name in _df.columns:
|
||||
del _df[date_field_name]
|
||||
_df.reset_index(inplace=True)
|
||||
_df[symbol_field_name] = symbol
|
||||
_df["paused_num"] = 0
|
||||
_df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(fill_1min_using_1d)
|
@ -0,0 +1,5 @@
|
||||
fire
|
||||
pandas
|
||||
loguru
|
||||
tqdm
|
||||
pyqlib
|
@ -0,0 +1,24 @@
|
||||
# Get future trading days
|
||||
|
||||
> `D.calendar(future=True)` will be used
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- qlib_dir: qlib data directory
|
||||
- freq: value from [`day`, `1min`], default `day`
|
||||
|
||||
|
||||
|
@ -0,0 +1,88 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
# get data from baostock
|
||||
import baostock as bs
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent.parent))
|
||||
|
||||
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
|
||||
|
||||
def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame:
|
||||
calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt")
|
||||
if not calendar_path.exists():
|
||||
return pd.DataFrame()
|
||||
return pd.read_csv(calendar_path, header=None)
|
||||
|
||||
|
||||
def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"):
|
||||
calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt"))
|
||||
|
||||
np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8")
|
||||
logger.info(f"write future calendars success: {calendar_path}")
|
||||
|
||||
|
||||
def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]:
|
||||
print(freq)
|
||||
if freq == "day":
|
||||
return date_list
|
||||
elif freq == "1min":
|
||||
date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist()
|
||||
return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list))
|
||||
else:
|
||||
raise ValueError(f"Unsupported freq: {freq}")
|
||||
|
||||
|
||||
def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"):
|
||||
"""get future calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str or Path
|
||||
qlib data directory
|
||||
freq: str
|
||||
value from ["day", "1min"], by default day
|
||||
"""
|
||||
qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
if not qlib_dir.exists():
|
||||
raise FileNotFoundError(str(qlib_dir))
|
||||
|
||||
lg = bs.login()
|
||||
if lg.error_code != "0":
|
||||
logger.error(f"login error: {lg.error_msg}")
|
||||
return
|
||||
# read daily calendar
|
||||
daily_calendar = read_calendar_from_qlib(qlib_dir)
|
||||
end_year = pd.Timestamp.now().year
|
||||
if daily_calendar.empty:
|
||||
start_year = pd.Timestamp.now().year
|
||||
else:
|
||||
start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year
|
||||
rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31")
|
||||
data_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
_row_data = rs.get_row_data()
|
||||
if int(_row_data[1]) == 1:
|
||||
data_list.append(_row_data[0])
|
||||
data_list = sorted(data_list)
|
||||
date_list = generate_qlib_calendar(data_list, freq=freq)
|
||||
date_list = sorted(set(daily_calendar.loc[:, 0].values.tolist() + date_list))
|
||||
write_calendar_to_qlib(qlib_dir, date_list, freq=freq)
|
||||
bs.logout()
|
||||
logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(future_calendar_collector)
|
@ -0,0 +1,5 @@
|
||||
baostock
|
||||
fire
|
||||
numpy
|
||||
pandas
|
||||
loguru
|
@ -0,0 +1,54 @@
|
||||
# Collect Crypto Data
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [Coingecko](https://www.coingecko.com/en/api) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage of the dataset
|
||||
> *Crypto dateset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
### Crypto Data
|
||||
|
||||
#### 1d from Coingecko
|
||||
|
||||
```bash
|
||||
|
||||
# download from https://api.coingecko.com/api/v3/
|
||||
python collector.py download_data --source_dir ~/.qlib/crypto_data/source/1d --start 2015-01-01 --end 2021-11-30 --delay 1 --interval 1d
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/crypto_data/source/1d --normalize_dir ~/.qlib/crypto_data/source/1d_nor --interval 1d --date_field_name date
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/crypto_data/source/1d_nor --qlib_dir ~/.qlib/qlib_data/crypto_data --freq day --date_field_name date --include_fields prices,total_volumes,market_caps
|
||||
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/crypto_data")
|
||||
df = D.features(D.instruments(market="all"), ["$prices", "$total_volumes","$market_caps"], freq="day")
|
||||
```
|
||||
|
||||
|
||||
### Help
|
||||
```bash
|
||||
python collector.py collector_data --help
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- interval: 1d
|
||||
- delay: 1
|
@ -0,0 +1,311 @@
|
||||
import abc
|
||||
import sys
|
||||
import datetime
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import deco_retry
|
||||
|
||||
from pycoingecko import CoinGeckoAPI
|
||||
from time import mktime
|
||||
from datetime import datetime as dt
|
||||
import time
|
||||
|
||||
|
||||
_CG_CRYPTO_SYMBOLS = None
|
||||
|
||||
|
||||
def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get crypto symbols in coingecko
|
||||
|
||||
Returns
|
||||
-------
|
||||
crypto symbols in given exchanges list of coingecko
|
||||
"""
|
||||
global _CG_CRYPTO_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_coingecko():
|
||||
try:
|
||||
cg = CoinGeckoAPI()
|
||||
resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd"))
|
||||
except:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = resp["id"].to_list()
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
return _symbols
|
||||
|
||||
if _CG_CRYPTO_SYMBOLS is None:
|
||||
_all_symbols = _get_coingecko()
|
||||
|
||||
_CG_CRYPTO_SYMBOLS = sorted(set(_all_symbols))
|
||||
|
||||
return _CG_CRYPTO_SYMBOLS
|
||||
|
||||
|
||||
class CryptoCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=1,
|
||||
max_collector_count=2,
|
||||
delay=1, # delay need to be one
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
crypto save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
super(CryptoCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
self.init_datetime()
|
||||
|
||||
def init_datetime(self):
|
||||
if self.interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
||||
elif self.interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
try:
|
||||
cg = CoinGeckoAPI()
|
||||
data = cg.get_coin_market_chart_by_id(id=symbol, vs_currency="usd", days="max")
|
||||
_resp = pd.DataFrame(columns=["date"] + list(data.keys()))
|
||||
_resp["date"] = [dt.fromtimestamp(mktime(time.localtime(x[0] / 1000))) for x in data["prices"]]
|
||||
for key in data.keys():
|
||||
_resp[key] = [x[1] for x in data[key]]
|
||||
_resp["date"] = pd.to_datetime(_resp["date"])
|
||||
_resp["date"] = [x.date() for x in _resp["date"]]
|
||||
_resp = _resp[(_resp["date"] < pd.to_datetime(end).date()) & (_resp["date"] > pd.to_datetime(start).date())]
|
||||
if _resp.shape[0] != 0:
|
||||
_resp = _resp.reset_index()
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> [pd.DataFrame]:
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
raise ValueError(f"cannot support {interval}")
|
||||
return _result
|
||||
|
||||
|
||||
class CryptoCollector1d(CryptoCollector, ABC):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get coingecko crypto symbols......")
|
||||
symbols = get_cg_crypto_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "Asia/Shanghai"
|
||||
|
||||
|
||||
class CryptoNormalize(BaseNormalize):
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
@staticmethod
|
||||
def normalize_crypto(
|
||||
df: pd.DataFrame,
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
pd.DataFrame(index=calendar_list)
|
||||
.loc[
|
||||
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
|
||||
+ pd.Timedelta(hours=23, minutes=59)
|
||||
]
|
||||
.index
|
||||
)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.index.names = [date_field_name]
|
||||
return df.reset_index()
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = self.normalize_crypto(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
||||
return df
|
||||
|
||||
|
||||
class CryptoNormalize1d(CryptoNormalize):
|
||||
def _get_calendar_list(self):
|
||||
return None
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 1
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"CryptoCollector{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"CryptoNormalize{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d, currently only supprot 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: int # if this param useful?
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/crypto_data/source/1d --start 2015-01-01 --end 2021-11-30 --delay 1 --interval 1d
|
||||
"""
|
||||
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/crypto_data/source/1d --normalize_dir ~/.qlib/crypto_data/source/1d_nor --interval 1d --date_field_name date
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
@ -0,0 +1,8 @@
|
||||
loguru
|
||||
fire
|
||||
requests
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
pycoingecko
|
@ -0,0 +1,51 @@
|
||||
# Collect Fund Data
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
### CN Data
|
||||
|
||||
#### 1d from East Money
|
||||
|
||||
```bash
|
||||
|
||||
# download from eastmoney.com
|
||||
python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_data --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_data --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
|
||||
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_fund_data")
|
||||
df = D.features(D.instruments(market="all"), ["$DWJZ", "$LJJZ"], freq="day")
|
||||
```
|
||||
|
||||
|
||||
### Help
|
||||
```bash
|
||||
pythono collector.py collector_data --help
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- interval: 1d
|
||||
- region: CN
|
@ -0,0 +1,304 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import datetime
|
||||
import json
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.constant import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_en_fund_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
|
||||
|
||||
|
||||
class FundCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
fund save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
super(FundCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
self.init_datetime()
|
||||
|
||||
def init_datetime(self):
|
||||
if self.interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
||||
elif self.interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
try:
|
||||
# TODO: numberOfHistoricalDaysToCrawl should be bigger enough
|
||||
url = INDEX_BENCH_URL.format(
|
||||
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
|
||||
)
|
||||
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
data = json.loads(resp.text.split("(")[-1].split(")")[0])
|
||||
|
||||
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
|
||||
SYType = data["Data"]["SYType"]
|
||||
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
|
||||
raise Exception("The fund contains 每*份收益")
|
||||
|
||||
# TODO: should we sort the value by datetime?
|
||||
_resp = pd.DataFrame(data["Data"]["LSJZList"])
|
||||
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> [pd.DataFrame]:
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
raise ValueError(f"cannot support {interval}")
|
||||
return _result
|
||||
|
||||
|
||||
class FundollectorCN(FundCollector, ABC):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get cn fund symbols......")
|
||||
symbols = get_en_fund_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "Asia/Shanghai"
|
||||
|
||||
|
||||
class FundCollectorCN1d(FundollectorCN):
|
||||
pass
|
||||
|
||||
|
||||
class FundNormalize(BaseNormalize):
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
@staticmethod
|
||||
def normalize_fund(
|
||||
df: pd.DataFrame,
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
pd.DataFrame(index=calendar_list)
|
||||
.loc[
|
||||
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
|
||||
+ pd.Timedelta(hours=23, minutes=59)
|
||||
]
|
||||
.index
|
||||
)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.index.names = [date_field_name]
|
||||
return df.reset_index()
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
||||
return df
|
||||
|
||||
|
||||
class FundNormalize1d(FundNormalize):
|
||||
pass
|
||||
|
||||
|
||||
class FundNormalizeCN:
|
||||
def _get_calendar_list(self):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
region, value from ["CN"], default "CN"
|
||||
"""
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"FundCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"FundNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: int # if this param useful?
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_data --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
"""
|
||||
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_data --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
@ -0,0 +1,10 @@
|
||||
loguru
|
||||
fire
|
||||
requests
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
loguru
|
||||
yahooquery
|
||||
json
|
@ -0,0 +1,121 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Union, Iterable, List
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# pip install baostock
|
||||
import baostock as bs
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CollectorFutureCalendar:
|
||||
calendar_format = "%Y-%m-%d"
|
||||
|
||||
def __init__(self, qlib_dir: Union[str, Path], start_date: str = None, end_date: str = None):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir:
|
||||
qlib data directory
|
||||
start_date
|
||||
start date
|
||||
end_date
|
||||
end date
|
||||
"""
|
||||
self.qlib_dir = Path(qlib_dir).expanduser().absolute()
|
||||
self.calendar_path = self.qlib_dir.joinpath("calendars/day.txt")
|
||||
self.future_path = self.qlib_dir.joinpath("calendars/day_future.txt")
|
||||
self._calendar_list = self.calendar_list
|
||||
_latest_date = self._calendar_list[-1]
|
||||
self.start_date = _latest_date if start_date is None else pd.Timestamp(start_date)
|
||||
self.end_date = _latest_date + pd.Timedelta(days=365 * 2) if end_date is None else pd.Timestamp(end_date)
|
||||
|
||||
@property
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
# load old calendar
|
||||
if not self.calendar_path.exists():
|
||||
raise ValueError(f"calendar does not exist: {self.calendar_path}")
|
||||
calendar_df = pd.read_csv(self.calendar_path, header=None)
|
||||
calendar_df.columns = ["date"]
|
||||
calendar_df["date"] = pd.to_datetime(calendar_df["date"])
|
||||
return calendar_df["date"].to_list()
|
||||
|
||||
def _format_datetime(self, datetime_d: [str, pd.Timestamp]):
|
||||
datetime_d = pd.Timestamp(datetime_d)
|
||||
return datetime_d.strftime(self.calendar_format)
|
||||
|
||||
def write_calendar(self, calendar: Iterable):
|
||||
calendars_list = list(map(lambda x: self._format_datetime(x), sorted(set(self.calendar_list + calendar))))
|
||||
np.savetxt(self.future_path, calendars_list, fmt="%s", encoding="utf-8")
|
||||
|
||||
@abc.abstractmethod
|
||||
def collector(self) -> Iterable[pd.Timestamp]:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `collector` method")
|
||||
|
||||
|
||||
class CollectorFutureCalendarCN(CollectorFutureCalendar):
|
||||
def collector(self) -> Iterable[pd.Timestamp]:
|
||||
lg = bs.login()
|
||||
if lg.error_code != "0":
|
||||
raise ValueError(f"login respond error_msg: {lg.error_msg}")
|
||||
rs = bs.query_trade_dates(
|
||||
start_date=self._format_datetime(self.start_date), end_date=self._format_datetime(self.end_date)
|
||||
)
|
||||
if rs.error_code != "0":
|
||||
raise ValueError(f"query_trade_dates respond error_msg: {rs.error_msg}")
|
||||
data_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
data_list.append(rs.get_row_data())
|
||||
calendar = pd.DataFrame(data_list, columns=rs.fields)
|
||||
calendar["is_trading_day"] = calendar["is_trading_day"].astype(int)
|
||||
return pd.to_datetime(calendar[calendar["is_trading_day"] == 1]["calendar_date"]).to_list()
|
||||
|
||||
|
||||
class CollectorFutureCalendarUS(CollectorFutureCalendar):
|
||||
def collector(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: US future calendar
|
||||
raise ValueError("Us calendar is not supported")
|
||||
|
||||
|
||||
def run(qlib_dir: Union[str, Path], region: str = "cn", start_date: str = None, end_date: str = None):
|
||||
"""Collect future calendar(day)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir:
|
||||
qlib data directory
|
||||
region:
|
||||
cn/CN or us/US
|
||||
start_date
|
||||
start date
|
||||
end_date
|
||||
end date
|
||||
|
||||
Examples
|
||||
-------
|
||||
# get cn future calendar
|
||||
$ python future_calendar_collector.py --qlib_data_1d_dir <user data dir> --region cn
|
||||
"""
|
||||
logger.info(f"collector future calendar: region={region}")
|
||||
_cur_module = importlib.import_module("future_calendar_collector")
|
||||
_class = getattr(_cur_module, f"CollectorFutureCalendar{region.upper()}")
|
||||
collector = _class(qlib_dir=qlib_dir, start_date=start_date, end_date=end_date)
|
||||
collector.write_calendar(collector.collector())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(run)
|
@ -0,0 +1,238 @@
|
||||
import sys
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent))
|
||||
|
||||
|
||||
from data_collector.utils import get_trading_date_by_shift
|
||||
|
||||
|
||||
class IndexBase:
|
||||
DEFAULT_END_DATE = pd.Timestamp("2099-12-31")
|
||||
SYMBOL_FIELD_NAME = "symbol"
|
||||
DATE_FIELD_NAME = "date"
|
||||
START_DATE_FIELD = "start_date"
|
||||
END_DATE_FIELD = "end_date"
|
||||
CHANGE_TYPE_FIELD = "type"
|
||||
INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD]
|
||||
REMOVE = "remove"
|
||||
ADD = "add"
|
||||
INST_PREFIX = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
qlib_dir: [str, Path] = None,
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index_name: str
|
||||
index name
|
||||
qlib_dir: str
|
||||
qlib directory, by default Path(__file__).resolve().parent.joinpath("qlib_data")
|
||||
freq: str
|
||||
freq, value from ["day", "1min"]
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
"""
|
||||
self.index_name = index_name
|
||||
if qlib_dir is None:
|
||||
qlib_dir = Path(__file__).resolve().parent.joinpath("qlib_data")
|
||||
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
|
||||
self.instruments_dir.mkdir(exist_ok=True, parents=True)
|
||||
self.cache_dir = Path(f"~/.cache/qlib/index/{self.index_name}").expanduser().resolve()
|
||||
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
self._request_retry = request_retry
|
||||
self._retry_sleep = retry_sleep
|
||||
self.freq = freq
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
raise NotImplementedError("rewrite calendar_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
start_date: pd.Timestamp
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_new_companies")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_changes")
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""formatting the datetime in an instrument
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst_df: pd.DataFrame
|
||||
inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite format_datetime")
|
||||
|
||||
def save_new_companies(self):
|
||||
"""save new companies
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
df = self.get_new_companies()
|
||||
if df is None or df.empty:
|
||||
raise ValueError(f"get new companies error: {self.index_name}")
|
||||
df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])
|
||||
df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
|
||||
def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> pd.DataFrame:
|
||||
"""get changes with history companies
|
||||
|
||||
Parameters
|
||||
----------
|
||||
history_companies : pd.DataFrame
|
||||
symbol date
|
||||
SH600000 2020-11-11
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
|
||||
Return
|
||||
--------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
|
||||
"""
|
||||
logger.info("parse changes from history companies......")
|
||||
last_code = []
|
||||
result_df_list = []
|
||||
_columns = [self.DATE_FIELD_NAME, self.SYMBOL_FIELD_NAME, self.CHANGE_TYPE_FIELD]
|
||||
for _trading_date in tqdm(sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True)):
|
||||
_currenet_code = history_companies[history_companies[self.DATE_FIELD_NAME] == _trading_date][
|
||||
self.SYMBOL_FIELD_NAME
|
||||
].tolist()
|
||||
if last_code:
|
||||
add_code = list(set(last_code) - set(_currenet_code))
|
||||
remote_code = list(set(_currenet_code) - set(last_code))
|
||||
for _code in add_code:
|
||||
result_df_list.append(
|
||||
pd.DataFrame(
|
||||
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 1), _code, self.ADD]],
|
||||
columns=_columns,
|
||||
)
|
||||
)
|
||||
for _code in remote_code:
|
||||
result_df_list.append(
|
||||
pd.DataFrame(
|
||||
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 0), _code, self.REMOVE]],
|
||||
columns=_columns,
|
||||
)
|
||||
)
|
||||
last_code = _currenet_code
|
||||
df = pd.concat(result_df_list)
|
||||
logger.info("end of parse changes from history companies.")
|
||||
return df
|
||||
|
||||
def parse_instruments(self):
|
||||
"""parse instruments, eg: csi300.txt
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py parse_instruments --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
logger.info(f"start parse {self.index_name.lower()} companies.....")
|
||||
instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
changers_df = self.get_changes()
|
||||
new_df = self.get_new_companies()
|
||||
if new_df is None or new_df.empty:
|
||||
raise ValueError(f"get new companies error: {self.index_name}")
|
||||
new_df = new_df.copy()
|
||||
logger.info("parse history companies by changes......")
|
||||
for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):
|
||||
if _row.type == self.ADD:
|
||||
min_end_date = new_df.loc[new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD].min()
|
||||
new_df.loc[
|
||||
(new_df[self.END_DATE_FIELD] == min_end_date) & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol),
|
||||
self.START_DATE_FIELD,
|
||||
] = _row.date
|
||||
else:
|
||||
_tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns)
|
||||
new_df = pd.concat([new_df, _tmp_df], sort=False)
|
||||
|
||||
inst_df = new_df.loc[:, instruments_columns]
|
||||
_inst_prefix = self.INST_PREFIX.strip()
|
||||
if _inst_prefix:
|
||||
inst_df["save_inst"] = inst_df[self.SYMBOL_FIELD_NAME].apply(lambda x: f"{_inst_prefix}{x}")
|
||||
inst_df = self.format_datetime(inst_df)
|
||||
inst_df.to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
logger.info(f"parse {self.index_name.lower()} companies finished.")
|
@ -0,0 +1,40 @@
|
||||
# Collect Point-in-Time Data
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [baostock](http://baostock.com) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
### Download Quarterly CN Data
|
||||
|
||||
```bash
|
||||
cd qlib/scripts/data_collector/pit/
|
||||
# download from baostock.com
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly
|
||||
```
|
||||
|
||||
Downloading all data from the stock is very time consuming. If you just want run a quick test on a few stocks, you can run the command below
|
||||
```bash
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex "^(600519|000725).*"
|
||||
```
|
||||
|
||||
|
||||
### Normalize Data
|
||||
```bash
|
||||
python collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Dump Data into PIT Format
|
||||
|
||||
```bash
|
||||
cd qlib/scripts
|
||||
python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly
|
||||
```
|
@ -0,0 +1,262 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Iterable, Optional, Union
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
import baostock as bs
|
||||
from loguru import logger
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(BASE_DIR.parent.parent))
|
||||
|
||||
from data_collector.base import BaseCollector, BaseRun, BaseNormalize
|
||||
from data_collector.utils import get_hs_stock_symbols, get_calendar_list
|
||||
|
||||
|
||||
class PitCollector(BaseCollector):
|
||||
DEFAULT_START_DATETIME_QUARTERLY = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_ANNUAL = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_END_DATETIME_QUARTERLY = pd.Timestamp(datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_END_DATETIME_ANNUAL = pd.Timestamp(datetime.now() + pd.Timedelta(days=1))
|
||||
|
||||
INTERVAL_QUARTERLY = "quarterly"
|
||||
INTERVAL_ANNUAL = "annual"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: Union[str, Path],
|
||||
start: Optional[str] = None,
|
||||
end: Optional[str] = None,
|
||||
interval: str = "quarterly",
|
||||
max_workers: int = 1,
|
||||
max_collector_count: int = 1,
|
||||
delay: int = 0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: Optional[int] = None,
|
||||
symbol_regex: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
instrument save dir
|
||||
max_workers: int
|
||||
workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
symbol_regex: str
|
||||
symbol regular expression, by default None.
|
||||
"""
|
||||
self.symbol_regex = symbol_regex
|
||||
super().__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
def get_instrument_list(self) -> List[str]:
|
||||
logger.info("get cn stock symbols......")
|
||||
symbols = get_hs_stock_symbols()
|
||||
if self.symbol_regex is not None:
|
||||
regex_compile = re.compile(self.symbol_regex)
|
||||
symbols = [symbol for symbol in symbols if regex_compile.match(symbol)]
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
symbol, exchange = symbol.split(".")
|
||||
exchange = "sh" if exchange == "ss" else "sz"
|
||||
return f"{exchange}{symbol}"
|
||||
|
||||
@staticmethod
|
||||
def get_performance_express_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||||
column_mapping = {
|
||||
"performanceExpPubDate": "date",
|
||||
"performanceExpStatDate": "period",
|
||||
"performanceExpressROEWa": "value",
|
||||
}
|
||||
|
||||
resp = bs.query_performance_express_report(code=code, start_date=start_date, end_date=end_date)
|
||||
report_list = []
|
||||
while (resp.error_code == "0") and resp.next():
|
||||
report_list.append(resp.get_row_data())
|
||||
report_df = pd.DataFrame(report_list, columns=resp.fields)
|
||||
try:
|
||||
report_df = report_df[list(column_mapping.keys())]
|
||||
except KeyError:
|
||||
return pd.DataFrame()
|
||||
report_df.rename(columns=column_mapping, inplace=True)
|
||||
report_df["field"] = "roeWa"
|
||||
report_df["value"] = pd.to_numeric(report_df["value"], errors="ignore")
|
||||
report_df["value"] = report_df["value"].apply(lambda x: x / 100.0)
|
||||
return report_df
|
||||
|
||||
@staticmethod
|
||||
def get_profit_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||||
column_mapping = {"pubDate": "date", "statDate": "period", "roeAvg": "value"}
|
||||
fields = bs.query_profit_data(code="sh.600519", year=2020, quarter=1).fields
|
||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)]
|
||||
profit_list = []
|
||||
for year, quarter in args:
|
||||
resp = bs.query_profit_data(code=code, year=year, quarter=quarter)
|
||||
while (resp.error_code == "0") and resp.next():
|
||||
if "pubDate" not in resp.fields:
|
||||
continue
|
||||
row_data = resp.get_row_data()
|
||||
pub_date = pd.Timestamp(row_data[resp.fields.index("pubDate")])
|
||||
if start_date <= pub_date <= end_date and row_data:
|
||||
profit_list.append(row_data)
|
||||
profit_df = pd.DataFrame(profit_list, columns=fields)
|
||||
try:
|
||||
profit_df = profit_df[list(column_mapping.keys())]
|
||||
except KeyError:
|
||||
return pd.DataFrame()
|
||||
profit_df.rename(columns=column_mapping, inplace=True)
|
||||
profit_df["field"] = "roeWa"
|
||||
profit_df["value"] = pd.to_numeric(profit_df["value"], errors="ignore")
|
||||
return profit_df
|
||||
|
||||
@staticmethod
|
||||
def get_forecast_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||||
column_mapping = {
|
||||
"profitForcastExpPubDate": "date",
|
||||
"profitForcastExpStatDate": "period",
|
||||
"value": "value",
|
||||
}
|
||||
resp = bs.query_forecast_report(code=code, start_date=start_date, end_date=end_date)
|
||||
forecast_list = []
|
||||
while (resp.error_code == "0") and resp.next():
|
||||
forecast_list.append(resp.get_row_data())
|
||||
forecast_df = pd.DataFrame(forecast_list, columns=resp.fields)
|
||||
numeric_fields = ["profitForcastChgPctUp", "profitForcastChgPctDwn"]
|
||||
try:
|
||||
forecast_df[numeric_fields] = forecast_df[numeric_fields].apply(pd.to_numeric, errors="ignore")
|
||||
except KeyError:
|
||||
return pd.DataFrame()
|
||||
forecast_df["value"] = (forecast_df["profitForcastChgPctUp"] + forecast_df["profitForcastChgPctDwn"]) / 200
|
||||
forecast_df = forecast_df[list(column_mapping.keys())]
|
||||
forecast_df.rename(columns=column_mapping, inplace=True)
|
||||
forecast_df["field"] = "YOYNI"
|
||||
return forecast_df
|
||||
|
||||
@staticmethod
|
||||
def get_growth_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||||
column_mapping = {"pubDate": "date", "statDate": "period", "YOYNI": "value"}
|
||||
fields = bs.query_growth_data(code="sh.600519", year=2020, quarter=1).fields
|
||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)]
|
||||
growth_list = []
|
||||
for year, quarter in args:
|
||||
resp = bs.query_growth_data(code=code, year=year, quarter=quarter)
|
||||
while (resp.error_code == "0") and resp.next():
|
||||
if "pubDate" not in resp.fields:
|
||||
continue
|
||||
row_data = resp.get_row_data()
|
||||
pub_date = pd.Timestamp(row_data[resp.fields.index("pubDate")])
|
||||
if start_date <= pub_date <= end_date and row_data:
|
||||
growth_list.append(row_data)
|
||||
growth_df = pd.DataFrame(growth_list, columns=fields)
|
||||
try:
|
||||
growth_df = growth_df[list(column_mapping.keys())]
|
||||
except KeyError:
|
||||
return pd.DataFrame()
|
||||
growth_df.rename(columns=column_mapping, inplace=True)
|
||||
growth_df["field"] = "YOYNI"
|
||||
growth_df["value"] = pd.to_numeric(growth_df["value"], errors="ignore")
|
||||
return growth_df
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
symbol: str,
|
||||
interval: str,
|
||||
start_datetime: pd.Timestamp,
|
||||
end_datetime: pd.Timestamp,
|
||||
) -> pd.DataFrame:
|
||||
if interval != self.INTERVAL_QUARTERLY:
|
||||
raise ValueError(f"cannot support {interval}")
|
||||
symbol, exchange = symbol.split(".")
|
||||
exchange = "sh" if exchange == "ss" else "sz"
|
||||
code = f"{exchange}.{symbol}"
|
||||
start_date = start_datetime.strftime("%Y-%m-%d")
|
||||
end_date = end_datetime.strftime("%Y-%m-%d")
|
||||
|
||||
performance_express_report_df = self.get_performance_express_report_df(code, start_date, end_date)
|
||||
profit_df = self.get_profit_df(code, start_date, end_date)
|
||||
forecast_report_df = self.get_forecast_report_df(code, start_date, end_date)
|
||||
growth_df = self.get_growth_df(code, start_date, end_date)
|
||||
|
||||
df = pd.concat(
|
||||
[performance_express_report_df, profit_df, forecast_report_df, growth_df],
|
||||
axis=0,
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
class PitNormalize(BaseNormalize):
|
||||
def __init__(self, interval: str = "quarterly", *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.interval = interval
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
dt = df["period"].apply(
|
||||
lambda x: (
|
||||
pd.to_datetime(x) + pd.DateOffset(days=(45 if self.interval == PitCollector.INTERVAL_QUARTERLY else 90))
|
||||
).date()
|
||||
)
|
||||
df["date"] = df["date"].fillna(dt.astype(str))
|
||||
|
||||
df["period"] = pd.to_datetime(df["period"])
|
||||
df["period"] = df["period"].apply(
|
||||
lambda x: x.year if self.interval == PitCollector.INTERVAL_ANNUAL else x.year * 100 + (x.month - 1) // 3 + 1
|
||||
)
|
||||
return df
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return get_calendar_list()
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
@property
|
||||
def collector_class_name(self) -> str:
|
||||
return f"PitCollector"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self) -> str:
|
||||
return f"PitNormalize"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return BASE_DIR
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bs.login()
|
||||
fire.Fire(Run)
|
||||
bs.logout()
|
@ -0,0 +1,10 @@
|
||||
loguru
|
||||
fire
|
||||
tqdm
|
||||
requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
baostock
|
||||
yahooquery
|
||||
beautifulsoup4
|
@ -0,0 +1,22 @@
|
||||
# NASDAQ100/SP500/SP400/DJIA History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method save_new_companies
|
||||
|
||||
# index_name support: SP500, NASDAQ100, DJIA, SP400
|
||||
# help
|
||||
python collector.py --help
|
||||
```
|
||||
|
@ -0,0 +1,275 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from functools import partial
|
||||
import sys
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
from data_collector.index import IndexBase
|
||||
from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift
|
||||
from data_collector.utils import get_instruments
|
||||
|
||||
|
||||
WIKI_URL = "https://en.wikipedia.org/wiki"
|
||||
|
||||
WIKI_INDEX_NAME_MAP = {
|
||||
"NASDAQ100": "NASDAQ-100",
|
||||
"SP500": "List_of_S%26P_500_companies",
|
||||
"SP400": "List_of_S%26P_400_companies",
|
||||
"DJIA": "Dow_Jones_Industrial_Average",
|
||||
}
|
||||
|
||||
|
||||
class WIKIIndex(IndexBase):
|
||||
# NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix
|
||||
# https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
|
||||
INST_PREFIX = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
qlib_dir: [str, Path] = None,
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
):
|
||||
super(WIKIIndex, self).__init__(
|
||||
index_name=index_name, qlib_dir=qlib_dir, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
|
||||
self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}"
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_changes")
|
||||
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""formatting the datetime in an instrument
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst_df: pd.DataFrame
|
||||
inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
if self.freq != "day":
|
||||
inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(
|
||||
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=23, minutes=59)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
return inst_df
|
||||
|
||||
@property
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
_calendar_list = getattr(self, "_calendar_list", None)
|
||||
if _calendar_list is None:
|
||||
_calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL")))
|
||||
setattr(self, "_calendar_list", _calendar_list)
|
||||
return _calendar_list
|
||||
|
||||
def _request_new_companies(self) -> requests.Response:
|
||||
resp = requests.get(self._target_url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {self._target_url}")
|
||||
|
||||
return resp
|
||||
|
||||
def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
_df = df.copy()
|
||||
_df[self.SYMBOL_FIELD_NAME] = _df[self.SYMBOL_FIELD_NAME].str.strip()
|
||||
_df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
_df[self.END_DATE_FIELD] = self.DEFAULT_END_DATE
|
||||
return _df.loc[:, self.INSTRUMENTS_COLUMNS]
|
||||
|
||||
def get_new_companies(self):
|
||||
logger.info(f"get new companies {self.index_name} ......")
|
||||
_data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)()
|
||||
df_list = pd.read_html(_data.text)
|
||||
for _df in df_list:
|
||||
_df = self.filter_df(_df)
|
||||
if (_df is not None) and (not _df.empty):
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df = self.set_default_date_range(_df)
|
||||
logger.info(f"end of get new companies {self.index_name} ......")
|
||||
return _df
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
raise NotImplementedError("rewrite filter_df")
|
||||
|
||||
|
||||
class NASDAQ100Index(WIKIIndex):
|
||||
|
||||
HISTORY_COMPANIES_URL = (
|
||||
"https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD"
|
||||
)
|
||||
MAX_WORKERS = 16
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if len(df) >= 100 and "Ticker" in df.columns:
|
||||
return df.loc[:, ["Ticker"]].copy()
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2003-01-02")
|
||||
|
||||
@deco_retry
|
||||
def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame:
|
||||
trade_date = trade_date.strftime("%Y-%m-%d")
|
||||
cache_path = self.cache_dir.joinpath(f"{trade_date}_history_companies.pkl")
|
||||
if cache_path.exists() and use_cache:
|
||||
df = pd.read_pickle(cache_path)
|
||||
else:
|
||||
url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)
|
||||
resp = requests.post(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {url}")
|
||||
df = pd.DataFrame(resp.json()["aaData"])
|
||||
df[self.DATE_FIELD_NAME] = trade_date
|
||||
df.rename(columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True)
|
||||
if not df.empty:
|
||||
df.to_pickle(cache_path)
|
||||
return df
|
||||
|
||||
def get_history_companies(self):
|
||||
logger.info(f"start get history companies......")
|
||||
all_history = []
|
||||
error_list = []
|
||||
with tqdm(total=len(self.calendar_list)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
||||
for _trading_date, _df in zip(
|
||||
self.calendar_list, executor.map(self._request_history_companies, self.calendar_list)
|
||||
):
|
||||
if _df.empty:
|
||||
error_list.append(_trading_date)
|
||||
else:
|
||||
all_history.append(_df)
|
||||
p_bar.update()
|
||||
|
||||
if error_list:
|
||||
logger.warning(f"get error: {error_list}")
|
||||
logger.info(f"total {len(self.calendar_list)}, error {len(error_list)}")
|
||||
logger.info(f"end of get history companies.")
|
||||
return pd.concat(all_history, sort=False)
|
||||
|
||||
def get_changes(self):
|
||||
return self.get_changes_with_history_companies(self.get_history_companies())
|
||||
|
||||
|
||||
class DJIAIndex(WIKIIndex):
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2000-01-01")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
pass
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Symbol" in df.columns:
|
||||
_df = df.loc[:, ["Symbol"]].copy()
|
||||
_df["Symbol"] = _df["Symbol"].apply(lambda x: x.split(":")[-1])
|
||||
return _df
|
||||
|
||||
def parse_instruments(self):
|
||||
logger.warning(f"No suitable data source has been found!")
|
||||
|
||||
|
||||
class SP500Index(WIKIIndex):
|
||||
WIKISP500_CHANGES_URL = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("1999-01-01")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
logger.info(f"get sp500 history changes......")
|
||||
# NOTE: may update the index of the table
|
||||
changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1]
|
||||
changes_df = changes_df.iloc[:, [0, 1, 3]]
|
||||
changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE]
|
||||
changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME])
|
||||
_result = []
|
||||
for _type in [self.ADD, self.REMOVE]:
|
||||
_df = changes_df.copy()
|
||||
_df[self.CHANGE_TYPE_FIELD] = _type
|
||||
_df[self.SYMBOL_FIELD_NAME] = _df[_type]
|
||||
_df.dropna(subset=[self.SYMBOL_FIELD_NAME], inplace=True)
|
||||
if _type == self.ADD:
|
||||
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
|
||||
lambda x: get_trading_date_by_shift(self.calendar_list, x, 0)
|
||||
)
|
||||
else:
|
||||
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
|
||||
lambda x: get_trading_date_by_shift(self.calendar_list, x, -1)
|
||||
)
|
||||
_result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]])
|
||||
logger.info(f"end of get sp500 history changes.")
|
||||
return pd.concat(_result, sort=False)
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Symbol" in df.columns:
|
||||
return df.loc[:, ["Symbol"]].copy()
|
||||
|
||||
|
||||
class SP400Index(WIKIIndex):
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2000-01-01")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
pass
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Ticker symbol" in df.columns:
|
||||
return df.loc[:, ["Ticker symbol"]].copy()
|
||||
|
||||
def parse_instruments(self):
|
||||
logger.warning(f"No suitable data source has been found!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(partial(get_instruments, market_index="us_index"))
|
@ -0,0 +1,6 @@
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
@ -0,0 +1,609 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import importlib
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
|
||||
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231"
|
||||
SZSE_CALENDAR_URL = "http://www.szse.cn/api/report/exchange/onepersistenthour/monthList?month={month}&random={random}"
|
||||
|
||||
CALENDAR_BENCH_URL_MAP = {
|
||||
"CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"),
|
||||
"CSI500": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
"CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"),
|
||||
# NOTE: Use the time series of SH600000 as the sequence of all stocks
|
||||
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
|
||||
"US_ALL": "^GSPC",
|
||||
"IN_ALL": "^NSEI",
|
||||
"BR_ALL": "^BVSP",
|
||||
}
|
||||
|
||||
_BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
_US_SYMBOLS = None
|
||||
_IN_SYMBOLS = None
|
||||
_BR_SYMBOLS = None
|
||||
_EN_FUND_SYMBOLS = None
|
||||
_CALENDAR_MAP = {}
|
||||
|
||||
# NOTE: Until 2020-10-20 20:00:00
|
||||
MINIMUM_SYMBOLS_NUM = 3900
|
||||
|
||||
|
||||
def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bench_code: str
|
||||
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
|
||||
logger.info(f"get calendar list: {bench_code}......")
|
||||
|
||||
def _get_calendar(url):
|
||||
_value_list = requests.get(url).json()["data"]["klines"]
|
||||
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
|
||||
|
||||
calendar = _CALENDAR_MAP.get(bench_code, None)
|
||||
if calendar is None:
|
||||
if bench_code.startswith("US_") or bench_code.startswith("IN_") or bench_code.startswith("BR_"):
|
||||
print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]))
|
||||
print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max"))
|
||||
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
|
||||
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
|
||||
else:
|
||||
if bench_code.upper() == "ALL":
|
||||
|
||||
@deco_retry
|
||||
def _get_calendar(month):
|
||||
_cal = []
|
||||
try:
|
||||
resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json()
|
||||
for _r in resp["data"]:
|
||||
if int(_r["jybz"]):
|
||||
_cal.append(pd.Timestamp(_r["jyrq"]))
|
||||
except Exception as e:
|
||||
raise ValueError(f"{month}-->{e}")
|
||||
return _cal
|
||||
|
||||
month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M")
|
||||
calendar = []
|
||||
for _m in month_range:
|
||||
cal = _get_calendar(_m.strftime("%Y-%m"))
|
||||
if cal:
|
||||
calendar += cal
|
||||
calendar = list(filter(lambda x: x <= pd.Timestamp.now(), calendar))
|
||||
else:
|
||||
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
|
||||
_CALENDAR_MAP[bench_code] = calendar
|
||||
logger.info(f"end of get calendar list: {bench_code}.")
|
||||
return calendar
|
||||
|
||||
|
||||
def return_date_list(date_field_name: str, file_path: Path):
|
||||
date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list()
|
||||
return sorted(map(lambda x: pd.Timestamp(x), date_list))
|
||||
|
||||
|
||||
def get_calendar_list_by_ratio(
|
||||
source_dir: [str, Path],
|
||||
date_field_name: str = "date",
|
||||
threshold: float = 0.5,
|
||||
minimum_count: int = 10,
|
||||
max_workers: int = 16,
|
||||
) -> list:
|
||||
"""get calendar list by selecting the date when few funds trade in this day
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
threshold: float
|
||||
threshold to exclude some days when few funds trade in this day, default 0.5
|
||||
minimum_count: int
|
||||
minimum count of funds should trade in one day
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......")
|
||||
|
||||
source_dir = Path(source_dir).expanduser()
|
||||
file_list = list(source_dir.glob("*.csv"))
|
||||
|
||||
_number_all_funds = len(file_list)
|
||||
|
||||
logger.info(f"count how many funds trade in this day......")
|
||||
_dict_count_trade = dict() # dict{date:count}
|
||||
_fun = partial(return_date_list, date_field_name)
|
||||
all_oldest_list = []
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
for date_list in executor.map(_fun, file_list):
|
||||
if date_list:
|
||||
all_oldest_list.append(date_list[0])
|
||||
for date in date_list:
|
||||
if date not in _dict_count_trade.keys():
|
||||
_dict_count_trade[date] = 0
|
||||
|
||||
_dict_count_trade[date] += 1
|
||||
|
||||
p_bar.update()
|
||||
|
||||
logger.info(f"count how many funds have founded in this day......")
|
||||
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
for oldest_date in all_oldest_list:
|
||||
for date in _dict_count_founding.keys():
|
||||
if date < oldest_date:
|
||||
_dict_count_founding[date] -= 1
|
||||
|
||||
calendar = [
|
||||
date
|
||||
for date in _dict_count_trade
|
||||
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
|
||||
]
|
||||
|
||||
return calendar
|
||||
|
||||
|
||||
def get_hs_stock_symbols() -> list:
|
||||
"""get SH/SZ stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _HS_SYMBOLS
|
||||
|
||||
def _get_symbol():
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k))
|
||||
_res |= set(
|
||||
map(
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
|
||||
)
|
||||
)
|
||||
time.sleep(3)
|
||||
return _res
|
||||
|
||||
if _HS_SYMBOLS is None:
|
||||
symbols = set()
|
||||
_retry = 60
|
||||
# It may take multiple times to get the complete
|
||||
while len(symbols) < MINIMUM_SYMBOLS_NUM:
|
||||
symbols |= _get_symbol()
|
||||
time.sleep(3)
|
||||
|
||||
symbol_cache_path = Path("~/.cache/hs_symbols_cache.pkl").expanduser().resolve()
|
||||
symbol_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if symbol_cache_path.exists():
|
||||
with symbol_cache_path.open("rb") as fp:
|
||||
cache_symbols = pickle.load(fp)
|
||||
symbols |= cache_symbols
|
||||
with symbol_cache_path.open("wb") as fp:
|
||||
pickle.dump(symbols, fp)
|
||||
|
||||
_HS_SYMBOLS = sorted(list(symbols))
|
||||
|
||||
return _HS_SYMBOLS
|
||||
|
||||
|
||||
def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get US stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _US_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
try:
|
||||
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
|
||||
return _symbols
|
||||
|
||||
@deco_retry
|
||||
def _get_nasdaq():
|
||||
_res_symbols = []
|
||||
for _name in ["otherlisted", "nasdaqtraded"]:
|
||||
url = f"ftp://ftp.nasdaqtrader.com/SymbolDirectory/{_name}.txt"
|
||||
df = pd.read_csv(url, sep="|")
|
||||
df = df.rename(columns={"ACT Symbol": "Symbol"})
|
||||
_symbols = df["Symbol"].dropna()
|
||||
_symbols = _symbols.str.replace("$", "-P", regex=False)
|
||||
_symbols = _symbols.str.replace(".W", "-WT", regex=False)
|
||||
_symbols = _symbols.str.replace(".U", "-UN", regex=False)
|
||||
_symbols = _symbols.str.replace(".R", "-RI", regex=False)
|
||||
_symbols = _symbols.str.replace(".", "-", regex=False)
|
||||
_res_symbols += _symbols.unique().tolist()
|
||||
return _res_symbols
|
||||
|
||||
@deco_retry
|
||||
def _get_nyse():
|
||||
url = "https://www.nyse.com/api/quotes/filter"
|
||||
_parms = {
|
||||
"instrumentType": "EQUITY",
|
||||
"pageNumber": 1,
|
||||
"sortColumn": "NORMALIZED_TICKER",
|
||||
"sortOrder": "ASC",
|
||||
"maxResultsPerPage": 10000,
|
||||
"filterToken": "",
|
||||
}
|
||||
resp = requests.post(url, json=_parms)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
try:
|
||||
_symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
_symbols = []
|
||||
return _symbols
|
||||
|
||||
if _US_SYMBOLS is None:
|
||||
_all_symbols = _get_eastmoney() + _get_nasdaq() + _get_nyse()
|
||||
if qlib_data_path is not None:
|
||||
for _index in ["nasdaq100", "sp500"]:
|
||||
ins_df = pd.read_csv(
|
||||
Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"),
|
||||
sep="\t",
|
||||
names=["symbol", "start_date", "end_date"],
|
||||
)
|
||||
_all_symbols += ins_df["symbol"].unique().tolist()
|
||||
|
||||
def _format(s_):
|
||||
s_ = s_.replace(".", "-")
|
||||
s_ = s_.strip("$")
|
||||
s_ = s_.strip("*")
|
||||
return s_
|
||||
|
||||
_US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols))))
|
||||
|
||||
return _US_SYMBOLS
|
||||
|
||||
|
||||
def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get IN stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _IN_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_nifty():
|
||||
url = f"https://www1.nseindia.com/content/equities/EQUITY_L.csv"
|
||||
df = pd.read_csv(url)
|
||||
df = df.rename(columns={"SYMBOL": "Symbol"})
|
||||
df["Symbol"] = df["Symbol"] + ".NS"
|
||||
_symbols = df["Symbol"].dropna()
|
||||
_symbols = _symbols.unique().tolist()
|
||||
return _symbols
|
||||
|
||||
if _IN_SYMBOLS is None:
|
||||
_all_symbols = _get_nifty()
|
||||
if qlib_data_path is not None:
|
||||
for _index in ["nifty"]:
|
||||
ins_df = pd.read_csv(
|
||||
Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"),
|
||||
sep="\t",
|
||||
names=["symbol", "start_date", "end_date"],
|
||||
)
|
||||
_all_symbols += ins_df["symbol"].unique().tolist()
|
||||
|
||||
def _format(s_):
|
||||
s_ = s_.replace(".", "-")
|
||||
s_ = s_.strip("$")
|
||||
s_ = s_.strip("*")
|
||||
return s_
|
||||
|
||||
_IN_SYMBOLS = sorted(set(_all_symbols))
|
||||
|
||||
return _IN_SYMBOLS
|
||||
|
||||
|
||||
def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get Brazil(B3) stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
B3 stock symbols
|
||||
"""
|
||||
global _BR_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_ibovespa():
|
||||
_symbols = []
|
||||
url = "https://www.fundamentus.com.br/detalhes.php?papel="
|
||||
|
||||
# Request
|
||||
agent = {"User-Agent": "Mozilla/5.0"}
|
||||
page = requests.get(url, headers=agent)
|
||||
|
||||
# BeautifulSoup
|
||||
soup = BeautifulSoup(page.content, "html.parser")
|
||||
tbody = soup.find("tbody")
|
||||
|
||||
children = tbody.findChildren("a", recursive=True)
|
||||
for child in children:
|
||||
_symbols.append(str(child).split('"')[-1].split(">")[1].split("<")[0])
|
||||
|
||||
return _symbols
|
||||
|
||||
if _BR_SYMBOLS is None:
|
||||
_all_symbols = _get_ibovespa()
|
||||
if qlib_data_path is not None:
|
||||
for _index in ["ibov"]:
|
||||
ins_df = pd.read_csv(
|
||||
Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"),
|
||||
sep="\t",
|
||||
names=["symbol", "start_date", "end_date"],
|
||||
)
|
||||
_all_symbols += ins_df["symbol"].unique().tolist()
|
||||
|
||||
def _format(s_):
|
||||
s_ = s_.strip()
|
||||
s_ = s_.strip("$")
|
||||
s_ = s_.strip("*")
|
||||
s_ = s_ + ".SA"
|
||||
return s_
|
||||
|
||||
_BR_SYMBOLS = sorted(set(map(_format, _all_symbols)))
|
||||
|
||||
return _BR_SYMBOLS
|
||||
|
||||
|
||||
def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get en fund symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
fund symbols in China
|
||||
"""
|
||||
global _EN_FUND_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://fund.eastmoney.com/js/fundcode_search.js"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = []
|
||||
for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")):
|
||||
data = sub_data.replace('"', "").replace("'", "")
|
||||
# TODO: do we need other information, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE']
|
||||
_symbols.append(data.split(",")[0])
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
return _symbols
|
||||
|
||||
if _EN_FUND_SYMBOLS is None:
|
||||
_all_symbols = _get_eastmoney()
|
||||
|
||||
_EN_FUND_SYMBOLS = sorted(set(_all_symbols))
|
||||
|
||||
return _EN_FUND_SYMBOLS
|
||||
|
||||
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
code, exchange = symbol.split(".")
|
||||
if exchange.lower() in ["sh", "ss"]:
|
||||
res = f"sh{code}"
|
||||
else:
|
||||
res = f"{exchange}{code}"
|
||||
return res.upper() if capital else res.lower()
|
||||
|
||||
|
||||
def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol prefix to sufix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
res = f"{symbol[:-2]}.{symbol[-2:]}"
|
||||
return res.upper() if capital else res.lower()
|
||||
|
||||
|
||||
def deco_retry(retry: int = 5, retry_sleep: int = 3):
|
||||
def deco_func(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
_retry = 5 if callable(retry) else retry
|
||||
_result = None
|
||||
for _i in range(1, _retry + 1):
|
||||
try:
|
||||
_result = func(*args, **kwargs)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{func.__name__}: {_i} :{e}")
|
||||
if _i == _retry:
|
||||
raise
|
||||
|
||||
time.sleep(retry_sleep)
|
||||
return _result
|
||||
|
||||
return wrapper
|
||||
|
||||
return deco_func(retry) if callable(retry) else deco_func
|
||||
|
||||
|
||||
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
|
||||
"""get trading date by shift
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trading_list: list
|
||||
trading calendar list
|
||||
shift : int
|
||||
shift, default is 1
|
||||
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
trading_date = pd.Timestamp(trading_date)
|
||||
left_index = bisect.bisect_left(trading_list, trading_date)
|
||||
try:
|
||||
res = trading_list[left_index + shift]
|
||||
except IndexError:
|
||||
res = trading_date
|
||||
return res
|
||||
|
||||
|
||||
def generate_minutes_calendar_from_daily(
|
||||
calendars: Iterable,
|
||||
freq: str = "1min",
|
||||
am_range: Tuple[str, str] = ("09:30:00", "11:29:00"),
|
||||
pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"),
|
||||
) -> pd.Index:
|
||||
"""generate minutes calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calendars: Iterable
|
||||
daily calendar
|
||||
freq: str
|
||||
by default 1min
|
||||
am_range: Tuple[str, str]
|
||||
AM Time Range, by default China-Stock: ("09:30:00", "11:29:00")
|
||||
pm_range: Tuple[str, str]
|
||||
PM Time Range, by default China-Stock: ("13:00:00", "14:59:00")
|
||||
|
||||
"""
|
||||
daily_format: str = "%Y-%m-%d"
|
||||
res = []
|
||||
for _day in calendars:
|
||||
for _range in [am_range, pm_range]:
|
||||
res.append(
|
||||
pd.date_range(
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}",
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}",
|
||||
freq=freq,
|
||||
)
|
||||
)
|
||||
|
||||
return pd.Index(sorted(set(np.hstack(res))))
|
||||
|
||||
|
||||
def get_instruments(
|
||||
qlib_dir: str,
|
||||
index_name: str,
|
||||
method: str = "parse_instruments",
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
market_index: str = "cn_index",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str
|
||||
qlib data dir, default "Path(__file__).parent/qlib_data"
|
||||
index_name: str
|
||||
index name, value from ["csi100", "csi300"]
|
||||
method: str
|
||||
method, value from ["parse_instruments", "save_new_companies"]
|
||||
freq: str
|
||||
freq, value from ["day", "1min"]
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
market_index: str
|
||||
Where the files to obtain the index are located,
|
||||
for example data_collector.cn_index.collector
|
||||
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("data_collector.{}.collector".format(market_index))
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
getattr(obj, method)()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
@ -0,0 +1,220 @@
|
||||
|
||||
- [Collector Data](#collector-data)
|
||||
- [Get Qlib data](#get-qlib-databin-file)
|
||||
- [Collector *YahooFinance* data to qlib](#collector-yahoofinance-data-to-qlib)
|
||||
- [Automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
- [Using qlib data](#using-qlib-data)
|
||||
|
||||
|
||||
# Collect Data From Yahoo Finance
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
|
||||
**NOTE**: Yahoo! Finance has blocked the access from China. Please change your network if you want to use the Yahoo data crawler.
|
||||
|
||||
> **Examples of abnormal data**
|
||||
|
||||
- [SH000661](https://finance.yahoo.com/quote/000661.SZ/history?period1=1558310400&period2=1590796800&interval=1d&filter=history&frequency=1d)
|
||||
- [SZ300144](https://finance.yahoo.com/quote/300144.SZ/history?period1=1557446400&period2=1589932800&interval=1d&filter=history&frequency=1d)
|
||||
|
||||
We have considered **STOCK PRICE ADJUSTMENT**, but some price series seem still very abnormal.
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
### Get Qlib data(`bin file`)
|
||||
> `qlib-data` from *YahooFinance*, is the data that has been dumped and can be used directly in `qlib`.
|
||||
> This ready-made qlib-data is not updated regularly. If users want the latest data, please follow [these steps](#collector-yahoofinance-data-to-qlib) download the latest data.
|
||||
|
||||
- get data: `python scripts/get_data.py qlib_data`
|
||||
- parameters:
|
||||
- `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data*
|
||||
- `version`: dataset version, value from [`v1`, `v2`], by default `v1`
|
||||
- `v2` end date is *2021-06*, `v1` end date is *2020-09*
|
||||
- user can append data to `v2`: [automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
- **the [benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks) for qlib use `v1`**, *due to the unstable access to historical data by YahooFinance, there are some differences between `v2` and `v1`*
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
- `region`: `cn` or `us` or `in`, by default `cn`
|
||||
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
|
||||
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
|
||||
- examples:
|
||||
```bash
|
||||
# cn 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
# cn 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
# us 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us --interval 1d
|
||||
# us 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data_1min --region us --interval 1min
|
||||
# in 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/in_data --region in --interval 1d
|
||||
# in 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/in_data_1min --region in --interval 1min
|
||||
```
|
||||
|
||||
### Collector *YahooFinance* data to qlib
|
||||
> collector *YahooFinance* data and *dump* into `qlib` format.
|
||||
> If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data.
|
||||
1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data`
|
||||
|
||||
- parameters:
|
||||
- `source_dir`: save the directory
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
> **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`**
|
||||
- `region`: `CN` or `US` or `IN` or `BR`, by default `CN`
|
||||
- `delay`: `time.sleep(delay)`, by default *0.5*
|
||||
- `start`: start datetime, by default *"2000-01-01"*; *closed interval(including start)*
|
||||
- `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)*
|
||||
- `max_workers`: get the number of concurrent symbols, it is not recommended to change this parameter in order to maintain the integrity of the symbol data, by default *1*
|
||||
- `check_data_length`: check the number of rows per *symbol*, by default `None`
|
||||
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
|
||||
- `max_collector_count`: number of *"failed"* symbol retries, by default 2
|
||||
- examples:
|
||||
```bash
|
||||
# cn 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_data --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region CN
|
||||
# cn 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_data_1min --delay 1 --interval 1min --region CN
|
||||
|
||||
# us 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_data --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
|
||||
# us 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_data_1min --delay 1 --interval 1min --region US
|
||||
|
||||
# in 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_data --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region IN
|
||||
# in 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_data_1min --delay 1 --interval 1min --region IN
|
||||
|
||||
# br 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/br_data --start 2003-01-03 --end 2022-03-01 --delay 1 --interval 1d --region BR
|
||||
# br 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/br_data_1min --delay 1 --interval 1min --region BR
|
||||
```
|
||||
2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data`
|
||||
|
||||
- parameters:
|
||||
- `source_dir`: csv directory
|
||||
- `normalize_dir`: result directory
|
||||
- `max_workers`: number of concurrent, by default *1*
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
> if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None`
|
||||
- `region`: `CN` or `US` or `IN`, by default `CN`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
|
||||
- `qlib_data_1d_dir`: qlib directory(1d data)
|
||||
```
|
||||
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
|
||||
or:
|
||||
download 1d data from YahooFinance
|
||||
|
||||
```
|
||||
- examples:
|
||||
```bash
|
||||
# normalize 1d cn
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_data --normalize_dir ~/.qlib/stock_data/source/cn_1d_nor --region CN --interval 1d
|
||||
|
||||
# normalize 1min cn
|
||||
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/cn_data_1min --normalize_dir ~/.qlib/stock_data/source/cn_1min_nor --region CN --interval 1min
|
||||
|
||||
# normalize 1d br
|
||||
python scripts/data_collector/yahoo/collector.py normalize_data --source_dir ~/.qlib/stock_data/source/br_data --normalize_dir ~/.qlib/stock_data/source/br_1d_nor --region BR --interval 1d
|
||||
|
||||
# normalize 1min br
|
||||
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/br_data --source_dir ~/.qlib/stock_data/source/br_data_1min --normalize_dir ~/.qlib/stock_data/source/br_1min_nor --region BR --interval 1min
|
||||
```
|
||||
3. dump data: `python scripts/dump_bin.py dump_all`
|
||||
|
||||
- parameters:
|
||||
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `qlib_dir`: qlib(dump) data director
|
||||
- `freq`: transaction frequency, by default `day`
|
||||
> `freq_map = {1d:day, 1mih: 1min}`
|
||||
- `max_workers`: number of threads, by default *16*
|
||||
- `include_fields`: dump fields, by default `""`
|
||||
- `exclude_fields`: fields not dumped, by default `"""
|
||||
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- examples:
|
||||
```bash
|
||||
# dump 1d cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_data --freq day --exclude_fields date,symbol
|
||||
# dump 1min cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/cn_data_1min --freq 1min --exclude_fields date,symbol
|
||||
```
|
||||
|
||||
### Automatic update of daily frequency data(from yahoo finance)
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
* use *crontab*: `crontab -e`
|
||||
* set up timed tasks:
|
||||
|
||||
```
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
```
|
||||
* **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
* Manual update of data
|
||||
```
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
```
|
||||
* `trading_date`: start of trading day
|
||||
* `end_date`: end of trading day(not included)
|
||||
* `check_data_length`: check the number of rows per *symbol*, by default `None`
|
||||
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
|
||||
|
||||
* `scripts/data_collector/yahoo/collector.py update_data_to_bin` parameters:
|
||||
* `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
* `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
* `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
* `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
* `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
* `region`: region, value from ["CN", "US"], default "CN"
|
||||
|
||||
|
||||
## Using qlib data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
# 1d data cn
|
||||
# freq=day, freq default day
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
# 1min data cn
|
||||
# freq=1min
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data_1min", region="cn")
|
||||
inst = D.list_instruments(D.instruments("all"), freq="1min", as_list=True)
|
||||
# get 100 symbols
|
||||
df = D.features(inst[:100], ["$close"], freq="1min")
|
||||
# get all symbol data
|
||||
# df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
|
||||
# 1d data us
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/us_data", region="us")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
# 1min data us
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/us_data_1min", region="cn")
|
||||
inst = D.list_instruments(D.instruments("all"), freq="1min", as_list=True)
|
||||
# get 100 symbols
|
||||
df = D.features(inst[:100], ["$close"], freq="1min")
|
||||
# get all symbol data
|
||||
# df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
```
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,12 @@
|
||||
loguru
|
||||
fire
|
||||
requests
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
yahooquery
|
||||
joblib
|
||||
beautifulsoup4
|
||||
bs4
|
||||
soupsieve
|
@ -0,0 +1,286 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
TODO:
|
||||
- A more well-designed PIT database is required.
|
||||
- seperated insert, delete, update, query operations are required.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import shutil
|
||||
import struct
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Union
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import fname_to_code, code_to_fname, get_period_offset
|
||||
from qlib.config import C
|
||||
|
||||
|
||||
class DumpPitData:
|
||||
PIT_DIR_NAME = "financial"
|
||||
PIT_CSV_SEP = ","
|
||||
DATA_FILE_SUFFIX = ".data"
|
||||
INDEX_FILE_SUFFIX = ".index"
|
||||
|
||||
INTERVAL_quarterly = "quarterly"
|
||||
INTERVAL_annual = "annual"
|
||||
|
||||
PERIOD_DTYPE = C.pit_record_type["period"]
|
||||
INDEX_DTYPE = C.pit_record_type["index"]
|
||||
DATA_DTYPE = "".join(
|
||||
[
|
||||
C.pit_record_type["date"],
|
||||
C.pit_record_type["period"],
|
||||
C.pit_record_type["value"],
|
||||
C.pit_record_type["index"],
|
||||
]
|
||||
)
|
||||
|
||||
NA_INDEX = C.pit_record_nan["index"]
|
||||
|
||||
INDEX_DTYPE_SIZE = struct.calcsize(INDEX_DTYPE)
|
||||
PERIOD_DTYPE_SIZE = struct.calcsize(PERIOD_DTYPE)
|
||||
DATA_DTYPE_SIZE = struct.calcsize(DATA_DTYPE)
|
||||
|
||||
UPDATE_MODE = "update"
|
||||
ALL_MODE = "all"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "quarterly",
|
||||
max_workers: int = 16,
|
||||
date_column_name: str = "date",
|
||||
period_column_name: str = "period",
|
||||
value_column_name: str = "value",
|
||||
field_column_name: str = "field",
|
||||
file_suffix: str = ".csv",
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
csv_path: str
|
||||
stock data path or directory
|
||||
qlib_dir: str
|
||||
qlib(dump) data director
|
||||
backup_dir: str, default None
|
||||
if backup_dir is not None, backup qlib_dir to backup_dir
|
||||
freq: str, default "quarterly"
|
||||
data frequency
|
||||
max_workers: int, default None
|
||||
number of threads
|
||||
date_column_name: str, default "date"
|
||||
the name of the date field in the csv
|
||||
file_suffix: str, default ".csv"
|
||||
file suffix
|
||||
include_fields: tuple
|
||||
dump fields
|
||||
exclude_fields: tuple
|
||||
fields not dumped
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
if isinstance(exclude_fields, str):
|
||||
exclude_fields = exclude_fields.split(",")
|
||||
if isinstance(include_fields, str):
|
||||
include_fields = include_fields.split(",")
|
||||
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self.file_suffix = file_suffix
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
if limit_nums is not None:
|
||||
self.csv_files = self.csv_files[: int(limit_nums)]
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
|
||||
if backup_dir is not None:
|
||||
self._backup_qlib_dir(Path(backup_dir).expanduser())
|
||||
|
||||
self.works = max_workers
|
||||
self.date_column_name = date_column_name
|
||||
self.period_column_name = period_column_name
|
||||
self.value_column_name = value_column_name
|
||||
self.field_column_name = field_column_name
|
||||
|
||||
self._mode = self.ALL_MODE
|
||||
|
||||
def _backup_qlib_dir(self, target_dir: Path):
|
||||
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
|
||||
|
||||
def get_source_data(self, file_path: Path) -> pd.DataFrame:
|
||||
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
|
||||
df[self.value_column_name] = df[self.value_column_name].astype("float32")
|
||||
df[self.date_column_name] = df[self.date_column_name].str.replace("-", "").astype("int32")
|
||||
# df.drop_duplicates([self.date_field_name], inplace=True)
|
||||
return df
|
||||
|
||||
def get_symbol_from_file(self, file_path: Path) -> str:
|
||||
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
|
||||
|
||||
def get_dump_fields(self, df: Iterable[str]) -> Iterable[str]:
|
||||
return (
|
||||
set(self._include_fields)
|
||||
if self._include_fields
|
||||
else set(df[self.field_column_name]) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else set(df[self.field_column_name])
|
||||
)
|
||||
|
||||
def get_filenames(self, symbol, field, interval):
|
||||
dir_name = self.qlib_dir.joinpath(self.PIT_DIR_NAME, symbol)
|
||||
dir_name.mkdir(parents=True, exist_ok=True)
|
||||
return (
|
||||
dir_name.joinpath(f"{field}_{interval[0]}{self.DATA_FILE_SUFFIX}".lower()),
|
||||
dir_name.joinpath(f"{field}_{interval[0]}{self.INDEX_FILE_SUFFIX}".lower()),
|
||||
)
|
||||
|
||||
def _dump_pit(
|
||||
self,
|
||||
file_path: str,
|
||||
interval: str = "quarterly",
|
||||
overwrite: bool = False,
|
||||
):
|
||||
"""
|
||||
dump data as the following format:
|
||||
`/path/to/<field>.data`
|
||||
[date, period, value, _next]
|
||||
[date, period, value, _next]
|
||||
[...]
|
||||
`/path/to/<field>.index`
|
||||
[first_year, index, index, ...]
|
||||
|
||||
`<field.data>` contains the data as the point-in-time (PIT) order: `value` of `period`
|
||||
is published at `date`, and its successive revised value can be found at `_next` (linked list).
|
||||
|
||||
`<field>.index` contains the index of value for each period (quarter or year). To save
|
||||
disk space, we only store the `first_year` as its followings periods can be easily infered.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock symbol
|
||||
interval: str
|
||||
data interval
|
||||
overwrite: bool
|
||||
whether overwrite existing data or update only
|
||||
"""
|
||||
symbol = self.get_symbol_from_file(file_path)
|
||||
try:
|
||||
df = self.get_source_data(file_path)
|
||||
except Exception as err:
|
||||
print('Error @', file_path)
|
||||
print(err)
|
||||
if df.empty:
|
||||
logger.warning(f"{symbol} file is empty")
|
||||
return
|
||||
for field in self.get_dump_fields(df):
|
||||
df_sub = df.query(f'{self.field_column_name}=="{field}"').sort_values(self.date_column_name)
|
||||
if df_sub.empty:
|
||||
logger.warning(f"field {field} of {symbol} is empty")
|
||||
continue
|
||||
data_file, index_file = self.get_filenames(symbol, field, interval)
|
||||
|
||||
## calculate first & last period
|
||||
start_year = df_sub[self.period_column_name].min()
|
||||
end_year = df_sub[self.period_column_name].max()
|
||||
if interval == self.INTERVAL_quarterly:
|
||||
start_year //= 100
|
||||
end_year //= 100
|
||||
|
||||
# adjust `first_year` if existing data found
|
||||
if not overwrite and index_file.exists():
|
||||
with open(index_file, "rb") as fi:
|
||||
(first_year,) = struct.unpack(self.PERIOD_DTYPE, fi.read(self.PERIOD_DTYPE_SIZE))
|
||||
n_years = len(fi.read()) // self.INDEX_DTYPE_SIZE
|
||||
if interval == self.INTERVAL_quarterly:
|
||||
n_years //= 4
|
||||
start_year = first_year + n_years
|
||||
else:
|
||||
with open(index_file, "wb") as f:
|
||||
f.write(struct.pack(self.PERIOD_DTYPE, start_year))
|
||||
first_year = start_year
|
||||
|
||||
# if data already exists, continue to the next field
|
||||
if start_year > end_year:
|
||||
logger.warning(f"{symbol}-{field} data already exists, continue to the next field")
|
||||
continue
|
||||
|
||||
# dump index filled with NA
|
||||
with open(index_file, "ab") as fi:
|
||||
for year in range(start_year, end_year + 1):
|
||||
if interval == self.INTERVAL_quarterly:
|
||||
fi.write(struct.pack(self.INDEX_DTYPE * 4, *[self.NA_INDEX] * 4))
|
||||
else:
|
||||
fi.write(struct.pack(self.INDEX_DTYPE, self.NA_INDEX))
|
||||
|
||||
# if data already exists, remove overlapped data
|
||||
if not overwrite and data_file.exists():
|
||||
with open(data_file, "rb") as fd:
|
||||
fd.seek(-self.DATA_DTYPE_SIZE, 2)
|
||||
last_date, _, _, _ = struct.unpack(self.DATA_DTYPE, fd.read())
|
||||
df_sub = df_sub.query(f"{self.date_column_name}>{last_date}")
|
||||
# otherwise,
|
||||
# 1) truncate existing file or create a new file with `wb+` if overwrite,
|
||||
# 2) or append existing file or create a new file with `ab+` if not overwrite
|
||||
else:
|
||||
with open(data_file, "wb+" if overwrite else "ab+"):
|
||||
pass
|
||||
|
||||
with open(data_file, "rb+") as fd, open(index_file, "rb+") as fi:
|
||||
|
||||
# update index if needed
|
||||
for i, row in df_sub.iterrows():
|
||||
# get index
|
||||
offset = get_period_offset(first_year, row.period, interval == self.INTERVAL_quarterly)
|
||||
|
||||
fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset)
|
||||
(cur_index,) = struct.unpack(self.INDEX_DTYPE, fi.read(self.INDEX_DTYPE_SIZE))
|
||||
|
||||
# Case I: new data => update `_next` with current index
|
||||
if cur_index == self.NA_INDEX:
|
||||
fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset)
|
||||
fi.write(struct.pack(self.INDEX_DTYPE, fd.tell()))
|
||||
# Case II: previous data exists => find and update the last `_next`
|
||||
else:
|
||||
_cur_fd = fd.tell()
|
||||
prev_index = self.NA_INDEX
|
||||
while cur_index != self.NA_INDEX: # NOTE: first iter always != NA_INDEX
|
||||
fd.seek(cur_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE)
|
||||
prev_index = cur_index
|
||||
(cur_index,) = struct.unpack(self.INDEX_DTYPE, fd.read(self.INDEX_DTYPE_SIZE))
|
||||
fd.seek(prev_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE)
|
||||
fd.write(struct.pack(self.INDEX_DTYPE, _cur_fd)) # NOTE: add _next pointer
|
||||
fd.seek(_cur_fd)
|
||||
|
||||
# dump data
|
||||
fd.write(struct.pack(self.DATA_DTYPE, row.date, row.period, row.value, self.NA_INDEX))
|
||||
|
||||
def dump(self, interval="quarterly", overwrite=False):
|
||||
logger.info("start dump pit data......")
|
||||
_dump_func = partial(self._dump_pit, interval=interval, overwrite=overwrite)
|
||||
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for _ in executor.map(_dump_func, self.csv_files):
|
||||
p_bar.update()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.dump()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(DumpPitData)
|
@ -0,0 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(GetData)
|
@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys;
|
||||
sys.path.append('C:\\Program Files\\tinysoft\\Analyse.NET');
|
||||
import TSLPy3 as ts;
|
||||
|
||||
|
||||
# 登录天软
|
||||
def login():
|
||||
if not ts.Logined():
|
||||
a = ts.ConnectServer('tsl.tinysoft.com.cn', 443);
|
||||
b = ts.LoginServer('fsfundsh', 'fsfund');
|
||||
if a != 0 or b[0] != 0:
|
||||
raise Exception("Cannot connect to tsl server");
|
||||
|
||||
# 断开天软链接
|
||||
def logoff():
|
||||
if ts.Logined():
|
||||
ts.Disconnect();
|
||||
|
||||
# 检查是否链接成功
|
||||
def logined():
|
||||
return ts.Logined();
|
@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# In[34]:
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
from datetime import datetime
|
||||
def to_last_season(datestamp):
|
||||
'''将当前季末日期转换为上一季末日期'''
|
||||
last_season = (pd.Series(datestamp) - pd.DateOffset(months=3)).apply(
|
||||
lambda x: x + pd.to_timedelta('1 days') if x.month == 3 else x)
|
||||
return last_season
|
||||
|
||||
def intdate(datestamp):
|
||||
'''日期转换为8位整数'''
|
||||
return int(datestamp.strftime('%Y%m%d'))
|
||||
|
||||
def non_rec_deduct_profit_single_growth(df):
|
||||
profit = df['扣非净利润'].values
|
||||
ss1 = pd.to_datetime(df['截止日'], format='%Y%m%d') # 当季
|
||||
ss2 = to_last_season(ss1) # 上一季
|
||||
ss3 = ss1 - pd.DateOffset(years=1) # 上一年同季
|
||||
ss4 = to_last_season(ss3) # 上一年上一季
|
||||
profit_mat = [] # [当季利润,上季利润,去年利润,去年上季利润,单季利润,去年单季利润]
|
||||
single_profit_pctchg = [] # 单季度利润同比
|
||||
for i in range(ss1.shape[0]):
|
||||
s1 = ss1[:i+1].max() # 当前公告日能获取的最新报告期
|
||||
s2 = to_last_season(s1)[0]
|
||||
s3 = s1 - pd.DateOffset(years=1)
|
||||
s4 = to_last_season(s3)[0]
|
||||
latest_seasons = [intdate(s1), intdate(s2), intdate(s3), intdate(s4)] #计算最新同比需要的报告期
|
||||
adjust_season = np.where(np.in1d(latest_seasons, intdate(ss1[i])))[0]
|
||||
|
||||
if len(adjust_season):
|
||||
adjust_season = adjust_season[0]
|
||||
if adjust_season == 0:
|
||||
# 最新报告期,计算最新同比
|
||||
if ss2[i].month != 12:
|
||||
s2 = np.where(ss1[:i] == ss2[i])[0]
|
||||
s4 = np.where(ss1[:i] == ss4[i])[0]
|
||||
p2 = profit[s2[-1]] if len(s2) else np.nan
|
||||
p4 = profit[s4[-1]] if len(s4) else np.nan
|
||||
else:
|
||||
p2 = 0
|
||||
p4 = 0
|
||||
s3 = np.where(ss1[:i] == ss3[i])[0]
|
||||
p3 = profit[s3[-1]] if len(s3) else np.nan
|
||||
p1 = profit[i]
|
||||
profit_mat.append([p1, p2, p3, p4, p1-p2, p3-p4])
|
||||
else:
|
||||
# 非最新报告期,但调整数据在最新报告期计算的范围内
|
||||
profit_adjust = profit_mat[-1].copy()
|
||||
profit_adjust[adjust_season] = 0 if latest_seasons[adjust_season] % 10000 == 1231 and adjust_season in [1, 3] else profit[i]# 调整相应季度为最新数据
|
||||
if adjust_season <= 1:
|
||||
profit_adjust[4] = profit_adjust[0] - profit_adjust[1]
|
||||
else:
|
||||
profit_adjust[5] = profit_adjust[2] - profit_adjust[3]
|
||||
profit_mat.append(profit_adjust)
|
||||
else:
|
||||
# 非最新报告期,但调整数据不在最新报告期的计算范围内
|
||||
profit_mat.append(profit_mat[-1])
|
||||
|
||||
single_profit_pctchg.append((profit_mat[-1][4] - profit_mat[-1][5]) / abs(profit_mat[-1][5])
|
||||
if profit_mat[-1][5] != 0 else np.nan)
|
||||
profit_mat = pd.DataFrame(profit_mat, columns=['当季',
|
||||
'上季',
|
||||
'去年当季',
|
||||
'去年上季',
|
||||
'单季',
|
||||
'去年单季'])
|
||||
single_profit_pctchg = pd.Series(single_profit_pctchg, name='同比')
|
||||
|
||||
return pd.concat([pd.DataFrame({'StockID': df['StockID'],
|
||||
'截止日': ss1.apply(intdate),
|
||||
'公布日': df['公布日']}),
|
||||
profit_mat,
|
||||
single_profit_pctchg,
|
||||
df['公告类型']], axis=1)
|
||||
|
||||
# stocklist = os.listdir('D:/数据/天软基本面数据/42.主要财务指标')
|
||||
# df_list = []
|
||||
# for stock in stocklist:
|
||||
# df = pd.read_csv('D:/数据/天软基本面数据/42.主要财务指标/' + stock)
|
||||
# factor = non_rec_deduct_profit_single_growth(df)
|
||||
# df_list.append(factor)
|
||||
# factor_df = pd.concat(df_list, axis=1, join='outer')
|
||||
# factor_df.fillna(method='ffill', inplace=True)
|
||||
# factor_df = factor_df.join(factor, how='outer')
|
||||
|
||||
|
||||
# In[35]:
|
||||
|
||||
|
||||
# 业绩预测
|
||||
# data = pd.read_csv('D:/数据/天软基本面数据/40.业绩预测/SZ000048.csv')
|
||||
# df1 = pd.DataFrame({'StockID': data['StockID'],
|
||||
# '截止日': data['截止日'],
|
||||
# '公布日': data['公布日'],
|
||||
# '归母净利润': (data['盈利金额上限'] + data['盈利金额下限']) * 10000/2,
|
||||
# '公告类型': '业绩预测'})
|
||||
# 业绩快报
|
||||
# data = pd.read_csv('D:/数据/天软基本面数据/41.业绩快报/SZ000048.csv')
|
||||
# df2 = pd.DataFrame({'StockID': data['StockID'],
|
||||
# '截止日': data['截止日'],
|
||||
# '公布日': data['公布日'],
|
||||
# '归母净利润': data['归属于母公司所有者净利润'],
|
||||
# '公告类型': '业绩快报'})
|
||||
# 正式报告
|
||||
filelist = os.listdir('D:/数据/天软基本面数据/46.合并利润分配表')
|
||||
for file in filelist:
|
||||
data1 = pd.read_csv('D:/数据/天软基本面数据/42.主要财务指标/' + file)
|
||||
data2 = pd.read_csv('D:/数据/天软基本面数据/46.合并利润分配表/' + file)
|
||||
A = pd.DataFrame({'StockID': data1['StockID'],
|
||||
'截止日': data1['截止日'],
|
||||
'公布日': data1['公布日'],
|
||||
'首次公布扣非净利润': data1['扣除非经常性损益后的净利润']})
|
||||
B = pd.DataFrame({'StockID': data2['StockID'],
|
||||
'截止日': data2['截止日'],
|
||||
'公布日': data2['公布日'],
|
||||
'归母净利润': data2['归属于母公司所有者净利润']})
|
||||
A.index = A['截止日'].astype('str') + A['公布日'].astype('str')
|
||||
B.index = B['截止日'].astype('str') + B['公布日'].astype('str')
|
||||
df3 = B.join(A['首次公布扣非净利润'], how='left')
|
||||
df3.sort_values(['公布日', '截止日'], inplace=True)
|
||||
df3.reset_index(drop=True, inplace=True)
|
||||
df3['非经常性损益'] = df3['归母净利润'] - df3['首次公布扣非净利润']
|
||||
df3['非经常性损益'] = df3.groupby('截止日')['非经常性损益'].transform(lambda x: x.fillna(method='ffill'))
|
||||
df3['扣非净利润'] = df3['归母净利润'] - df3['非经常性损益']
|
||||
df3 = df3[['StockID', '截止日', '公布日', '扣非净利润']]
|
||||
df3['公告类型'] = '正式报告'
|
||||
|
||||
# df = pd.concat([df1, df2, df3], axis=0)
|
||||
# df.sort_values(['公布日', '截止日'], inplace=True)
|
||||
# df.reset_index(inplace=True)
|
||||
result = non_rec_deduct_profit_single_growth(df3)
|
||||
result.to_csv('D:/数据/天软基本面数据/单季度扣非同比/' + file, index=False)
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,58 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
export_cols_bs = {
|
||||
'存货' : 'INVENTORIES',
|
||||
'负债合计' : 'TOT_LIAB',
|
||||
'流动资产合计' : 'TOT_CUR_ASSETS'
|
||||
}
|
||||
|
||||
export_cols_pl = {
|
||||
'净利润' : 'NET_PROFIT_INCL_MIN_INT_INC',
|
||||
'营业外收入' : 'PLUS_NON_OPER_REV',
|
||||
'营业收入' : 'OPER_REV'
|
||||
}
|
||||
|
||||
def adj_split(df):
|
||||
df.reset_index(inplace=True)
|
||||
|
||||
df_ori = df[df['截止日'] == df['数据报告期']]
|
||||
df_adj = df[df['截止日'] != df['数据报告期']]
|
||||
return df_ori, df_adj
|
||||
|
||||
|
||||
def to_qlib_format_pl(df):
|
||||
index_cols = ['截止日', '数据报告期', '公布日']
|
||||
sel_cols = index_cols + export_cols_pl
|
||||
|
||||
df_export = df[sel_cols]
|
||||
df_export.set_index(index_cols, inplace=True)
|
||||
|
||||
df_export_ori, df_export_adj = adj_split(df_export)
|
||||
df_export_ori.set_index(index_cols, inplace=True)
|
||||
df_export_adj.set_index(index_cols, inplace=True)
|
||||
|
||||
adj_col_rename = {name : name+'(调整)' for name in export_cols_pl.keys()}
|
||||
df_export_adj.rename(columns=adj_col_rename, inplace=True)
|
||||
|
||||
df_list = []
|
||||
|
||||
def _T(df, df_list):
|
||||
for col in list(df.columns):
|
||||
df_tmp = df[[col]].copy(deep=True)
|
||||
df_tmp['field'] = col
|
||||
df_tmp.rename(columns={col:'value'}, inplace=True)
|
||||
df_list.append(df_tmp)
|
||||
|
||||
_T(df_export_adj, df_list)
|
||||
_T(df_export_ori, df_list)
|
||||
|
||||
df = pd.concat(df_list, axis=0)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def single_quarter(df):
|
||||
DataFrame.sort_values(by, axis=0, ascending=True, inplace=False)
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue