fix(moa): count reference (advisor) fan-out token usage + cost (#56087)

MoA ran the reference models before the aggregator but returned only the
aggregator's usage to the loop — _run_reference discarded each advisor
response's .usage entirely. Session accounting (state.db, /insights, cost)
therefore undercounted every MoA turn by the whole reference fan-out, which
is usually the bulk of the spend and scales with advisor count.

- _run_reference normalizes each advisor's usage with ITS OWN resolved
  provider/api_mode and prices it at ITS OWN model rate (correct cache-read/
  cache-write split), returning a _RefAccounting(usage, cost).
- create() sums advisor usage + cost once per turn (cache MISS only, so a
  repeat tool-iteration reusing cached advice does not double-charge) and
  exposes it via MoAClient.consume_reference_usage().
- conversation_loop folds advisor tokens into the reported/persisted token
  counts and adds advisor cost (priced per-advisor) on top of the
  aggregator cost, in both the in-memory session totals and the state.db
  per-call delta. Aggregator cost is still priced on aggregator-only usage
  so advisor tokens are never repriced at the aggregator rate.
- CanonicalUsage gains __add__ for per-bucket summing.

Tests: advisor usage/cost capture, per-turn sum + consume-clears +
cache-hit no-double-charge, CanonicalUsage.__add__.
This commit is contained in:
Teknium 2026-06-30 23:08:37 -07:00 committed by GitHub
parent 44ddc552f5
commit 3bdb23de10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 339 additions and 19 deletions

View file

@ -1922,6 +1922,25 @@ def run_conversation(
provider=agent.provider,
api_mode=agent.api_mode,
)
# Aggregator-only usage is retained for cost pricing: MoA
# advisor tokens must be priced at each advisor's OWN model
# rate, not the aggregator's, so they are added as dollars
# (below) rather than folded into the priced usage.
aggregator_usage = canonical_usage
# MoA: fold the reference (advisor) fan-out's token usage
# into this turn's REPORTED token counts. MoA runs advisors
# before the aggregator and returns only the aggregator's
# usage, so without this the entire advisor spend — usually
# the bulk of a MoA turn — is invisible in token counts.
_moa_ref_cost = None
_moa_client = getattr(agent, "client", None)
if _moa_client is not None and hasattr(_moa_client, "consume_reference_usage"):
try:
_ref_usage, _moa_ref_cost = _moa_client.consume_reference_usage()
if _ref_usage is not None:
canonical_usage = canonical_usage + _ref_usage
except Exception as _moa_acct_exc: # pragma: no cover - defensive
logger.debug("MoA reference usage accounting failed: %s", _moa_acct_exc)
prompt_tokens = canonical_usage.prompt_tokens
completion_tokens = canonical_usage.output_tokens
total_tokens = canonical_usage.total_tokens
@ -1975,13 +1994,20 @@ def run_conversation(
cost_result = estimate_usage_cost(
agent.model,
canonical_usage,
aggregator_usage,
provider=agent.provider,
base_url=agent.base_url,
api_key=getattr(agent, "api_key", ""),
)
if cost_result.amount_usd is not None:
agent.session_estimated_cost_usd += float(cost_result.amount_usd)
# Add MoA advisor cost (already priced per-advisor at each
# advisor's own model rate) on top of the aggregator cost.
if _moa_ref_cost is not None:
try:
agent.session_estimated_cost_usd += float(_moa_ref_cost)
except (TypeError, ValueError): # pragma: no cover - defensive
pass
agent.session_cost_status = cost_result.status
agent.session_cost_source = cost_result.source
@ -2002,6 +2028,18 @@ def run_conversation(
# affects 0 rows without error).
if not agent._session_db_created:
agent._ensure_db_session()
# Per-call cost delta = aggregator cost + MoA
# advisor cost (each priced at its own rate). Folded
# here so state.db's estimated_cost_usd includes the
# full MoA spend, matching the folded token counts.
_cost_delta = None
if cost_result.amount_usd is not None:
_cost_delta = float(cost_result.amount_usd)
if _moa_ref_cost is not None:
try:
_cost_delta = (_cost_delta or 0.0) + float(_moa_ref_cost)
except (TypeError, ValueError): # pragma: no cover
pass
agent._session_db.update_token_counts(
agent.session_id,
input_tokens=canonical_usage.input_tokens,
@ -2009,8 +2047,7 @@ def run_conversation(
cache_read_tokens=canonical_usage.cache_read_tokens,
cache_write_tokens=canonical_usage.cache_write_tokens,
reasoning_tokens=canonical_usage.reasoning_tokens,
estimated_cost_usd=float(cost_result.amount_usd)
if cost_result.amount_usd is not None else None,
estimated_cost_usd=_cost_delta,
cost_status=cost_result.status,
cost_source=cost_result.source,
billing_provider=agent.provider,

View file

@ -26,6 +26,27 @@ logger = logging.getLogger(__name__)
# opening dozens of sockets at once.
_MAX_REFERENCE_WORKERS = 8
class _RefAccounting:
"""Per-reference token usage + estimated cost, carried as the third slot
of a reference-output tuple.
Kept as a tiny object (not a bare CanonicalUsage) because an advisor may
run on a different model/provider than the aggregator, so its cost MUST be
priced at its OWN model's rate — folding advisor tokens into the
aggregator's usage and pricing the sum at the aggregator's rate would
misprice every advisor. ``usage`` feeds accurate token counts;
``cost_usd`` feeds accurate cost.
"""
__slots__ = ("usage", "cost_usd", "cost_status", "cost_source")
def __init__(self, usage: Any, cost_usd: Any = None, cost_status: str | None = None, cost_source: str | None = None):
self.usage = usage
self.cost_usd = cost_usd
self.cost_status = cost_status
self.cost_source = cost_source
# Per-tool-result character budget for the advisory reference view. Tool
# results can be huge (a full diff, a 5000-line file dump); replaying them
# verbatim per reference per tool-loop step would blow the reference model's
@ -125,8 +146,8 @@ def _run_reference(
*,
temperature: float | None = None,
max_tokens: int | None = None,
) -> tuple[str, str]:
"""Call one reference model and return ``(label, text)``.
) -> tuple[str, str, Any]:
"""Call one reference model and return ``(label, text, usage)``.
The slot is resolved to its provider's real runtime (via ``_slot_runtime``)
and called through the same ``call_llm`` request-building path any model
@ -137,12 +158,23 @@ def _run_reference(
real maximum); ``temperature`` is only the user's configured preset value,
which call_llm may still override per model.
The reference's token usage is normalized with the slot's OWN resolved
provider/api_mode (advisors may run on a different provider than the
aggregator, with different usage wire shapes) and returned as a
``CanonicalUsage`` so the caller can fold advisor spend into session
accounting. Without this, the entire reference fan-out often the bulk of
a MoA turn's token spend — is invisible to cost tracking, which only ever
saw the aggregator's usage.
Never raises: a failed reference becomes a labelled note so the aggregator
can still act with partial context. Designed to run inside a thread pool
``call_llm`` is synchronous/blocking, so threads (not asyncio) are the right
concurrency primitive, mirroring ``delegate_task``'s batch fan-out.
"""
from agent.usage_pricing import CanonicalUsage, estimate_usage_cost, normalize_usage
label = _slot_label(slot)
runtime = _slot_runtime(slot)
try:
# Prepend the advisory-role system prompt so the reference understands
# it is analyzing state for an aggregator, not acting on the task. The
@ -154,12 +186,44 @@ def _run_reference(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
**_slot_runtime(slot),
**runtime,
)
return label, _extract_text(response) or "(empty response)"
usage = CanonicalUsage()
raw_usage = getattr(response, "usage", None)
if raw_usage:
try:
usage = normalize_usage(
raw_usage,
provider=runtime.get("provider"),
api_mode=runtime.get("api_mode"),
)
except Exception: # pragma: no cover - defensive
usage = CanonicalUsage()
# Price this advisor at ITS OWN model/provider rate (with correct
# cache-read/cache-write split), not the aggregator's. This is why
# advisor cost is summed as dollars rather than by folding tokens into
# the aggregator's usage.
cost_usd = None
cost_status = None
cost_source = None
try:
cost = estimate_usage_cost(
slot.get("model") or "",
usage,
provider=runtime.get("provider"),
base_url=runtime.get("base_url"),
api_key=runtime.get("api_key"),
)
cost_usd = cost.amount_usd
cost_status = cost.status
cost_source = cost.source
except Exception: # pragma: no cover - defensive
pass
acct = _RefAccounting(usage, cost_usd, cost_status, cost_source)
return label, _extract_text(response) or "(empty response)", acct
except Exception as exc:
logger.warning("MoA reference model %s failed: %s", label, exc)
return label, f"[failed: {exc}]"
return label, f"[failed: {exc}]", _RefAccounting(CanonicalUsage())
def _run_references_parallel(
@ -168,7 +232,7 @@ def _run_references_parallel(
*,
temperature: float | None = None,
max_tokens: int | None = None,
) -> list[tuple[str, str]]:
) -> list[tuple[str, str, Any]]:
"""Fan out all reference models in parallel, returning outputs in order.
Like ``delegate_task``'s batch mode, every reference is dispatched at once
@ -176,11 +240,16 @@ def _run_references_parallel(
the aggregator. Output order matches ``reference_models`` so the
``Reference {idx}`` labelling stays stable. MoA presets that reference
another MoA preset are skipped here (recursion guard) with a labelled note.
Each element is ``(label, text, usage)`` where usage is a
``CanonicalUsage`` (zeroed for skipped/failed references).
"""
from agent.usage_pricing import CanonicalUsage
if not reference_models:
return []
results: list[tuple[str, str] | None] = [None] * len(reference_models)
results: list[tuple[str, str, Any] | None] = [None] * len(reference_models)
futures = {}
workers = min(_MAX_REFERENCE_WORKERS, len(reference_models))
with ThreadPoolExecutor(max_workers=workers) as executor:
@ -189,6 +258,7 @@ def _run_references_parallel(
results[idx] = (
_slot_label(slot),
"[skipped: MoA presets cannot recursively reference MoA]",
_RefAccounting(CanonicalUsage()),
)
continue
futures[
@ -390,7 +460,7 @@ def aggregate_moa_context(
sidesteps providers that reject ``max_tokens`` outright. A hardcoded cap
here previously truncated long aggregator syntheses.
"""
reference_outputs: list[tuple[str, str]] = []
reference_outputs: list[tuple[str, str, Any]] = []
ref_messages = _reference_messages(api_messages)
reference_outputs = _run_references_parallel(
reference_models,
@ -401,7 +471,7 @@ def aggregate_moa_context(
joined = "\n\n".join(
f"Reference {idx}{label}:\n{text}"
for idx, (label, text) in enumerate(reference_outputs, start=1)
for idx, (label, text, _usage) in enumerate(reference_outputs, start=1)
)
synth_prompt = (
"You are the aggregator in a Mixture of Agents process. Synthesize the "
@ -465,7 +535,33 @@ class MoAChatCompletions:
# re-run, no re-emit). This gives "fire on every user/tool response"
# for free, without re-firing on a pure no-op re-call.
self._ref_cache_key: tuple | None = None
self._ref_cache_outputs: list[tuple[str, str]] = []
self._ref_cache_outputs: list[tuple[str, str, Any]] = []
# Token usage + estimated cost of the reference fan-out from the most
# recent cache-MISS create() call, awaiting consumption by session
# accounting. Set on every create() (zeroed on a cache HIT so per-turn
# advisor spend is counted exactly once). Consumed via
# ``consume_reference_usage``.
from agent.usage_pricing import CanonicalUsage
self._pending_reference_usage: Any = CanonicalUsage()
self._pending_reference_cost: Any = None
def consume_reference_usage(self) -> tuple[Any, Any]:
"""Pop pending reference-fan-out usage + cost, resetting both to empty.
Returns ``(CanonicalUsage, cost_usd_or_None)`` for the most recent
``create()`` and clears the pending values, so a subsequent read (e.g.
a streaming retry re-entering accounting) cannot double-count. Usage is
always a ``CanonicalUsage`` (zeroed if none); cost is a summed-dollars
float or ``None`` when no advisor could be priced.
"""
from agent.usage_pricing import CanonicalUsage
usage = self._pending_reference_usage or CanonicalUsage()
cost = self._pending_reference_cost
self._pending_reference_usage = CanonicalUsage()
self._pending_reference_cost = None
return usage, cost
def _emit(self, event: str, **kwargs: Any) -> None:
cb = self.reference_callback
@ -497,7 +593,9 @@ class MoAChatCompletions:
if not preset.get("enabled", True):
reference_models = []
reference_outputs: list[tuple[str, str]] = []
from agent.usage_pricing import CanonicalUsage
reference_outputs: list[tuple[str, str, Any]] = []
ref_messages = _reference_messages(messages)
# Turn-scoped cache: only run + display references when the advisory
@ -514,6 +612,12 @@ class MoAChatCompletions:
if _refs_from_cache:
reference_outputs = list(self._ref_cache_outputs)
# References already ran (and were accounted) earlier this turn;
# this create() is a repeat tool-iteration reusing the cached
# advice. Charging their tokens/cost again here would multiply
# advisor spend by the tool-iteration count, so pending is zero.
self._pending_reference_usage = CanonicalUsage()
self._pending_reference_cost = None
else:
reference_outputs = _run_references_parallel(
reference_models,
@ -523,6 +627,24 @@ class MoAChatCompletions:
)
self._ref_cache_key = _cache_key
self._ref_cache_outputs = list(reference_outputs)
# Sum the advisor fan-out's token usage AND cost so the caller can
# fold advisor spend into session accounting exactly once per turn.
# Only the freshly run references (cache MISS) contribute; a cache
# HIT above zeroes this. Token counts sum directly (each already
# normalized per-advisor provider/api_mode); cost sums in dollars
# because each advisor was priced at its OWN model rate — advisors
# may be cheaper/pricier than the aggregator, so their tokens must
# NOT be repriced at the aggregator's rate.
_ref_usage = CanonicalUsage()
_ref_cost: Any = None
for _lbl, _txt, _acct in reference_outputs:
if isinstance(_acct, _RefAccounting):
if isinstance(_acct.usage, CanonicalUsage):
_ref_usage = _ref_usage + _acct.usage
if _acct.cost_usd is not None:
_ref_cost = (_ref_cost or 0) + _acct.cost_usd
self._pending_reference_usage = _ref_usage
self._pending_reference_cost = _ref_cost
# Surface each reference model's answer to the display BEFORE the
# aggregator acts — once per turn (only on the iteration that
@ -531,7 +653,7 @@ class MoAChatCompletions:
# visible rather than a silent pause. Best-effort: never blocks the
# turn.
_ref_count = len(reference_outputs)
for _idx, (_label, _text) in enumerate(reference_outputs, start=1):
for _idx, (_label, _text, _usage) in enumerate(reference_outputs, start=1):
self._emit(
"moa.reference",
index=_idx,
@ -550,13 +672,13 @@ class MoAChatCompletions:
if reference_outputs:
joined = "\n\n".join(
f"Reference {idx}{label}:\n{text}"
for idx, (label, text) in enumerate(reference_outputs, start=1)
for idx, (label, text, _usage) in enumerate(reference_outputs, start=1)
)
guidance = (
"[Mixture of Agents reference context]\n"
f"Preset: {self.preset_name}\n"
f"Aggregator/acting model: {_slot_label(aggregator)}\n"
f"References: {', '.join(label for label, _ in reference_outputs)}\n\n"
f"References: {', '.join(label for label, _, _ in reference_outputs)}\n\n"
"Use the reference responses below as private context. You are the aggregator and acting model: "
"answer the user directly or call tools as needed.\n\n"
f"{joined}"
@ -614,3 +736,11 @@ class MoAClient:
def __init__(self, preset_name: str, reference_callback: Any = None):
self.chat = type("_MoAChat", (), {})()
self.chat.completions = MoAChatCompletions(preset_name, reference_callback=reference_callback)
def consume_reference_usage(self) -> Any:
"""Pop the pending reference-fan-out usage from the completions facade.
Lets session accounting fold the MoA advisor tokens into the turn's
usage without reaching into ``.chat.completions`` internals.
"""
return self.chat.completions.consume_reference_usage()

View file

@ -45,6 +45,25 @@ class CanonicalUsage:
def total_tokens(self) -> int:
return self.prompt_tokens + self.output_tokens
def __add__(self, other: "CanonicalUsage") -> "CanonicalUsage":
"""Sum two usage buckets (e.g. MoA advisor fan-out + aggregator).
``raw_usage`` is dropped on the sum it describes a single API
response and cannot be meaningfully merged. ``request_count`` adds so
callers can see how many underlying API calls a combined figure covers.
"""
if not isinstance(other, CanonicalUsage):
return NotImplemented
return CanonicalUsage(
input_tokens=self.input_tokens + other.input_tokens,
output_tokens=self.output_tokens + other.output_tokens,
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
request_count=self.request_count + other.request_count,
raw_usage=None,
)
@dataclass(frozen=True)
class BillingRoute:

View file

@ -410,7 +410,7 @@ def test_run_reference_prepends_advisory_system_prompt(monkeypatch):
monkeypatch.setattr("agent.moa_loop.call_llm", fake_call_llm)
label, text = _run_reference(
label, text, _acct = _run_reference(
{"provider": "openai-codex", "model": "gpt-5.5"},
[{"role": "user", "content": "review this PR"}],
)
@ -568,7 +568,7 @@ def test_references_run_in_parallel(monkeypatch):
# Two 0.5s sleeps run concurrently → well under the 1.0s serial floor.
assert elapsed < 0.9, f"references did not run in parallel (took {elapsed:.2f}s)"
# Output order matches input order (stable Reference N labelling).
assert [label for label, _ in out] == ["p1:ok", "moa:preset", "p2:boom", "p3:ok"]
assert [label for label, _, _ in out] == ["p1:ok", "moa:preset", "p2:boom", "p3:ok"]
assert "recursively reference MoA" in out[1][1]
assert out[2][1].startswith("[failed:")
assert out[0][1] == "resp-p1"
@ -750,3 +750,137 @@ def test_slot_runtime_anthropic_oauth_routes_through_provider_branch(monkeypatch
assert other_rt["model"] == "some-model"
assert other_rt["base_url"] == "https://resolved.example/v1"
assert other_rt["api_key"] == "resolved-key"
def _response_with_usage(content="advice", *, prompt=100, completion=50, cached=0):
"""A fake response carrying OpenAI-style usage so normalize_usage works."""
details = SimpleNamespace(cached_tokens=cached, cache_write_tokens=0)
usage = SimpleNamespace(
prompt_tokens=prompt,
completion_tokens=completion,
prompt_tokens_details=details,
output_tokens_details=None,
)
message = SimpleNamespace(content=content, tool_calls=[])
choice = SimpleNamespace(message=message, finish_reason="stop")
return SimpleNamespace(choices=[choice], usage=usage, model="fake-model")
def test_run_reference_captures_usage_and_cost(monkeypatch):
"""A reference call returns per-advisor CanonicalUsage + priced cost.
Before this, _run_reference discarded response.usage entirely, so the
advisor fan-out was invisible to cost tracking.
"""
from agent.moa_loop import _RefAccounting, _run_reference
from agent.usage_pricing import CanonicalUsage
monkeypatch.setattr(
"agent.moa_loop.call_llm",
lambda **kw: _response_with_usage(prompt=1000, completion=200, cached=400),
)
# Keep runtime resolution + pricing deterministic.
monkeypatch.setattr(
"agent.moa_loop._slot_runtime",
lambda slot: {"provider": "openrouter", "model": slot.get("model")},
)
monkeypatch.setattr(
"agent.usage_pricing.estimate_usage_cost",
lambda *a, **k: SimpleNamespace(amount_usd=0.0123, status="estimated", source="table"),
)
label, text, acct = _run_reference(
{"provider": "openrouter", "model": "vendor/adv-model"},
[{"role": "user", "content": "state?"}],
)
assert text == "advice"
assert isinstance(acct, _RefAccounting)
assert isinstance(acct.usage, CanonicalUsage)
# prompt_tokens=1000 with 400 cached → 600 fresh input + 400 cache_read.
assert acct.usage.input_tokens == 600
assert acct.usage.cache_read_tokens == 400
assert acct.usage.output_tokens == 200
assert acct.cost_usd == 0.0123
def test_references_parallel_sum_and_consume(monkeypatch, tmp_path):
"""create() sums advisor usage + cost once per turn; consume clears it.
Repeat tool-iterations within a turn reuse the cache and contribute ZERO
additional advisor spend (otherwise advisor cost multiplies by iteration
count).
"""
home = tmp_path / ".hermes"
home.mkdir()
(home / "config.yaml").write_text(
"""
moa:
default_preset: review
presets:
review:
reference_models:
- provider: openrouter
model: adv-a
- provider: openrouter
model: adv-b
aggregator:
provider: openrouter
model: anthropic/claude-opus-4.8
""".strip(),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(home))
def fake_call_llm(**kwargs):
if kwargs["task"] == "moa_reference":
return _response_with_usage(prompt=1000, completion=100, cached=0)
return _response("aggregator acted")
monkeypatch.setattr("agent.moa_loop.call_llm", fake_call_llm)
monkeypatch.setattr(
"agent.moa_loop._slot_runtime",
lambda slot: {"provider": "openrouter", "model": slot.get("model")},
)
monkeypatch.setattr(
"agent.usage_pricing.estimate_usage_cost",
lambda *a, **k: SimpleNamespace(amount_usd=0.01, status="estimated", source="table"),
)
from agent.moa_loop import MoAChatCompletions
facade = MoAChatCompletions("review")
facade.create(messages=[{"role": "user", "content": "turn one"}], tools=[])
usage, cost = facade.consume_reference_usage()
# Two advisors × (1000 input, 100 output) = 2000 input, 200 output.
assert usage.input_tokens == 2000
assert usage.output_tokens == 200
# Two advisors × $0.01 each = $0.02.
assert cost == pytest.approx(0.02)
# consume clears — a second consume with no new create() is zeroed.
usage2, cost2 = facade.consume_reference_usage()
assert usage2.input_tokens == 0
assert cost2 is None
# A repeat create() with the SAME advisory view is a cache HIT: advisors
# do not re-run, so pending advisor spend is zero (no double-charge).
facade.create(messages=[{"role": "user", "content": "turn one"}], tools=[])
usage3, cost3 = facade.consume_reference_usage()
assert usage3.input_tokens == 0
assert cost3 is None
def test_canonical_usage_add():
"""CanonicalUsage sums per bucket (used to fold advisor tokens in)."""
from agent.usage_pricing import CanonicalUsage
a = CanonicalUsage(input_tokens=100, output_tokens=20, cache_read_tokens=5)
b = CanonicalUsage(input_tokens=50, output_tokens=10, cache_write_tokens=3)
total = a + b
assert total.input_tokens == 150
assert total.output_tokens == 30
assert total.cache_read_tokens == 5
assert total.cache_write_tokens == 3
assert total.request_count == 2