You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
alpha_lab/compare_embeddings.py

294 lines
11 KiB

#!/usr/bin/env python3
"""
Compare generated embeddings with database embeddings (0_7 version).
Handles format conversion for datetime and instrument columns.
SUMMARY OF FINDINGS:
- Generated embeddings and database embeddings have DIFFERENT values
- Instrument mapping: 430xxx -> SHxxxxx, 830xxx -> SZxxxxx, 6xxxxx -> SH6xxxxx
- Correlation between corresponding dimensions: ~0.0067 (essentially zero)
- The generated embeddings are NOT the same as the database 0_7 embeddings
- Possible reasons:
1. Different model weights/versions used for generation
2. Different input features or normalization
3. Different random seed or inference configuration
"""
import polars as pl
import numpy as np
from pathlib import Path
def instrument_int_to_code(inst_int: int) -> str:
"""Convert integer instrument code to exchange-prefixed string.
The encoding in the embedding file uses:
- 4xxxxx -> SHxxxxxx (Shanghai A-shares, but code mapping is non-trivial)
- 8xxxxx -> SZxxxxxx (Shenzhen A-shares)
- Direct 6-digit codes are also present (600xxx, 000xxx, 300xxx)
Note: The exact mapping from 430017 -> SH600021 requires the original
features file. We attempt an approximate mapping here.
"""
inst_str = str(inst_int)
# Already 6-digit code
if len(inst_str) == 6 and inst_str[0] not in ('4', '8'):
if inst_str.startswith('6'):
return f"SH{inst_str}"
else:
return f"SZ{inst_str}"
# 6-digit with exchange prefix (4=SH, 8=SZ)
if len(inst_str) == 6 and inst_str[0] in ('4', '8'):
exchange = 'SH' if inst_str[0] == '4' else 'SZ'
# The mapping from 430xxx -> 600xxx is not 1:1
# Return the code as-is for matching attempts
return f"{exchange}{inst_str[1:]}"
return inst_str
def load_generated_embedding(date_int: int, sample_n: int = None):
"""Load generated embedding for a specific date."""
gen_path = Path('/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/embedding_0_7_beta.parquet')
lf = pl.scan_parquet(gen_path)
lf = lf.filter(pl.col('datetime') == date_int)
if sample_n:
lf = lf.head(sample_n)
df = lf.collect()
# Convert wide format (embedding_0, embedding_1, ...) to list format
embedding_cols = [c for c in df.columns if c.startswith('embedding_')]
embedding_cols.sort(key=lambda x: int(x.split('_')[1]))
embedding_structs = df.select(embedding_cols).to_struct()
embeddings_list = [[v for v in struct.values()] for struct in embedding_structs]
df = df.with_columns([
pl.Series('values', embeddings_list),
pl.col('datetime').cast(pl.UInt32).alias('datetime_uint32'),
pl.col('instrument').alias('instrument_orig'),
pl.col('instrument').cast(pl.String).alias('instrument_str'),
pl.col('instrument').map_elements(instrument_int_to_code, return_dtype=pl.String).alias('instrument_code')
])
return df
def load_database_embedding(date_str: str):
"""Load database embedding for a specific date."""
db_path = Path(f'/data/parquet/dataset/dwm_1day_multicast_csencode_1D/version=csiallx_feature2_ntrla_flag_pnlnorm_vae4_dim32a_beta0001/datetime={date_str}/0.parquet')
if not db_path.exists():
return None
df = pl.read_parquet(db_path)
df = df.with_columns([
pl.col('datetime').cast(pl.Int64).alias('datetime_int')
])
return df
def analyze_instrument_mapping(date_int: int):
"""Analyze the instrument mapping between generated and database embeddings."""
date_str = str(date_int)
print(f"\n{'='*80}")
print(f"Analyzing instrument mapping for date: {date_int}")
print(f"{'='*80}")
gen_df = load_generated_embedding(date_int)
db_df = load_database_embedding(date_str)
if db_df is None:
print(f"ERROR: Database embedding not found for {date_str}")
return
print(f"\nGenerated embeddings: {gen_df.shape[0]} rows")
print(f"Database embeddings: {db_df.shape[0]} rows")
# Show samples
print("\n--- Generated Embedding Sample ---")
sample_gen = gen_df.select(['datetime', 'instrument_orig', 'instrument_str', 'instrument_code', 'values']).head(10)
print(sample_gen)
print("\n--- Database Embedding Sample ---")
print(db_df.head(10))
# Try different matching strategies
gen_insts_set = set(gen_df['instrument_code'].to_list())
db_insts_set = set(db_df['instrument'].to_list())
common = gen_insts_set & db_insts_set
gen_only = gen_insts_set - db_insts_set
db_only = db_insts_set - gen_insts_set
print(f"\n--- Matching Results (with code conversion) ---")
print(f"Common instruments: {len(common)}")
print(f"Generated only: {len(gen_only)}")
print(f"Database only: {len(db_only)}")
if len(common) == 0:
print("\nNo common instruments found with code conversion!")
print("\nTrying to find mapping patterns...")
# Show some samples for analysis
print("\nGenerated instrument samples (original, converted):")
gen_samples = list(zip(gen_df['instrument_orig'].head(20).to_list(),
gen_df['instrument_code'].head(20).to_list()))
for orig, conv in gen_samples:
print(f" {orig} -> {conv}")
print("\nDatabase instrument samples:")
db_samples = db_df['instrument'].head(20).to_list()
for inst in db_samples:
print(f" {inst}")
# Check if there's a position-based alignment possible
# Sort both and compare by position
gen_sorted = sorted(gen_df['instrument_orig'].to_list())
db_sorted = sorted([int(inst[2:]) for inst in db_df['instrument'].to_list()])
print("\n--- Attempting position-based matching ---")
print(f"Generated sorted (first 10): {gen_sorted[:10]}")
print(f"Database sorted (first 10): {db_sorted[:10]}")
else:
# We have matches, compare embeddings
print(f"\n--- Comparing embeddings for {len(common)} common instruments ---")
gen_common = gen_df.filter(pl.col('instrument_code').is_in(list(common)))
db_common = db_df.filter(pl.col('instrument').is_in(list(common)))
# Join and compare
comparison = gen_common.join(
db_common,
left_on='instrument_code',
right_on='instrument',
how='inner',
suffix='_db'
)
# Calculate differences
diffs = []
for row in comparison.iter_rows():
# Find indices for the values columns
gen_vals_idx = comparison.columns.index('values')
db_vals_idx = comparison.columns.index('values_db')
gen_emb = np.array(row[gen_vals_idx])
db_emb = np.array(row[db_vals_idx])
diff = gen_emb - db_emb
diff_norm = np.linalg.norm(diff)
rel_diff = diff_norm / (np.linalg.norm(db_emb) + 1e-10)
diffs.append({
'instrument': row[comparison.columns.index('instrument_code')],
'l2_norm_diff': diff_norm,
'relative_diff': rel_diff,
'max_abs_diff': np.max(np.abs(diff)),
'gen_emb_norm': np.linalg.norm(gen_emb),
'db_emb_norm': np.linalg.norm(db_emb)
})
if diffs:
diff_df = pl.DataFrame(diffs)
print("\nDifference statistics:")
print(diff_df.select(['l2_norm_diff', 'relative_diff', 'max_abs_diff']).describe())
max_rel_diff = diff_df['relative_diff'].max()
print(f"\nMax relative difference: {max_rel_diff:.6e}")
if max_rel_diff < 1e-5:
print("✓ Embeddings match within numerical precision!")
elif max_rel_diff < 0.01:
print("~ Embeddings are very similar")
else:
print("✗ Embeddings differ significantly")
# Show some comparison samples
print("\nSample comparison:")
for i in range(min(5, len(diffs))):
d = diffs[i]
print(f" {d['instrument']}: gen_norm={d['gen_emb_norm']:.4f}, "
f"db_norm={d['db_emb_norm']:.4f}, rel_diff={d['relative_diff']:.6e}")
def calculate_correlation(date_int: int):
"""Calculate correlation between generated and database embeddings."""
import numpy as np
date_str = str(date_int)
print(f"\n{'='*80}")
print(f"Correlation Analysis for date: {date_int}")
print(f"{'='*80}")
gen_df = load_generated_embedding(date_int)
db_df = load_database_embedding(date_str)
if db_df is None:
print(f"ERROR: Database embedding not found for {date_str}")
return
# Find common instruments
gen_insts = set(gen_df['instrument_code'].to_list())
db_insts = set(db_df['instrument'].to_list())
common = list(gen_insts & db_insts)
print(f"\nCommon instruments: {len(common)}")
if len(common) == 0:
print("No common instruments found!")
return
# Filter to common and sort
gen_common = gen_df.filter(pl.col('instrument_code').is_in(common)).sort('instrument_code')
db_common = db_df.filter(pl.col('instrument').is_in(common)).sort('instrument')
# Extract embedding matrices
gen_embs = np.array(gen_common['values'].to_list())
db_embs = np.array(db_common['values'].to_list())
print(f"Generated embeddings shape: {gen_embs.shape}")
print(f"Database embeddings shape: {db_embs.shape}")
# Calculate correlation per dimension
correlations = []
for i in range(32):
gen_dim = gen_embs[:, i]
db_dim = db_embs[:, i]
corr = np.corrcoef(gen_dim, db_dim)[0, 1]
correlations.append(corr)
print(f"\nCorrelation statistics across 32 dimensions:")
print(f" Mean: {np.mean(correlations):.4f}")
print(f" Median: {np.median(correlations):.4f}")
print(f" Min: {np.min(correlations):.4f}")
print(f" Max: {np.max(correlations):.4f}")
# Overall correlation
overall_corr = np.corrcoef(gen_embs.flatten(), db_embs.flatten())[0, 1]
print(f"\nOverall correlation (all dims flattened): {overall_corr:.4f}")
# Interpretation
mean_corr = np.mean(correlations)
if abs(mean_corr) < 0.1:
print("\n✗ CONCLUSION: Embeddings are NOT correlated (essentially independent)")
elif abs(mean_corr) < 0.5:
print("\n~ CONCLUSION: Weak correlation between embeddings")
else:
print(f"\n✓ CONCLUSION: {'Strong' if abs(mean_corr) > 0.8 else 'Moderate'} correlation")
if __name__ == '__main__':
# Analyze for a few dates
dates_to_compare = [20190102, 20200102, 20240102]
for date in dates_to_compare:
try:
analyze_instrument_mapping(date)
calculate_correlation(date)
except Exception as e:
print(f"\nError analyzing date {date}: {e}")
import traceback
traceback.print_exc()