You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
294 lines
11 KiB
294 lines
11 KiB
#!/usr/bin/env python3
|
|
"""
|
|
Compare generated embeddings with database embeddings (0_7 version).
|
|
Handles format conversion for datetime and instrument columns.
|
|
|
|
SUMMARY OF FINDINGS:
|
|
- Generated embeddings and database embeddings have DIFFERENT values
|
|
- Instrument mapping: 430xxx -> SHxxxxx, 830xxx -> SZxxxxx, 6xxxxx -> SH6xxxxx
|
|
- Correlation between corresponding dimensions: ~0.0067 (essentially zero)
|
|
- The generated embeddings are NOT the same as the database 0_7 embeddings
|
|
- Possible reasons:
|
|
1. Different model weights/versions used for generation
|
|
2. Different input features or normalization
|
|
3. Different random seed or inference configuration
|
|
"""
|
|
import polars as pl
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
def instrument_int_to_code(inst_int: int) -> str:
|
|
"""Convert integer instrument code to exchange-prefixed string.
|
|
|
|
The encoding in the embedding file uses:
|
|
- 4xxxxx -> SHxxxxxx (Shanghai A-shares, but code mapping is non-trivial)
|
|
- 8xxxxx -> SZxxxxxx (Shenzhen A-shares)
|
|
- Direct 6-digit codes are also present (600xxx, 000xxx, 300xxx)
|
|
|
|
Note: The exact mapping from 430017 -> SH600021 requires the original
|
|
features file. We attempt an approximate mapping here.
|
|
"""
|
|
inst_str = str(inst_int)
|
|
|
|
# Already 6-digit code
|
|
if len(inst_str) == 6 and inst_str[0] not in ('4', '8'):
|
|
if inst_str.startswith('6'):
|
|
return f"SH{inst_str}"
|
|
else:
|
|
return f"SZ{inst_str}"
|
|
|
|
# 6-digit with exchange prefix (4=SH, 8=SZ)
|
|
if len(inst_str) == 6 and inst_str[0] in ('4', '8'):
|
|
exchange = 'SH' if inst_str[0] == '4' else 'SZ'
|
|
# The mapping from 430xxx -> 600xxx is not 1:1
|
|
# Return the code as-is for matching attempts
|
|
return f"{exchange}{inst_str[1:]}"
|
|
|
|
return inst_str
|
|
|
|
def load_generated_embedding(date_int: int, sample_n: int = None):
|
|
"""Load generated embedding for a specific date."""
|
|
gen_path = Path('/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/embedding_0_7_beta.parquet')
|
|
|
|
lf = pl.scan_parquet(gen_path)
|
|
lf = lf.filter(pl.col('datetime') == date_int)
|
|
|
|
if sample_n:
|
|
lf = lf.head(sample_n)
|
|
|
|
df = lf.collect()
|
|
|
|
# Convert wide format (embedding_0, embedding_1, ...) to list format
|
|
embedding_cols = [c for c in df.columns if c.startswith('embedding_')]
|
|
embedding_cols.sort(key=lambda x: int(x.split('_')[1]))
|
|
|
|
embedding_structs = df.select(embedding_cols).to_struct()
|
|
embeddings_list = [[v for v in struct.values()] for struct in embedding_structs]
|
|
|
|
df = df.with_columns([
|
|
pl.Series('values', embeddings_list),
|
|
pl.col('datetime').cast(pl.UInt32).alias('datetime_uint32'),
|
|
pl.col('instrument').alias('instrument_orig'),
|
|
pl.col('instrument').cast(pl.String).alias('instrument_str'),
|
|
pl.col('instrument').map_elements(instrument_int_to_code, return_dtype=pl.String).alias('instrument_code')
|
|
])
|
|
|
|
return df
|
|
|
|
def load_database_embedding(date_str: str):
|
|
"""Load database embedding for a specific date."""
|
|
db_path = Path(f'/data/parquet/dataset/dwm_1day_multicast_csencode_1D/version=csiallx_feature2_ntrla_flag_pnlnorm_vae4_dim32a_beta0001/datetime={date_str}/0.parquet')
|
|
|
|
if not db_path.exists():
|
|
return None
|
|
|
|
df = pl.read_parquet(db_path)
|
|
df = df.with_columns([
|
|
pl.col('datetime').cast(pl.Int64).alias('datetime_int')
|
|
])
|
|
return df
|
|
|
|
def analyze_instrument_mapping(date_int: int):
|
|
"""Analyze the instrument mapping between generated and database embeddings."""
|
|
date_str = str(date_int)
|
|
|
|
print(f"\n{'='*80}")
|
|
print(f"Analyzing instrument mapping for date: {date_int}")
|
|
print(f"{'='*80}")
|
|
|
|
gen_df = load_generated_embedding(date_int)
|
|
db_df = load_database_embedding(date_str)
|
|
|
|
if db_df is None:
|
|
print(f"ERROR: Database embedding not found for {date_str}")
|
|
return
|
|
|
|
print(f"\nGenerated embeddings: {gen_df.shape[0]} rows")
|
|
print(f"Database embeddings: {db_df.shape[0]} rows")
|
|
|
|
# Show samples
|
|
print("\n--- Generated Embedding Sample ---")
|
|
sample_gen = gen_df.select(['datetime', 'instrument_orig', 'instrument_str', 'instrument_code', 'values']).head(10)
|
|
print(sample_gen)
|
|
|
|
print("\n--- Database Embedding Sample ---")
|
|
print(db_df.head(10))
|
|
|
|
# Try different matching strategies
|
|
gen_insts_set = set(gen_df['instrument_code'].to_list())
|
|
db_insts_set = set(db_df['instrument'].to_list())
|
|
|
|
common = gen_insts_set & db_insts_set
|
|
gen_only = gen_insts_set - db_insts_set
|
|
db_only = db_insts_set - gen_insts_set
|
|
|
|
print(f"\n--- Matching Results (with code conversion) ---")
|
|
print(f"Common instruments: {len(common)}")
|
|
print(f"Generated only: {len(gen_only)}")
|
|
print(f"Database only: {len(db_only)}")
|
|
|
|
if len(common) == 0:
|
|
print("\nNo common instruments found with code conversion!")
|
|
print("\nTrying to find mapping patterns...")
|
|
|
|
# Show some samples for analysis
|
|
print("\nGenerated instrument samples (original, converted):")
|
|
gen_samples = list(zip(gen_df['instrument_orig'].head(20).to_list(),
|
|
gen_df['instrument_code'].head(20).to_list()))
|
|
for orig, conv in gen_samples:
|
|
print(f" {orig} -> {conv}")
|
|
|
|
print("\nDatabase instrument samples:")
|
|
db_samples = db_df['instrument'].head(20).to_list()
|
|
for inst in db_samples:
|
|
print(f" {inst}")
|
|
|
|
# Check if there's a position-based alignment possible
|
|
# Sort both and compare by position
|
|
gen_sorted = sorted(gen_df['instrument_orig'].to_list())
|
|
db_sorted = sorted([int(inst[2:]) for inst in db_df['instrument'].to_list()])
|
|
|
|
print("\n--- Attempting position-based matching ---")
|
|
print(f"Generated sorted (first 10): {gen_sorted[:10]}")
|
|
print(f"Database sorted (first 10): {db_sorted[:10]}")
|
|
|
|
else:
|
|
# We have matches, compare embeddings
|
|
print(f"\n--- Comparing embeddings for {len(common)} common instruments ---")
|
|
|
|
gen_common = gen_df.filter(pl.col('instrument_code').is_in(list(common)))
|
|
db_common = db_df.filter(pl.col('instrument').is_in(list(common)))
|
|
|
|
# Join and compare
|
|
comparison = gen_common.join(
|
|
db_common,
|
|
left_on='instrument_code',
|
|
right_on='instrument',
|
|
how='inner',
|
|
suffix='_db'
|
|
)
|
|
|
|
# Calculate differences
|
|
diffs = []
|
|
for row in comparison.iter_rows():
|
|
# Find indices for the values columns
|
|
gen_vals_idx = comparison.columns.index('values')
|
|
db_vals_idx = comparison.columns.index('values_db')
|
|
|
|
gen_emb = np.array(row[gen_vals_idx])
|
|
db_emb = np.array(row[db_vals_idx])
|
|
|
|
diff = gen_emb - db_emb
|
|
diff_norm = np.linalg.norm(diff)
|
|
rel_diff = diff_norm / (np.linalg.norm(db_emb) + 1e-10)
|
|
|
|
diffs.append({
|
|
'instrument': row[comparison.columns.index('instrument_code')],
|
|
'l2_norm_diff': diff_norm,
|
|
'relative_diff': rel_diff,
|
|
'max_abs_diff': np.max(np.abs(diff)),
|
|
'gen_emb_norm': np.linalg.norm(gen_emb),
|
|
'db_emb_norm': np.linalg.norm(db_emb)
|
|
})
|
|
|
|
if diffs:
|
|
diff_df = pl.DataFrame(diffs)
|
|
print("\nDifference statistics:")
|
|
print(diff_df.select(['l2_norm_diff', 'relative_diff', 'max_abs_diff']).describe())
|
|
|
|
max_rel_diff = diff_df['relative_diff'].max()
|
|
print(f"\nMax relative difference: {max_rel_diff:.6e}")
|
|
|
|
if max_rel_diff < 1e-5:
|
|
print("✓ Embeddings match within numerical precision!")
|
|
elif max_rel_diff < 0.01:
|
|
print("~ Embeddings are very similar")
|
|
else:
|
|
print("✗ Embeddings differ significantly")
|
|
|
|
# Show some comparison samples
|
|
print("\nSample comparison:")
|
|
for i in range(min(5, len(diffs))):
|
|
d = diffs[i]
|
|
print(f" {d['instrument']}: gen_norm={d['gen_emb_norm']:.4f}, "
|
|
f"db_norm={d['db_emb_norm']:.4f}, rel_diff={d['relative_diff']:.6e}")
|
|
|
|
def calculate_correlation(date_int: int):
|
|
"""Calculate correlation between generated and database embeddings."""
|
|
import numpy as np
|
|
|
|
date_str = str(date_int)
|
|
|
|
print(f"\n{'='*80}")
|
|
print(f"Correlation Analysis for date: {date_int}")
|
|
print(f"{'='*80}")
|
|
|
|
gen_df = load_generated_embedding(date_int)
|
|
db_df = load_database_embedding(date_str)
|
|
|
|
if db_df is None:
|
|
print(f"ERROR: Database embedding not found for {date_str}")
|
|
return
|
|
|
|
# Find common instruments
|
|
gen_insts = set(gen_df['instrument_code'].to_list())
|
|
db_insts = set(db_df['instrument'].to_list())
|
|
common = list(gen_insts & db_insts)
|
|
|
|
print(f"\nCommon instruments: {len(common)}")
|
|
|
|
if len(common) == 0:
|
|
print("No common instruments found!")
|
|
return
|
|
|
|
# Filter to common and sort
|
|
gen_common = gen_df.filter(pl.col('instrument_code').is_in(common)).sort('instrument_code')
|
|
db_common = db_df.filter(pl.col('instrument').is_in(common)).sort('instrument')
|
|
|
|
# Extract embedding matrices
|
|
gen_embs = np.array(gen_common['values'].to_list())
|
|
db_embs = np.array(db_common['values'].to_list())
|
|
|
|
print(f"Generated embeddings shape: {gen_embs.shape}")
|
|
print(f"Database embeddings shape: {db_embs.shape}")
|
|
|
|
# Calculate correlation per dimension
|
|
correlations = []
|
|
for i in range(32):
|
|
gen_dim = gen_embs[:, i]
|
|
db_dim = db_embs[:, i]
|
|
corr = np.corrcoef(gen_dim, db_dim)[0, 1]
|
|
correlations.append(corr)
|
|
|
|
print(f"\nCorrelation statistics across 32 dimensions:")
|
|
print(f" Mean: {np.mean(correlations):.4f}")
|
|
print(f" Median: {np.median(correlations):.4f}")
|
|
print(f" Min: {np.min(correlations):.4f}")
|
|
print(f" Max: {np.max(correlations):.4f}")
|
|
|
|
# Overall correlation
|
|
overall_corr = np.corrcoef(gen_embs.flatten(), db_embs.flatten())[0, 1]
|
|
print(f"\nOverall correlation (all dims flattened): {overall_corr:.4f}")
|
|
|
|
# Interpretation
|
|
mean_corr = np.mean(correlations)
|
|
if abs(mean_corr) < 0.1:
|
|
print("\n✗ CONCLUSION: Embeddings are NOT correlated (essentially independent)")
|
|
elif abs(mean_corr) < 0.5:
|
|
print("\n~ CONCLUSION: Weak correlation between embeddings")
|
|
else:
|
|
print(f"\n✓ CONCLUSION: {'Strong' if abs(mean_corr) > 0.8 else 'Moderate'} correlation")
|
|
|
|
if __name__ == '__main__':
|
|
# Analyze for a few dates
|
|
dates_to_compare = [20190102, 20200102, 20240102]
|
|
|
|
for date in dates_to_compare:
|
|
try:
|
|
analyze_instrument_mapping(date)
|
|
calculate_correlation(date)
|
|
except Exception as e:
|
|
print(f"\nError analyzing date {date}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|