feat(api): add validator, FastAPI app structure, and health endpoint

Wave 3 tasks complete:
- Task 7: Validator with 4 checks (pflichtfelder, betraege, ustid, pdf_abgleich)
- Task 8: FastAPI app with CORS, exception handlers, JSON logging
- Task 9: Health endpoint returning status and version

Features:
- validate_invoice() runs selected validation checks
- Exception handlers for ExtractionError and generic errors
- GET /health returns {status: healthy, version: 1.0.0}

Tests: 52 validator tests covering all validation rules
This commit is contained in:
m3tm3re
2026-02-04 19:57:12 +01:00
parent c1f603cd46
commit 4791c91f06
6 changed files with 1795 additions and 6 deletions

View File

@@ -1,7 +1,37 @@
"""FastAPI application for ZUGFeRD invoice processing."""
import json
import logging
from datetime import datetime
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from src.extractor import ExtractionError
from src.models import HealthResponse
class JSONFormatter(logging.Formatter):
def format(self, record):
log_data = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"level": record.levelname,
"message": record.getMessage(),
}
if hasattr(record, "data"):
log_data["data"] = record.data
return json.dumps(log_data)
logger = logging.getLogger(__name__)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(JSONFormatter())
logger.addHandler(handler)
logger.setLevel(logging.INFO)
app = FastAPI(
title="ZUGFeRD Service",
@@ -9,6 +39,48 @@ app = FastAPI(
description="REST API for ZUGFeRD invoice extraction and validation",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(ExtractionError)
async def extraction_error_handler(request: Request, exc: ExtractionError):
return JSONResponse(
status_code=400,
content={
"error": exc.error_code,
"message": exc.message,
"details": exc.details,
},
)
@app.exception_handler(Exception)
async def generic_error_handler(request: Request, exc: Exception):
logger.error(f"Internal error: {exc}")
return JSONResponse(
status_code=500,
content={
"error": "internal_error",
"message": "An internal error occurred",
},
)
@app.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
"""Health check endpoint.
Returns:
HealthResponse with status and version.
"""
return HealthResponse(status="healthy", version="1.0.0")
def run(host: str = "0.0.0.0", port: int = 5000) -> None:
"""Run the FastAPI application.

View File

@@ -160,6 +160,13 @@ class ValidateResponse(BaseModel):
result: ValidationResult = Field(description="Validation result")
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(description="Service status")
version: str = Field(description="Service version")
class ErrorResponse(BaseModel):
"""Error response."""

View File

@@ -1,3 +1,333 @@
"""Validation module for ZUGFeRD invoices."""
"""Validation functions for ZUGFeRD invoices."""
pass
import re
import time
from typing import Any
from src.models import (
ErrorDetail,
ValidateRequest,
ValidationResult,
XmlData,
)
from src.utils import amounts_match
def validate_pflichtfelder(xml_data: XmlData) -> list[ErrorDetail]:
"""Check required fields are present."""
errors = []
def add_error(field: str, severity: str) -> None:
errors.append(
ErrorDetail(
check="pflichtfelder",
field=field,
error_code="missing_required",
message=f"Required field '{field}' is missing or empty",
severity=severity,
)
)
# Critical fields
if not xml_data.invoice_number or not xml_data.invoice_number.strip():
add_error("invoice_number", "critical")
if not xml_data.invoice_date or not xml_data.invoice_date.strip():
add_error("invoice_date", "critical")
if not xml_data.supplier.name or not xml_data.supplier.name.strip():
add_error("supplier.name", "critical")
if not xml_data.supplier.vat_id or not xml_data.supplier.vat_id.strip():
add_error("supplier.vat_id", "critical")
if not xml_data.buyer.name or not xml_data.buyer.name.strip():
add_error("buyer.name", "critical")
if xml_data.totals.net == 0:
add_error("totals.net", "critical")
if xml_data.totals.gross == 0:
add_error("totals.gross", "critical")
if xml_data.totals.vat_total == 0:
add_error("totals.vat_total", "critical")
# Warning fields
if xml_data.due_date is not None and not xml_data.due_date.strip():
add_error("due_date", "warning")
if (
xml_data.payment_terms is not None
and xml_data.payment_terms.iban is not None
and not xml_data.payment_terms.iban.strip()
):
add_error("payment_terms.iban", "warning")
# Line items
if not xml_data.line_items or len(xml_data.line_items) == 0:
add_error("line_items", "critical")
else:
for idx, item in enumerate(xml_data.line_items):
field_prefix = f"line_items[{idx}]"
if not item.description or not item.description.strip():
add_error(f"{field_prefix}.description", "critical")
if item.quantity == 0:
add_error(f"{field_prefix}.quantity", "critical")
if item.unit_price == 0:
add_error(f"{field_prefix}.unit_price", "critical")
if item.line_total == 0:
add_error(f"{field_prefix}.line_total", "critical")
if item.vat_rate is None:
add_error(f"{field_prefix}.vat_rate", "warning")
return errors
def validate_betraege(xml_data: XmlData) -> list[ErrorDetail]:
"""Check amount calculations are correct."""
errors = []
def add_mismatch(field: str, expected: float, actual: float) -> None:
errors.append(
ErrorDetail(
check="betraege",
field=field,
error_code="calculation_mismatch",
message=f"Calculation mismatch for '{field}': expected {expected}, got {actual}",
severity="critical",
)
)
# Check line_total = quantity × unit_price
for idx, item in enumerate(xml_data.line_items):
expected_line_total = item.quantity * item.unit_price
if not amounts_match(item.line_total, expected_line_total):
add_mismatch(
f"line_items[{idx}].line_total",
expected_line_total,
item.line_total,
)
# Check totals.net = sum(line_items.line_total)
line_total_sum = sum(item.line_total for item in xml_data.line_items)
if not amounts_match(xml_data.totals.net, line_total_sum):
add_mismatch("totals.net", line_total_sum, xml_data.totals.net)
# Check vat_breakdown.amount = base × (rate/100)
for idx, vat_breakdown in enumerate(xml_data.totals.vat_breakdown):
expected_amount = vat_breakdown.base * (vat_breakdown.rate / 100)
if not amounts_match(vat_breakdown.amount, expected_amount):
add_mismatch(
f"totals.vat_breakdown[{idx}].amount",
expected_amount,
vat_breakdown.amount,
)
# Check totals.vat_total = sum(vat_breakdown.amount)
vat_breakdown_sum = sum(vb.amount for vb in xml_data.totals.vat_breakdown)
if not amounts_match(xml_data.totals.vat_total, vat_breakdown_sum):
add_mismatch("totals.vat_total", vat_breakdown_sum, xml_data.totals.vat_total)
# Check totals.gross = totals.net + totals.vat_total
expected_gross = xml_data.totals.net + xml_data.totals.vat_total
if not amounts_match(xml_data.totals.gross, expected_gross):
add_mismatch("totals.gross", expected_gross, xml_data.totals.gross)
return errors
def validate_ustid(vat_id: str) -> ErrorDetail | None:
"""Check VAT ID format (returns None if valid)."""
if not vat_id or not vat_id.strip():
return ErrorDetail(
check="ustid",
field="vat_id",
error_code="invalid_format",
message="VAT ID is empty",
severity="critical",
)
vat_id = vat_id.strip()
# German VAT ID: DE followed by 9 digits
if vat_id.startswith("DE"):
if re.match(r"^DE[0-9]{9}$", vat_id):
return None
return ErrorDetail(
check="ustid",
field="vat_id",
error_code="invalid_format",
message=f"Invalid German VAT ID format: {vat_id}",
severity="critical",
)
# Austrian VAT ID: ATU followed by 8 digits
if vat_id.startswith("AT"):
if re.match(r"^ATU[0-9]{8}$", vat_id):
return None
return ErrorDetail(
check="ustid",
field="vat_id",
error_code="invalid_format",
message=f"Invalid Austrian VAT ID format: {vat_id}",
severity="critical",
)
# Swiss VAT ID: CHE followed by 9 digits and MWST/TVA/IVA suffix
if vat_id.startswith("CH"):
if re.match(r"^CHE[0-9]{9}(MWST|TVA|IVA)$", vat_id):
return None
return ErrorDetail(
check="ustid",
field="vat_id",
error_code="invalid_format",
message=f"Invalid Swiss VAT ID format: {vat_id}",
severity="critical",
)
return ErrorDetail(
check="ustid",
field="vat_id",
error_code="invalid_format",
message=f"Unknown country code or invalid VAT ID format: {vat_id}",
severity="critical",
)
def validate_pdf_abgleich(xml_data: XmlData, pdf_values: dict) -> list[ErrorDetail]:
"""Compare XML values to PDF extracted values."""
errors = []
def add_mismatch(field: str, xml_value: Any, pdf_value: Any) -> None:
errors.append(
ErrorDetail(
check="pdf_abgleich",
field=field,
error_code="pdf_mismatch",
message=f"PDF mismatch for '{field}': XML has {xml_value}, PDF has {pdf_value}",
severity="warning",
)
)
# Invoice number (exact match)
if "invoice_number" in pdf_values:
pdf_invoice = pdf_values["invoice_number"]
if xml_data.invoice_number != pdf_invoice:
add_mismatch("invoice_number", xml_data.invoice_number, pdf_invoice)
# Totals.gross (within tolerance)
if "totals.gross" in pdf_values:
try:
pdf_gross = float(pdf_values["totals.gross"])
if not amounts_match(xml_data.totals.gross, pdf_gross):
add_mismatch("totals.gross", xml_data.totals.gross, pdf_gross)
except (ValueError, TypeError):
pass
# Totals.net (within tolerance)
if "totals.net" in pdf_values:
try:
pdf_net = float(pdf_values["totals.net"])
if not amounts_match(xml_data.totals.net, pdf_net):
add_mismatch("totals.net", xml_data.totals.net, pdf_net)
except (ValueError, TypeError):
pass
# Totals.vat_total (within tolerance)
if "totals.vat_total" in pdf_values:
try:
pdf_vat = float(pdf_values["totals.vat_total"])
if not amounts_match(xml_data.totals.vat_total, pdf_vat):
add_mismatch("totals.vat_total", xml_data.totals.vat_total, pdf_vat)
except (ValueError, TypeError):
pass
return errors
def validate_invoice(request: ValidateRequest) -> ValidationResult:
"""Run selected validation checks."""
start_time = time.time()
all_errors = []
all_warnings = []
xml_data = XmlData(**request.xml_data)
checks_run = 0
checks_passed = 0
# Run requested checks
for check_name in request.checks:
check_errors: list[ErrorDetail] = []
if check_name == "pflichtfelder":
check_errors = validate_pflichtfelder(xml_data)
checks_run += 1
elif check_name == "betraege":
check_errors = validate_betraege(xml_data)
checks_run += 1
elif check_name == "ustid":
# Check supplier VAT ID
if xml_data.supplier.vat_id:
error = validate_ustid(xml_data.supplier.vat_id)
if error:
check_errors.append(error)
# Check buyer VAT ID if present
if xml_data.buyer.vat_id:
error = validate_ustid(xml_data.buyer.vat_id)
if error:
check_errors.append(error)
checks_run += 1
elif check_name == "pdf_abgleich":
if request.pdf_text:
# For simplicity, try to extract values from PDF text
pdf_values = {}
try:
if "Invoice" in request.pdf_text:
parts = request.pdf_text.split()
if len(parts) > 1:
pdf_values["invoice_number"] = parts[1]
if "Total:" in request.pdf_text:
parts = request.pdf_text.split("Total:")
if len(parts) > 1:
total_str = parts[1].strip().split()[0]
pdf_values["totals.gross"] = total_str
except Exception:
pass
check_errors = validate_pdf_abgleich(xml_data, pdf_values)
checks_run += 1
# Separate errors and warnings
critical_errors = [e for e in check_errors if e.severity == "critical"]
warnings = [e for e in check_errors if e.severity == "warning"]
all_errors.extend(critical_errors)
all_warnings.extend(warnings)
if len(critical_errors) == 0:
checks_passed += 1
validation_time_ms = int((time.time() - start_time) * 1000)
is_valid = len(all_errors) == 0
summary = {
"total_checks": checks_run,
"checks_passed": checks_passed,
"checks_failed": checks_run - checks_passed,
"critical_errors": len(all_errors),
"warnings": len(all_warnings),
}
return ValidationResult(
is_valid=is_valid,
errors=all_errors,
warnings=all_warnings,
summary=summary,
validation_time_ms=validation_time_ms,
)