You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

211 lines
7.2 KiB

#!/usr/bin/env python
"""
Fetch embedding data from DolphinDB and save to parquet.
This script:
1. Connects to DolphinDB
2. Queries the dwm_1day_multicast_csencode table
3. Filters by version (default: 'csiallx_feature2_ntrla_flag_pnlnorm')
4. Filters by date range
5. Transforms columns (m_nDate -> datetime, code -> instrument)
6. Saves to local parquet file
"""
import os
import polars as pl
import pandas as pd
from datetime import datetime
from typing import Optional
# DolphinDB config (from CLAUDE.md)
DDB_CONFIG = {
"host": "192.168.1.146",
"port": 8848,
"username": "admin",
"password": "123456"
}
DB_PATH = "dfs://daily_stock_run_multicast"
TABLE_NAME = "dwm_1day_multicast_csencode"
DEFAULT_VERSION = "csix_alpha158b_ext2_zscore_vae4"
DEFAULT_START_DATE = "2019-01-01"
DEFAULT_END_DATE = "2025-12-31"
OUTPUT_FILE = "../data/embeddings_from_ddb.parquet"
def fetch_embeddings(
start_date: str = DEFAULT_START_DATE,
end_date: str = DEFAULT_END_DATE,
version: str = DEFAULT_VERSION,
output_file: str = OUTPUT_FILE
) -> pl.DataFrame:
"""
Fetch embedding data from DolphinDB.
Args:
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
version: Version string to filter by
output_file: Output parquet file path
Returns:
Polars DataFrame with columns: [datetime, instrument, embedding_0, embedding_1, ...]
"""
print("=" * 60)
print("Fetching embedding data from DolphinDB")
print("=" * 60)
print(f"Database: {DB_PATH}")
print(f"Table: {TABLE_NAME}")
print(f"Version: {version}")
print(f"Date range: {start_date} to {end_date}")
# Connect to DolphinDB
try:
from qshare.io.ddb import get_ddb_sess
sess = get_ddb_sess(host=DDB_CONFIG["host"], port=DDB_CONFIG["port"])
print(f"Connected to DolphinDB at {DDB_CONFIG['host']}:{DDB_CONFIG['port']}")
except Exception as e:
print(f"Error connecting to DolphinDB: {e}")
raise
# Convert date strings to DolphinDB date format (YYYY.MM.DD)
start_ddb = start_date.replace("-", ".")
end_ddb = end_date.replace("-", ".")
# Build SQL query with filters in the WHERE clause
# Note: DolphinDB requires date() function for date literals
# Use single-line SQL to avoid parsing issues
sql = f'select * from loadTable("{DB_PATH}", "{TABLE_NAME}") where version = "{version}" and m_nDate >= date({start_ddb}) and m_nDate <= date({end_ddb})'
print(f"Executing SQL: {sql.strip()}")
try:
# Execute query and get pandas DataFrame
df_pd = sess.run(sql)
print(f"Fetched {len(df_pd)} rows from DolphinDB")
print(f"Columns: {df_pd.columns.tolist()}")
if len(df_pd) > 0:
print(f"Sample:\n{df_pd.head()}")
except Exception as e:
print(f"Error executing query: {e}")
raise
finally:
sess.close()
# Convert to Polars
df = pl.from_pandas(df_pd)
print(f"Columns in result: {df.columns}")
# Transform columns
# Rename m_nDate -> datetime and convert to uint32 (YYYYMMDD)
if 'm_nDate' in df.columns:
df = df.rename({"m_nDate": "datetime"})
if df["datetime"].dtype == pl.Datetime:
df = df.with_columns([
pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime")
])
elif df["datetime"].dtype == pl.Date:
df = df.with_columns([
pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime")
])
elif df["datetime"].dtype in [pl.Utf8, pl.String]:
df = df.with_columns([
pl.col("datetime").str.replace("-", "").cast(pl.UInt32).alias("datetime")
])
else:
df = df.with_columns([pl.col("datetime").cast(pl.UInt32).alias("datetime")])
# Rename code -> instrument and convert to uint32
if 'code' in df.columns:
df = df.rename({"code": "instrument"})
# Convert TS code (e.g., 'SH600085') to uint32 by removing prefix and casting
df = df.with_columns([
pl.col("instrument")
.str.replace("SH", "")
.str.replace("SZ", "")
.str.replace("BJ", "")
.cast(pl.UInt32)
.alias("instrument")
])
# Drop version column if present (no longer needed)
if 'version' in df.columns:
df = df.drop('version')
# Check if 'values' column contains lists (embedding vectors)
if 'values' in df.columns and df['values'].dtype == pl.List:
# Get the embedding dimension from the first row
first_val = df['values'][0]
if first_val is not None:
emb_dim = len(first_val)
print(f"Detected embedding dimension: {emb_dim}")
# Expand the list column to separate embedding columns
embedding_cols = []
for i in range(emb_dim):
col_name = f"embedding_{i}"
embedding_cols.append(col_name)
df = df.with_columns([
pl.col('values').list.get(i).alias(col_name)
])
# Drop the original values column
df = df.drop('values')
# Reorder columns: datetime, instrument, embedding_0, embedding_1, ...
core_cols = ['datetime', 'instrument']
final_cols = core_cols + embedding_cols
df = df.select(final_cols)
print(f"Expanded embeddings into {emb_dim} columns")
else:
# Identify embedding columns (typically named 'feature_0', 'feature_1', etc. or 'emb_0', 'emb_1', etc.)
# Keep datetime, instrument, and any embedding/feature columns
core_cols = ['datetime', 'instrument']
embedding_cols = [c for c in df.columns if c not in core_cols + ['version']]
# Select and order columns
final_cols = core_cols + sorted(embedding_cols)
df = df.select(final_cols)
print(f"\nTransformed data:")
print(f" Shape: {df.shape}")
print(f" Columns: {df.columns[:10]}..." if len(df.columns) > 10 else f" Columns: {df.columns}")
print(f" Date range: {df['datetime'].min()} to {df['datetime'].max()}")
print(f" Instrument count: {df['instrument'].n_unique()}")
print(f" Sample:\n{df.head()}")
# Save to parquet
os.makedirs(os.path.dirname(output_file), exist_ok=True)
df.write_parquet(output_file)
print(f"\nSaved to: {output_file}")
return df
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Fetch embedding data from DolphinDB")
parser.add_argument("--start-date", type=str, default=DEFAULT_START_DATE,
help="Start date (YYYY-MM-DD)")
parser.add_argument("--end-date", type=str, default=DEFAULT_END_DATE,
help="End date (YYYY-MM-DD)")
parser.add_argument("--version", type=str, default=DEFAULT_VERSION,
help="Version string to filter by")
parser.add_argument("--output", type=str, default=OUTPUT_FILE,
help="Output parquet file")
args = parser.parse_args()
df = fetch_embeddings(
start_date=args.start_date,
end_date=args.end_date,
version=args.version,
output_file=args.output
)
print("\nDone!")