You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
211 lines
7.2 KiB
211 lines
7.2 KiB
|
4 days ago
|
#!/usr/bin/env python
|
||
|
|
"""
|
||
|
|
Fetch embedding data from DolphinDB and save to parquet.
|
||
|
|
|
||
|
|
This script:
|
||
|
|
1. Connects to DolphinDB
|
||
|
|
2. Queries the dwm_1day_multicast_csencode table
|
||
|
|
3. Filters by version (default: 'csiallx_feature2_ntrla_flag_pnlnorm')
|
||
|
|
4. Filters by date range
|
||
|
|
5. Transforms columns (m_nDate -> datetime, code -> instrument)
|
||
|
|
6. Saves to local parquet file
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import polars as pl
|
||
|
|
import pandas as pd
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
# DolphinDB config (from CLAUDE.md)
|
||
|
|
DDB_CONFIG = {
|
||
|
|
"host": "192.168.1.146",
|
||
|
|
"port": 8848,
|
||
|
|
"username": "admin",
|
||
|
|
"password": "123456"
|
||
|
|
}
|
||
|
|
|
||
|
|
DB_PATH = "dfs://daily_stock_run_multicast"
|
||
|
|
TABLE_NAME = "dwm_1day_multicast_csencode"
|
||
|
|
DEFAULT_VERSION = "csix_alpha158b_ext2_zscore_vae4"
|
||
|
|
DEFAULT_START_DATE = "2019-01-01"
|
||
|
|
DEFAULT_END_DATE = "2025-12-31"
|
||
|
|
OUTPUT_FILE = "../data/embeddings_from_ddb.parquet"
|
||
|
|
|
||
|
|
|
||
|
|
def fetch_embeddings(
|
||
|
|
start_date: str = DEFAULT_START_DATE,
|
||
|
|
end_date: str = DEFAULT_END_DATE,
|
||
|
|
version: str = DEFAULT_VERSION,
|
||
|
|
output_file: str = OUTPUT_FILE
|
||
|
|
) -> pl.DataFrame:
|
||
|
|
"""
|
||
|
|
Fetch embedding data from DolphinDB.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
start_date: Start date filter (YYYY-MM-DD)
|
||
|
|
end_date: End date filter (YYYY-MM-DD)
|
||
|
|
version: Version string to filter by
|
||
|
|
output_file: Output parquet file path
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Polars DataFrame with columns: [datetime, instrument, embedding_0, embedding_1, ...]
|
||
|
|
"""
|
||
|
|
print("=" * 60)
|
||
|
|
print("Fetching embedding data from DolphinDB")
|
||
|
|
print("=" * 60)
|
||
|
|
print(f"Database: {DB_PATH}")
|
||
|
|
print(f"Table: {TABLE_NAME}")
|
||
|
|
print(f"Version: {version}")
|
||
|
|
print(f"Date range: {start_date} to {end_date}")
|
||
|
|
|
||
|
|
# Connect to DolphinDB
|
||
|
|
try:
|
||
|
|
from qshare.io.ddb import get_ddb_sess
|
||
|
|
sess = get_ddb_sess(host=DDB_CONFIG["host"], port=DDB_CONFIG["port"])
|
||
|
|
print(f"Connected to DolphinDB at {DDB_CONFIG['host']}:{DDB_CONFIG['port']}")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error connecting to DolphinDB: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
# Convert date strings to DolphinDB date format (YYYY.MM.DD)
|
||
|
|
start_ddb = start_date.replace("-", ".")
|
||
|
|
end_ddb = end_date.replace("-", ".")
|
||
|
|
|
||
|
|
# Build SQL query with filters in the WHERE clause
|
||
|
|
# Note: DolphinDB requires date() function for date literals
|
||
|
|
# Use single-line SQL to avoid parsing issues
|
||
|
|
sql = f'select * from loadTable("{DB_PATH}", "{TABLE_NAME}") where version = "{version}" and m_nDate >= date({start_ddb}) and m_nDate <= date({end_ddb})'
|
||
|
|
|
||
|
|
print(f"Executing SQL: {sql.strip()}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Execute query and get pandas DataFrame
|
||
|
|
df_pd = sess.run(sql)
|
||
|
|
print(f"Fetched {len(df_pd)} rows from DolphinDB")
|
||
|
|
print(f"Columns: {df_pd.columns.tolist()}")
|
||
|
|
if len(df_pd) > 0:
|
||
|
|
print(f"Sample:\n{df_pd.head()}")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error executing query: {e}")
|
||
|
|
raise
|
||
|
|
finally:
|
||
|
|
sess.close()
|
||
|
|
|
||
|
|
# Convert to Polars
|
||
|
|
df = pl.from_pandas(df_pd)
|
||
|
|
print(f"Columns in result: {df.columns}")
|
||
|
|
|
||
|
|
# Transform columns
|
||
|
|
# Rename m_nDate -> datetime and convert to uint32 (YYYYMMDD)
|
||
|
|
if 'm_nDate' in df.columns:
|
||
|
|
df = df.rename({"m_nDate": "datetime"})
|
||
|
|
|
||
|
|
if df["datetime"].dtype == pl.Datetime:
|
||
|
|
df = df.with_columns([
|
||
|
|
pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime")
|
||
|
|
])
|
||
|
|
elif df["datetime"].dtype == pl.Date:
|
||
|
|
df = df.with_columns([
|
||
|
|
pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime")
|
||
|
|
])
|
||
|
|
elif df["datetime"].dtype in [pl.Utf8, pl.String]:
|
||
|
|
df = df.with_columns([
|
||
|
|
pl.col("datetime").str.replace("-", "").cast(pl.UInt32).alias("datetime")
|
||
|
|
])
|
||
|
|
else:
|
||
|
|
df = df.with_columns([pl.col("datetime").cast(pl.UInt32).alias("datetime")])
|
||
|
|
|
||
|
|
# Rename code -> instrument and convert to uint32
|
||
|
|
if 'code' in df.columns:
|
||
|
|
df = df.rename({"code": "instrument"})
|
||
|
|
|
||
|
|
# Convert TS code (e.g., 'SH600085') to uint32 by removing prefix and casting
|
||
|
|
df = df.with_columns([
|
||
|
|
pl.col("instrument")
|
||
|
|
.str.replace("SH", "")
|
||
|
|
.str.replace("SZ", "")
|
||
|
|
.str.replace("BJ", "")
|
||
|
|
.cast(pl.UInt32)
|
||
|
|
.alias("instrument")
|
||
|
|
])
|
||
|
|
|
||
|
|
# Drop version column if present (no longer needed)
|
||
|
|
if 'version' in df.columns:
|
||
|
|
df = df.drop('version')
|
||
|
|
|
||
|
|
# Check if 'values' column contains lists (embedding vectors)
|
||
|
|
if 'values' in df.columns and df['values'].dtype == pl.List:
|
||
|
|
# Get the embedding dimension from the first row
|
||
|
|
first_val = df['values'][0]
|
||
|
|
if first_val is not None:
|
||
|
|
emb_dim = len(first_val)
|
||
|
|
print(f"Detected embedding dimension: {emb_dim}")
|
||
|
|
|
||
|
|
# Expand the list column to separate embedding columns
|
||
|
|
embedding_cols = []
|
||
|
|
for i in range(emb_dim):
|
||
|
|
col_name = f"embedding_{i}"
|
||
|
|
embedding_cols.append(col_name)
|
||
|
|
df = df.with_columns([
|
||
|
|
pl.col('values').list.get(i).alias(col_name)
|
||
|
|
])
|
||
|
|
|
||
|
|
# Drop the original values column
|
||
|
|
df = df.drop('values')
|
||
|
|
|
||
|
|
# Reorder columns: datetime, instrument, embedding_0, embedding_1, ...
|
||
|
|
core_cols = ['datetime', 'instrument']
|
||
|
|
final_cols = core_cols + embedding_cols
|
||
|
|
df = df.select(final_cols)
|
||
|
|
|
||
|
|
print(f"Expanded embeddings into {emb_dim} columns")
|
||
|
|
else:
|
||
|
|
# Identify embedding columns (typically named 'feature_0', 'feature_1', etc. or 'emb_0', 'emb_1', etc.)
|
||
|
|
# Keep datetime, instrument, and any embedding/feature columns
|
||
|
|
core_cols = ['datetime', 'instrument']
|
||
|
|
embedding_cols = [c for c in df.columns if c not in core_cols + ['version']]
|
||
|
|
|
||
|
|
# Select and order columns
|
||
|
|
final_cols = core_cols + sorted(embedding_cols)
|
||
|
|
df = df.select(final_cols)
|
||
|
|
|
||
|
|
print(f"\nTransformed data:")
|
||
|
|
print(f" Shape: {df.shape}")
|
||
|
|
print(f" Columns: {df.columns[:10]}..." if len(df.columns) > 10 else f" Columns: {df.columns}")
|
||
|
|
print(f" Date range: {df['datetime'].min()} to {df['datetime'].max()}")
|
||
|
|
print(f" Instrument count: {df['instrument'].n_unique()}")
|
||
|
|
print(f" Sample:\n{df.head()}")
|
||
|
|
|
||
|
|
# Save to parquet
|
||
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||
|
|
df.write_parquet(output_file)
|
||
|
|
print(f"\nSaved to: {output_file}")
|
||
|
|
|
||
|
|
return df
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import argparse
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(description="Fetch embedding data from DolphinDB")
|
||
|
|
parser.add_argument("--start-date", type=str, default=DEFAULT_START_DATE,
|
||
|
|
help="Start date (YYYY-MM-DD)")
|
||
|
|
parser.add_argument("--end-date", type=str, default=DEFAULT_END_DATE,
|
||
|
|
help="End date (YYYY-MM-DD)")
|
||
|
|
parser.add_argument("--version", type=str, default=DEFAULT_VERSION,
|
||
|
|
help="Version string to filter by")
|
||
|
|
parser.add_argument("--output", type=str, default=OUTPUT_FILE,
|
||
|
|
help="Output parquet file")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
df = fetch_embeddings(
|
||
|
|
start_date=args.start_date,
|
||
|
|
end_date=args.end_date,
|
||
|
|
version=args.version,
|
||
|
|
output_file=args.output
|
||
|
|
)
|
||
|
|
|
||
|
|
print("\nDone!")
|