You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
1.9 KiB
64 lines
1.9 KiB
|
3 weeks ago
|
"""Label blending utilities for CTA experiments."""
|
||
|
|
|
||
|
|
from typing import Union, List
|
||
|
|
|
||
|
|
|
||
|
|
# Predefined blend configurations
|
||
|
|
BLEND_CONFIGS = {
|
||
|
|
'equal': [0.25, 0.25, 0.25, 0.25],
|
||
|
|
'zscore_heavy': [0.5, 0.2, 0.15, 0.15],
|
||
|
|
'rolling_heavy': [0.1, 0.1, 0.3, 0.5],
|
||
|
|
'cs_heavy': [0.2, 0.5, 0.15, 0.15],
|
||
|
|
'short_term': [0.1, 0.1, 0.4, 0.4],
|
||
|
|
'long_term': [0.4, 0.2, 0.2, 0.2],
|
||
|
|
}
|
||
|
|
|
||
|
|
DEFAULT_BLEND = [0.2, 0.1, 0.3, 0.4] # [zscore, cs_zscore, roll20, roll60]
|
||
|
|
|
||
|
|
|
||
|
|
def get_blend_weights(weights: Union[str, List[float], None]) -> List[float]:
|
||
|
|
"""Resolve blend weights from string name or list.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
weights: Config name, list of 4 floats, or None for default
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of 4 weights summing to 1.0
|
||
|
|
"""
|
||
|
|
if weights is None:
|
||
|
|
return DEFAULT_BLEND
|
||
|
|
|
||
|
|
if isinstance(weights, str):
|
||
|
|
if weights not in BLEND_CONFIGS:
|
||
|
|
raise ValueError(f"Unknown blend config: {weights}. "
|
||
|
|
f"Available: {list(BLEND_CONFIGS.keys())}")
|
||
|
|
return BLEND_CONFIGS[weights]
|
||
|
|
|
||
|
|
if isinstance(weights, (list, tuple)):
|
||
|
|
if len(weights) != 4:
|
||
|
|
raise ValueError(f"Blend weights must have 4 values, got {len(weights)}")
|
||
|
|
if abs(sum(weights) - 1.0) > 1e-6:
|
||
|
|
raise ValueError(f"Blend weights must sum to 1.0, got {sum(weights)}")
|
||
|
|
return list(weights)
|
||
|
|
|
||
|
|
raise ValueError(f"Invalid blend weights type: {type(weights)}")
|
||
|
|
|
||
|
|
|
||
|
|
def describe_blend_config(weights: Union[str, List[float]]) -> str:
|
||
|
|
"""Get human-readable description of blend config.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
weights: Config name or list of weights
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Description string
|
||
|
|
"""
|
||
|
|
names = ['zscore', 'cs_zscore', 'rolling_20', 'rolling_60']
|
||
|
|
|
||
|
|
if isinstance(weights, str):
|
||
|
|
w = get_blend_weights(weights)
|
||
|
|
return f"{weights}: {dict(zip(names, w))}"
|
||
|
|
|
||
|
|
w = weights
|
||
|
|
return f"custom: {dict(zip(names, w))}"
|