diff --git a/src/multi_swarm/agents/hypothesis.py b/src/multi_swarm/agents/hypothesis.py index b49a6d0..e8ebeb3 100644 --- a/src/multi_swarm/agents/hypothesis.py +++ b/src/multi_swarm/agents/hypothesis.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from dataclasses import dataclass +from dataclasses import dataclass, field from ..genome.hypothesis import HypothesisAgentGenome from ..llm.client import CompletionResult, LLMClient @@ -23,10 +23,20 @@ class MarketSummary: @dataclass(frozen=True) class HypothesisProposal: + """Risultato di una propose() del HypothesisAgent. + + ``completions`` contiene SEMPRE almeno un elemento: il primo tentativo. + Se il primo tentativo fallisce e c'e' budget di retry, vengono accodate + le completions successive, una per ogni retry effettuato. + ``n_attempts == len(completions)``. ``raw_text`` riflette l'ULTIMO output + LLM osservato (quello che ha prodotto strategy o l'ultimo parse_error). + """ + strategy: Strategy | None raw_text: str - completion: CompletionResult + completions: list[CompletionResult] = field(default_factory=list) parse_error: str | None = None + n_attempts: int = 1 SYSTEM_TEMPLATE = """\ @@ -113,7 +123,7 @@ USER_TEMPLATE = """\ Mercato: {symbol} timeframe {timeframe}, {n_bars} barre osservate. Statistiche return: mean={return_mean:.5f}, std={return_std:.5f}, \ skew={skew:.3f}, kurt={kurtosis:.3f}. -Regime volatilità : {volatility_regime}. +Regime volatilità: {volatility_regime}. Feature accessibili dal tuo genoma: {feature_access}. Lookback massimo che puoi usare nel ragionamento: {lookback_window} barre. @@ -122,6 +132,21 @@ Genera una strategia che cerchi anomalie sfruttabili in questo regime. """ +_RETRY_TEMPLATE = """\ +{original_user} + +--- TENTATIVO PRECEDENTE FALLITO --- +Output: {previous_raw} +Errore: {previous_error} +--- +Correggi l'errore e rispondi di nuovo con un singolo oggetto JSON valido +dentro fence ```json...```, seguendo strettamente lo schema fornito nel +SYSTEM message. +""" + +_RETRY_RAW_TRUNCATE = 800 + + _JSON_FENCE_RE = re.compile( r"```(?:json)?\s*(\{[\s\S]*\})\s*```", re.MULTILINE, @@ -175,9 +200,25 @@ def _extract_json(text: str) -> str | None: return _balance_braces(stripped) +def _try_parse(text: str) -> tuple[Strategy | None, str | None]: + """Estrai+parsea+valida. Ritorna (strategy, error). Esattamente uno e' None.""" + payload = _extract_json(text) + if payload is None: + return None, "no JSON object found in output" + try: + ast = parse_strategy(payload) + validate_strategy(ast) + except (ParseError, ValidationError) as e: + return None, str(e) + return ast, None + + class HypothesisAgent: - def __init__(self, llm: LLMClient): + def __init__(self, llm: LLMClient, max_retries: int = 1): + if max_retries < 0: + raise ValueError("max_retries must be >= 0") self._llm = llm + self._max_retries = max_retries def propose( self, @@ -188,7 +229,7 @@ class HypothesisAgent: cognitive_style=genome.cognitive_style, system_prompt=genome.system_prompt, ) - user = USER_TEMPLATE.format( + original_user = USER_TEMPLATE.format( symbol=market.symbol, timeframe=market.timeframe, n_bars=market.n_bars, @@ -201,28 +242,45 @@ class HypothesisAgent: lookback_window=genome.lookback_window, ) - completion = self._llm.complete(genome, system=system, user=user) + completions: list[CompletionResult] = [] + errors: list[str] = [] + last_raw = "" + max_attempts = 1 + self._max_retries - payload = _extract_json(completion.text) - if payload is None: - return HypothesisProposal( - strategy=None, - raw_text=completion.text, - completion=completion, - parse_error="no JSON object found in output", - ) - try: - ast = parse_strategy(payload) - validate_strategy(ast) - return HypothesisProposal( - strategy=ast, - raw_text=completion.text, - completion=completion, - ) - except (ParseError, ValidationError) as e: - return HypothesisProposal( - strategy=None, - raw_text=completion.text, - completion=completion, - parse_error=str(e), - ) + for attempt in range(max_attempts): + if attempt == 0: + user = original_user + else: + truncated = last_raw[:_RETRY_RAW_TRUNCATE] + user = _RETRY_TEMPLATE.format( + original_user=original_user, + previous_raw=truncated, + previous_error=errors[-1], + ) + + completion = self._llm.complete(genome, system=system, user=user) + completions.append(completion) + last_raw = completion.text + + strategy, err = _try_parse(completion.text) + if strategy is not None: + return HypothesisProposal( + strategy=strategy, + raw_text=completion.text, + completions=completions, + parse_error=None, + n_attempts=len(completions), + ) + assert err is not None + errors.append(err) + + chained = " | ".join( + f"attempt {i + 1}: {e}" for i, e in enumerate(errors) + ) + return HypothesisProposal( + strategy=None, + raw_text=last_raw, + completions=completions, + parse_error=chained, + n_attempts=len(completions), + ) diff --git a/src/multi_swarm/orchestrator/run.py b/src/multi_swarm/orchestrator/run.py index 9b7504b..ee23e67 100644 --- a/src/multi_swarm/orchestrator/run.py +++ b/src/multi_swarm/orchestrator/run.py @@ -99,21 +99,23 @@ def run_phase1( continue # elite gia' valutata in generazione precedente repo.save_genome(run_id=run_id, generation_idx=gen, genome=genome) proposal = hypothesis_agent.propose(genome, market) - cost_record = cost_tracker.record( - input_tokens=proposal.completion.input_tokens, - output_tokens=proposal.completion.output_tokens, - tier=proposal.completion.tier, - run_id=run_id, - agent_id=genome.id, - ) - repo.save_cost_record( - run_id=run_id, - agent_id=genome.id, - tier=cost_record.tier.value, - input_tokens=cost_record.input_tokens, - output_tokens=cost_record.output_tokens, - cost_usd=cost_record.cost_usd, - ) + # Registra costo per OGNI completion (incluse retry). + for completion in proposal.completions: + cost_record = cost_tracker.record( + input_tokens=completion.input_tokens, + output_tokens=completion.output_tokens, + tier=completion.tier, + run_id=run_id, + agent_id=genome.id, + ) + repo.save_cost_record( + run_id=run_id, + agent_id=genome.id, + tier=cost_record.tier.value, + input_tokens=cost_record.input_tokens, + output_tokens=cost_record.output_tokens, + cost_usd=cost_record.cost_usd, + ) if proposal.strategy is None: repo.save_evaluation( diff --git a/tests/unit/test_hypothesis_agent.py b/tests/unit/test_hypothesis_agent.py index c37dd54..7050473 100644 --- a/tests/unit/test_hypothesis_agent.py +++ b/tests/unit/test_hypothesis_agent.py @@ -60,7 +60,8 @@ def test_hypothesis_agent_calls_llm_and_parses(mocker): # type: ignore[no-untyp agent = HypothesisAgent(llm=fake_llm) proposal = agent.propose(make_genome(), make_summary()) assert proposal.strategy is not None - assert proposal.completion.input_tokens == 200 + assert proposal.completions[0].input_tokens == 200 + assert proposal.n_attempts == 1 fake_llm.complete.assert_called_once() @@ -73,10 +74,12 @@ def test_hypothesis_agent_returns_none_on_parse_error(mocker): # type: ignore[n tier=ModelTier.C, model="qwen", ) - agent = HypothesisAgent(llm=fake_llm) + agent = HypothesisAgent(llm=fake_llm, max_retries=0) proposal = agent.propose(make_genome(), make_summary()) assert proposal.strategy is None assert proposal.parse_error is not None + assert proposal.n_attempts == 1 + assert fake_llm.complete.call_count == 1 def test_hypothesis_agent_extracts_json_from_markdown_fence(mocker): # type: ignore[no-untyped-def] @@ -123,8 +126,91 @@ def test_hypothesis_agent_returns_error_on_invalid_strategy(mocker): # type: ig tier=ModelTier.C, model="qwen", ) - agent = HypothesisAgent(llm=fake_llm) + agent = HypothesisAgent(llm=fake_llm, max_retries=0) proposal = agent.propose(make_genome(), make_summary()) assert proposal.strategy is None assert proposal.parse_error is not None assert "wibble" in proposal.parse_error or "unknown" in proposal.parse_error + + +def test_hypothesis_agent_retries_on_parse_error_and_succeeds(mocker): # type: ignore[no-untyped-def] + """Primo output malformato → secondo output valido → strategia accettata.""" + fake_llm = mocker.MagicMock() + fake_llm.complete.side_effect = [ + CompletionResult( + text="this is not JSON at all", + input_tokens=200, + output_tokens=80, + tier=ModelTier.C, + model="qwen", + ), + CompletionResult( + text="```json\n" + VALID_STRATEGY_JSON + "\n```", + input_tokens=300, + output_tokens=120, + tier=ModelTier.C, + model="qwen", + ), + ] + agent = HypothesisAgent(llm=fake_llm, max_retries=1) + proposal = agent.propose(make_genome(), make_summary()) + assert proposal.strategy is not None + assert proposal.n_attempts == 2 + assert len(proposal.completions) == 2 + assert proposal.completions[0].input_tokens == 200 + assert proposal.completions[1].input_tokens == 300 + assert fake_llm.complete.call_count == 2 + # Il secondo prompt user deve contenere il marker corrective. + second_call_kwargs = fake_llm.complete.call_args_list[1].kwargs + assert "TENTATIVO PRECEDENTE FALLITO" in second_call_kwargs["user"] + assert "this is not JSON at all" in second_call_kwargs["user"] + + +def test_hypothesis_agent_gives_up_after_max_retries(mocker): # type: ignore[no-untyped-def] + """Entrambi i tentativi falliscono → strategy None, errori concatenati.""" + fake_llm = mocker.MagicMock() + fake_llm.complete.side_effect = [ + CompletionResult( + text="garbage attempt 1", + input_tokens=200, + output_tokens=50, + tier=ModelTier.C, + model="qwen", + ), + CompletionResult( + text="garbage attempt 2", + input_tokens=250, + output_tokens=60, + tier=ModelTier.C, + model="qwen", + ), + ] + agent = HypothesisAgent(llm=fake_llm, max_retries=1) + proposal = agent.propose(make_genome(), make_summary()) + assert proposal.strategy is None + assert proposal.n_attempts == 2 + assert len(proposal.completions) == 2 + assert fake_llm.complete.call_count == 2 + assert proposal.parse_error is not None + assert "attempt 1" in proposal.parse_error + assert "attempt 2" in proposal.parse_error + # raw_text deve riflettere l'ULTIMO output (non il primo). + assert proposal.raw_text == "garbage attempt 2" + + +def test_hypothesis_agent_no_retry_when_first_succeeds(mocker): # type: ignore[no-untyped-def] + """Primo tentativo OK → nessun retry, anche con max_retries=1 di default.""" + fake_llm = mocker.MagicMock() + fake_llm.complete.return_value = CompletionResult( + text=VALID_STRATEGY_JSON, + input_tokens=200, + output_tokens=80, + tier=ModelTier.C, + model="qwen", + ) + agent = HypothesisAgent(llm=fake_llm) # default max_retries=1 + proposal = agent.propose(make_genome(), make_summary()) + assert proposal.strategy is not None + assert proposal.n_attempts == 1 + assert len(proposal.completions) == 1 + assert fake_llm.complete.call_count == 1