"""Tests for chart binding validation rules.""" import pytest from dataclasses import dataclass, field # --------------------------------------------------------------------------- # Stub domain objects -- replace with real imports later # --------------------------------------------------------------------------- @dataclass class Binding: channel: str # e.g. "x", "y", "color", "size" field_id: str @dataclass class ChartSpec: chart_type: str bindings: list[Binding] = field(default_factory=list) def validate(self) -> list[str]: """Return a list of validation error messages (empty == valid).""" errors: list[str] = [] required = REQUIRED_BINDINGS.get(self.chart_type, []) bound_channels = {b.channel for b in self.bindings} for channel in required: if channel not in bound_channels: errors.append( f"Chart type '{self.chart_type}' requires binding for '{channel}'." ) return errors REQUIRED_BINDINGS: dict[str, list[str]] = { "bar": ["x", "y"], "line": ["x", "y"], "pie": ["theta"], "scatter": ["x", "y"], } # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestBindingValidation: """Ensure chart specs enforce required bindings per chart type.""" def test_bar_chart_requires_x_and_y(self): spec = ChartSpec(chart_type="bar", bindings=[]) errors = spec.validate() assert any("'x'" in e for e in errors) assert any("'y'" in e for e in errors) def test_bar_chart_valid_with_x_and_y(self): spec = ChartSpec( chart_type="bar", bindings=[ Binding(channel="x", field_id="category"), Binding(channel="y", field_id="amount"), ], ) assert spec.validate() == [] def test_pie_chart_requires_theta(self): spec = ChartSpec(chart_type="pie", bindings=[]) errors = spec.validate() assert any("'theta'" in e for e in errors) def test_pie_chart_valid_with_theta(self): spec = ChartSpec( chart_type="pie", bindings=[Binding(channel="theta", field_id="sales")], ) assert spec.validate() == [] def test_line_chart_missing_y(self): spec = ChartSpec( chart_type="line", bindings=[Binding(channel="x", field_id="date")], ) errors = spec.validate() assert any("'y'" in e for e in errors) assert not any("'x'" in e for e in errors) def test_scatter_chart_requires_x_and_y(self): spec = ChartSpec(chart_type="scatter", bindings=[]) errors = spec.validate() assert len(errors) == 2