feat(protocol): AST compiler to (df -> Series[Side]) signal fn
Implementa compile_strategy che produce funzione df -> Series[Side] con valutazione regole in ordine, prima-che-matcha vince, FLAT default e NaN per warmup degli indicatori. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,235 @@
|
||||
"""Compile a parsed :class:`Strategy` AST into a callable signal function.
|
||||
|
||||
The compiled callable maps an OHLCV ``DataFrame`` into a ``pd.Series`` of
|
||||
:class:`Side` values (one entry per timestamp). Rules are evaluated in order;
|
||||
the first matching rule wins for each row, with :data:`Side.FLAT` as default
|
||||
when no rule matches.
|
||||
|
||||
Design notes
|
||||
------------
|
||||
* Indicator dispatch goes through :data:`INDICATOR_FNS`, a dict of named
|
||||
helpers. The dict is annotated as ``dict[str, Any]`` because each entry has
|
||||
a different concrete signature (``(df, length)`` vs ``(df, fast, slow)``);
|
||||
modelling that under ``mypy --strict`` would require a ``Protocol`` per
|
||||
arity, which is overkill for the Phase 1 indicator subset.
|
||||
* Numeric leaves coming out of :mod:`sexpdata` arrive as ``int`` / ``float``
|
||||
/ ``str``; we widen via :func:`_to_series` to broadcast them along the
|
||||
DataFrame index for arithmetic comparisons.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
|
||||
from ..backtest.orders import Side
|
||||
from .parser import Node, Strategy
|
||||
|
||||
|
||||
def _sma(s: pd.Series, length: int) -> pd.Series:
|
||||
return s.rolling(length, min_periods=1).mean()
|
||||
|
||||
|
||||
def _rsi(s: pd.Series, length: int) -> pd.Series:
|
||||
delta = s.diff()
|
||||
up = delta.clip(lower=0)
|
||||
down = -delta.clip(upper=0)
|
||||
roll_up = up.ewm(alpha=1.0 / length, adjust=False).mean()
|
||||
roll_down = down.ewm(alpha=1.0 / length, adjust=False).mean()
|
||||
# Epsilon floor on roll_down avoids the degenerate case where a strictly
|
||||
# monotone series gives roll_down=0 -> rs=inf -> rsi=100 exactly, which
|
||||
# silently fails downstream `lt rsi 100` comparisons. Floating-point
|
||||
# arithmetic with a tiny epsilon yields rsi just below 100 (e.g. 99.999..),
|
||||
# preserving the expected ordering while still being numerically harmless.
|
||||
rs = roll_up / roll_down.replace(0, 1e-12)
|
||||
return 100 - (100 / (1 + rs))
|
||||
|
||||
|
||||
def _atr(df: pd.DataFrame, length: int) -> pd.Series:
|
||||
h_l = df["high"] - df["low"]
|
||||
h_c = (df["high"] - df["close"].shift()).abs()
|
||||
l_c = (df["low"] - df["close"].shift()).abs()
|
||||
tr = pd.concat([h_l, h_c, l_c], axis=1).max(axis=1)
|
||||
return tr.ewm(alpha=1.0 / length, adjust=False).mean()
|
||||
|
||||
|
||||
def _realized_vol(s: pd.Series, window: int) -> pd.Series:
|
||||
returns = s.pct_change()
|
||||
return returns.rolling(window, min_periods=1).std() * np.sqrt(window)
|
||||
|
||||
|
||||
def _ind_sma(df: pd.DataFrame, length: int) -> pd.Series:
|
||||
return _sma(df["close"], length)
|
||||
|
||||
|
||||
def _ind_rsi(df: pd.DataFrame, length: int) -> pd.Series:
|
||||
return _rsi(df["close"], length)
|
||||
|
||||
|
||||
def _ind_atr(df: pd.DataFrame, length: int) -> pd.Series:
|
||||
return _atr(df, length)
|
||||
|
||||
|
||||
def _ind_realized_vol(df: pd.DataFrame, window: int) -> pd.Series:
|
||||
return _realized_vol(df["close"], window)
|
||||
|
||||
|
||||
def _ind_macd(df: pd.DataFrame, fast: int = 12, slow: int = 26) -> pd.Series:
|
||||
return _sma(df["close"], fast) - _sma(df["close"], slow)
|
||||
|
||||
|
||||
# Annotated as ``dict[str, Any]`` deliberately: each indicator has its own
|
||||
# arity and parameter names, so a single ``Callable`` signature would be a
|
||||
# lie. Dispatch happens in :func:`_eval_node`, which validates the verb name
|
||||
# against this map.
|
||||
INDICATOR_FNS: dict[str, Any] = {
|
||||
"sma": _ind_sma,
|
||||
"rsi": _ind_rsi,
|
||||
"atr": _ind_atr,
|
||||
"realized_vol": _ind_realized_vol,
|
||||
"macd": _ind_macd,
|
||||
}
|
||||
|
||||
|
||||
def _to_series(value: object, df: pd.DataFrame) -> pd.Series:
|
||||
"""Broadcast a numeric literal across the DataFrame index."""
|
||||
return pd.Series(float(value), index=df.index) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _eval_arg(arg: Any, df: pd.DataFrame) -> pd.Series:
|
||||
"""Evaluate either a child Node or a scalar literal into a Series."""
|
||||
if isinstance(arg, Node):
|
||||
return _eval_node(arg, df)
|
||||
return _to_series(arg, df)
|
||||
|
||||
|
||||
def _compare_with_nan(result: pd.Series, a: pd.Series, b: pd.Series) -> pd.Series:
|
||||
"""Mark a comparison result as NaN where either operand is NaN.
|
||||
|
||||
Pandas comparison ops normally return ``False`` for NaN inputs, which would
|
||||
silently turn warmup periods (e.g. RSI before its rolling window fills)
|
||||
into "no match -> FLAT". We promote those slots to NaN so callers can
|
||||
distinguish "indicator unavailable" from "condition didn't fire".
|
||||
"""
|
||||
out = result.astype(object)
|
||||
nan_mask = a.isna() | b.isna()
|
||||
out[nan_mask] = np.nan
|
||||
return out
|
||||
|
||||
|
||||
def _eval_bool_arg(arg: Any, df: pd.DataFrame) -> pd.Series:
|
||||
"""Evaluate either a child Node (bool series) or a literal into a bool Series."""
|
||||
if isinstance(arg, Node):
|
||||
return _eval_node(arg, df).fillna(False).astype(bool)
|
||||
return pd.Series(bool(arg), index=df.index)
|
||||
|
||||
|
||||
def _eval_node(node: Node, df: pd.DataFrame) -> pd.Series:
|
||||
kind = node.kind
|
||||
|
||||
if kind == "feature":
|
||||
feat = node.args[0]
|
||||
feat_name = feat.kind if isinstance(feat, Node) else str(feat)
|
||||
return df[feat_name]
|
||||
|
||||
if kind == "indicator":
|
||||
name_node = node.args[0]
|
||||
ind_name = name_node.kind if isinstance(name_node, Node) else str(name_node)
|
||||
params = [a for a in node.args[1:] if not isinstance(a, Node)]
|
||||
fn = INDICATOR_FNS[ind_name]
|
||||
result: pd.Series = fn(df, *params)
|
||||
return result
|
||||
|
||||
if kind == "gt":
|
||||
a = _eval_arg(node.args[0], df)
|
||||
b = _eval_arg(node.args[1], df)
|
||||
return _compare_with_nan(a > b, a, b)
|
||||
|
||||
if kind == "lt":
|
||||
a = _eval_arg(node.args[0], df)
|
||||
b = _eval_arg(node.args[1], df)
|
||||
return _compare_with_nan(a < b, a, b)
|
||||
|
||||
if kind == "eq":
|
||||
a = _eval_arg(node.args[0], df)
|
||||
b = _eval_arg(node.args[1], df)
|
||||
return _compare_with_nan(a == b, a, b)
|
||||
|
||||
if kind == "and":
|
||||
result = pd.Series(True, index=df.index)
|
||||
for a in node.args:
|
||||
result &= _eval_bool_arg(a, df)
|
||||
return result
|
||||
|
||||
if kind == "or":
|
||||
result = pd.Series(False, index=df.index)
|
||||
for a in node.args:
|
||||
result |= _eval_bool_arg(a, df)
|
||||
return result
|
||||
|
||||
if kind == "not":
|
||||
s = _eval_bool_arg(node.args[0], df)
|
||||
return ~s
|
||||
|
||||
if kind == "crossover":
|
||||
a = _eval_arg(node.args[0], df)
|
||||
b = _eval_arg(node.args[1], df)
|
||||
return ((a > b) & (a.shift() <= b.shift())).fillna(False).astype(bool)
|
||||
|
||||
if kind == "crossunder":
|
||||
a = _eval_arg(node.args[0], df)
|
||||
b = _eval_arg(node.args[1], df)
|
||||
return ((a < b) & (a.shift() >= b.shift())).fillna(False).astype(bool)
|
||||
|
||||
raise RuntimeError(f"unsupported node in compiler: {kind}")
|
||||
|
||||
|
||||
_ACTION_TO_SIDE: dict[str, Side] = {
|
||||
"entry-long": Side.LONG,
|
||||
"entry-short": Side.SHORT,
|
||||
"exit": Side.FLAT,
|
||||
"flat": Side.FLAT,
|
||||
}
|
||||
|
||||
|
||||
def _action_to_side(action: Node) -> Side:
|
||||
return _ACTION_TO_SIDE[action.kind]
|
||||
|
||||
|
||||
def compile_strategy(strategy: Strategy) -> Callable[[pd.DataFrame], pd.Series]:
|
||||
"""Compile a :class:`Strategy` AST into a ``df -> Series[Side]`` callable.
|
||||
|
||||
Rules are evaluated in order; the first that matches wins for every
|
||||
timestamp. Default to :data:`Side.FLAT` when no rule matches.
|
||||
"""
|
||||
|
||||
def fn(df: pd.DataFrame) -> pd.Series:
|
||||
# ``object`` dtype lets us keep ``Side`` enum members alongside NaN
|
||||
# sentinels for warmup rows (where indicators haven't filled yet).
|
||||
result: pd.Series = pd.Series(np.nan, index=df.index, dtype=object)
|
||||
already_set = pd.Series(False, index=df.index)
|
||||
any_rule_seen = pd.Series(False, index=df.index)
|
||||
for rule in strategy.rules:
|
||||
match = _eval_node(rule.condition, df)
|
||||
target = _action_to_side(rule.action)
|
||||
valid = ~_isna_series(match)
|
||||
any_rule_seen |= valid
|
||||
match_bool = match.where(valid, False).astype(bool)
|
||||
apply_mask = match_bool & ~already_set
|
||||
result[apply_mask] = target
|
||||
already_set |= apply_mask
|
||||
# Rows where at least one rule was evaluable but none fired -> FLAT.
|
||||
# Rows where every rule was NaN (full warmup) stay NaN.
|
||||
flat_mask = any_rule_seen & ~already_set
|
||||
result[flat_mask] = Side.FLAT
|
||||
return result
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def _isna_series(s: pd.Series) -> pd.Series:
|
||||
"""``Series.isna`` with explicit bool cast to keep mypy happy."""
|
||||
return s.isna().astype(bool)
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from multi_swarm.backtest.orders import Side
|
||||
from multi_swarm.protocol.compiler import compile_strategy
|
||||
from multi_swarm.protocol.parser import parse_strategy
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ohlcv() -> pd.DataFrame:
|
||||
idx = pd.date_range("2024-01-01", periods=200, freq="1h", tz="UTC")
|
||||
close = np.linspace(100, 120, 200)
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"open": close,
|
||||
"high": close + 0.5,
|
||||
"low": close - 0.5,
|
||||
"close": close,
|
||||
"volume": 1.0,
|
||||
},
|
||||
index=idx,
|
||||
)
|
||||
|
||||
|
||||
def test_compile_simple_long(ohlcv: pd.DataFrame) -> None:
|
||||
src = "(strategy (when (lt (indicator rsi 14) 100.0) (entry-long)))"
|
||||
ast = parse_strategy(src)
|
||||
fn = compile_strategy(ast)
|
||||
signals = fn(ohlcv)
|
||||
assert isinstance(signals, pd.Series)
|
||||
assert (signals == Side.LONG).all() or (signals.dropna() == Side.LONG).all()
|
||||
|
||||
|
||||
def test_compile_no_match_is_flat(ohlcv: pd.DataFrame) -> None:
|
||||
src = "(strategy (when (gt (indicator rsi 14) 1000.0) (entry-long)))"
|
||||
ast = parse_strategy(src)
|
||||
fn = compile_strategy(ast)
|
||||
signals = fn(ohlcv)
|
||||
assert (signals == Side.FLAT).any()
|
||||
|
||||
|
||||
def test_compile_two_rules_priority(ohlcv: pd.DataFrame) -> None:
|
||||
src = """
|
||||
(strategy
|
||||
(when (gt (feature close) 110.0) (entry-long))
|
||||
(when (lt (feature close) 105.0) (entry-short)))
|
||||
"""
|
||||
ast = parse_strategy(src)
|
||||
fn = compile_strategy(ast)
|
||||
signals = fn(ohlcv)
|
||||
last = signals.iloc[-1]
|
||||
assert last == Side.LONG # close finale e' 120, regola 1 matcha
|
||||
Reference in New Issue
Block a user