feat(api): add validator, FastAPI app structure, and health endpoint
Wave 3 tasks complete:
- Task 7: Validator with 4 checks (pflichtfelder, betraege, ustid, pdf_abgleich)
- Task 8: FastAPI app with CORS, exception handlers, JSON logging
- Task 9: Health endpoint returning status and version
Features:
- validate_invoice() runs selected validation checks
- Exception handlers for ExtractionError and generic errors
- GET /health returns {status: healthy, version: 1.0.0}
Tests: 52 validator tests covering all validation rules
This commit is contained in:
74
src/main.py
74
src/main.py
@@ -1,7 +1,37 @@
|
||||
"""FastAPI application for ZUGFeRD invoice processing."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.extractor import ExtractionError
|
||||
from src.models import HealthResponse
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
log_data = {
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"level": record.levelname,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
if hasattr(record, "data"):
|
||||
log_data["data"] = record.data
|
||||
return json.dumps(log_data)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JSONFormatter())
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="ZUGFeRD Service",
|
||||
@@ -9,6 +39,48 @@ app = FastAPI(
|
||||
description="REST API for ZUGFeRD invoice extraction and validation",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(ExtractionError)
|
||||
async def extraction_error_handler(request: Request, exc: ExtractionError):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error": exc.error_code,
|
||||
"message": exc.message,
|
||||
"details": exc.details,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_error_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Internal error: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "internal_error",
|
||||
"message": "An internal error occurred",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check() -> HealthResponse:
|
||||
"""Health check endpoint.
|
||||
|
||||
Returns:
|
||||
HealthResponse with status and version.
|
||||
"""
|
||||
return HealthResponse(status="healthy", version="1.0.0")
|
||||
|
||||
|
||||
def run(host: str = "0.0.0.0", port: int = 5000) -> None:
|
||||
"""Run the FastAPI application.
|
||||
|
||||
@@ -160,6 +160,13 @@ class ValidateResponse(BaseModel):
|
||||
result: ValidationResult = Field(description="Validation result")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(description="Service status")
|
||||
version: str = Field(description="Service version")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response."""
|
||||
|
||||
|
||||
334
src/validator.py
334
src/validator.py
@@ -1,3 +1,333 @@
|
||||
"""Validation module for ZUGFeRD invoices."""
|
||||
"""Validation functions for ZUGFeRD invoices."""
|
||||
|
||||
pass
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from src.models import (
|
||||
ErrorDetail,
|
||||
ValidateRequest,
|
||||
ValidationResult,
|
||||
XmlData,
|
||||
)
|
||||
from src.utils import amounts_match
|
||||
|
||||
|
||||
def validate_pflichtfelder(xml_data: XmlData) -> list[ErrorDetail]:
|
||||
"""Check required fields are present."""
|
||||
errors = []
|
||||
|
||||
def add_error(field: str, severity: str) -> None:
|
||||
errors.append(
|
||||
ErrorDetail(
|
||||
check="pflichtfelder",
|
||||
field=field,
|
||||
error_code="missing_required",
|
||||
message=f"Required field '{field}' is missing or empty",
|
||||
severity=severity,
|
||||
)
|
||||
)
|
||||
|
||||
# Critical fields
|
||||
if not xml_data.invoice_number or not xml_data.invoice_number.strip():
|
||||
add_error("invoice_number", "critical")
|
||||
|
||||
if not xml_data.invoice_date or not xml_data.invoice_date.strip():
|
||||
add_error("invoice_date", "critical")
|
||||
|
||||
if not xml_data.supplier.name or not xml_data.supplier.name.strip():
|
||||
add_error("supplier.name", "critical")
|
||||
|
||||
if not xml_data.supplier.vat_id or not xml_data.supplier.vat_id.strip():
|
||||
add_error("supplier.vat_id", "critical")
|
||||
|
||||
if not xml_data.buyer.name or not xml_data.buyer.name.strip():
|
||||
add_error("buyer.name", "critical")
|
||||
|
||||
if xml_data.totals.net == 0:
|
||||
add_error("totals.net", "critical")
|
||||
|
||||
if xml_data.totals.gross == 0:
|
||||
add_error("totals.gross", "critical")
|
||||
|
||||
if xml_data.totals.vat_total == 0:
|
||||
add_error("totals.vat_total", "critical")
|
||||
|
||||
# Warning fields
|
||||
if xml_data.due_date is not None and not xml_data.due_date.strip():
|
||||
add_error("due_date", "warning")
|
||||
|
||||
if (
|
||||
xml_data.payment_terms is not None
|
||||
and xml_data.payment_terms.iban is not None
|
||||
and not xml_data.payment_terms.iban.strip()
|
||||
):
|
||||
add_error("payment_terms.iban", "warning")
|
||||
|
||||
# Line items
|
||||
if not xml_data.line_items or len(xml_data.line_items) == 0:
|
||||
add_error("line_items", "critical")
|
||||
else:
|
||||
for idx, item in enumerate(xml_data.line_items):
|
||||
field_prefix = f"line_items[{idx}]"
|
||||
|
||||
if not item.description or not item.description.strip():
|
||||
add_error(f"{field_prefix}.description", "critical")
|
||||
|
||||
if item.quantity == 0:
|
||||
add_error(f"{field_prefix}.quantity", "critical")
|
||||
|
||||
if item.unit_price == 0:
|
||||
add_error(f"{field_prefix}.unit_price", "critical")
|
||||
|
||||
if item.line_total == 0:
|
||||
add_error(f"{field_prefix}.line_total", "critical")
|
||||
|
||||
if item.vat_rate is None:
|
||||
add_error(f"{field_prefix}.vat_rate", "warning")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_betraege(xml_data: XmlData) -> list[ErrorDetail]:
|
||||
"""Check amount calculations are correct."""
|
||||
errors = []
|
||||
|
||||
def add_mismatch(field: str, expected: float, actual: float) -> None:
|
||||
errors.append(
|
||||
ErrorDetail(
|
||||
check="betraege",
|
||||
field=field,
|
||||
error_code="calculation_mismatch",
|
||||
message=f"Calculation mismatch for '{field}': expected {expected}, got {actual}",
|
||||
severity="critical",
|
||||
)
|
||||
)
|
||||
|
||||
# Check line_total = quantity × unit_price
|
||||
for idx, item in enumerate(xml_data.line_items):
|
||||
expected_line_total = item.quantity * item.unit_price
|
||||
if not amounts_match(item.line_total, expected_line_total):
|
||||
add_mismatch(
|
||||
f"line_items[{idx}].line_total",
|
||||
expected_line_total,
|
||||
item.line_total,
|
||||
)
|
||||
|
||||
# Check totals.net = sum(line_items.line_total)
|
||||
line_total_sum = sum(item.line_total for item in xml_data.line_items)
|
||||
if not amounts_match(xml_data.totals.net, line_total_sum):
|
||||
add_mismatch("totals.net", line_total_sum, xml_data.totals.net)
|
||||
|
||||
# Check vat_breakdown.amount = base × (rate/100)
|
||||
for idx, vat_breakdown in enumerate(xml_data.totals.vat_breakdown):
|
||||
expected_amount = vat_breakdown.base * (vat_breakdown.rate / 100)
|
||||
if not amounts_match(vat_breakdown.amount, expected_amount):
|
||||
add_mismatch(
|
||||
f"totals.vat_breakdown[{idx}].amount",
|
||||
expected_amount,
|
||||
vat_breakdown.amount,
|
||||
)
|
||||
|
||||
# Check totals.vat_total = sum(vat_breakdown.amount)
|
||||
vat_breakdown_sum = sum(vb.amount for vb in xml_data.totals.vat_breakdown)
|
||||
if not amounts_match(xml_data.totals.vat_total, vat_breakdown_sum):
|
||||
add_mismatch("totals.vat_total", vat_breakdown_sum, xml_data.totals.vat_total)
|
||||
|
||||
# Check totals.gross = totals.net + totals.vat_total
|
||||
expected_gross = xml_data.totals.net + xml_data.totals.vat_total
|
||||
if not amounts_match(xml_data.totals.gross, expected_gross):
|
||||
add_mismatch("totals.gross", expected_gross, xml_data.totals.gross)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_ustid(vat_id: str) -> ErrorDetail | None:
|
||||
"""Check VAT ID format (returns None if valid)."""
|
||||
if not vat_id or not vat_id.strip():
|
||||
return ErrorDetail(
|
||||
check="ustid",
|
||||
field="vat_id",
|
||||
error_code="invalid_format",
|
||||
message="VAT ID is empty",
|
||||
severity="critical",
|
||||
)
|
||||
|
||||
vat_id = vat_id.strip()
|
||||
|
||||
# German VAT ID: DE followed by 9 digits
|
||||
if vat_id.startswith("DE"):
|
||||
if re.match(r"^DE[0-9]{9}$", vat_id):
|
||||
return None
|
||||
return ErrorDetail(
|
||||
check="ustid",
|
||||
field="vat_id",
|
||||
error_code="invalid_format",
|
||||
message=f"Invalid German VAT ID format: {vat_id}",
|
||||
severity="critical",
|
||||
)
|
||||
|
||||
# Austrian VAT ID: ATU followed by 8 digits
|
||||
if vat_id.startswith("AT"):
|
||||
if re.match(r"^ATU[0-9]{8}$", vat_id):
|
||||
return None
|
||||
return ErrorDetail(
|
||||
check="ustid",
|
||||
field="vat_id",
|
||||
error_code="invalid_format",
|
||||
message=f"Invalid Austrian VAT ID format: {vat_id}",
|
||||
severity="critical",
|
||||
)
|
||||
|
||||
# Swiss VAT ID: CHE followed by 9 digits and MWST/TVA/IVA suffix
|
||||
if vat_id.startswith("CH"):
|
||||
if re.match(r"^CHE[0-9]{9}(MWST|TVA|IVA)$", vat_id):
|
||||
return None
|
||||
return ErrorDetail(
|
||||
check="ustid",
|
||||
field="vat_id",
|
||||
error_code="invalid_format",
|
||||
message=f"Invalid Swiss VAT ID format: {vat_id}",
|
||||
severity="critical",
|
||||
)
|
||||
|
||||
return ErrorDetail(
|
||||
check="ustid",
|
||||
field="vat_id",
|
||||
error_code="invalid_format",
|
||||
message=f"Unknown country code or invalid VAT ID format: {vat_id}",
|
||||
severity="critical",
|
||||
)
|
||||
|
||||
|
||||
def validate_pdf_abgleich(xml_data: XmlData, pdf_values: dict) -> list[ErrorDetail]:
|
||||
"""Compare XML values to PDF extracted values."""
|
||||
errors = []
|
||||
|
||||
def add_mismatch(field: str, xml_value: Any, pdf_value: Any) -> None:
|
||||
errors.append(
|
||||
ErrorDetail(
|
||||
check="pdf_abgleich",
|
||||
field=field,
|
||||
error_code="pdf_mismatch",
|
||||
message=f"PDF mismatch for '{field}': XML has {xml_value}, PDF has {pdf_value}",
|
||||
severity="warning",
|
||||
)
|
||||
)
|
||||
|
||||
# Invoice number (exact match)
|
||||
if "invoice_number" in pdf_values:
|
||||
pdf_invoice = pdf_values["invoice_number"]
|
||||
if xml_data.invoice_number != pdf_invoice:
|
||||
add_mismatch("invoice_number", xml_data.invoice_number, pdf_invoice)
|
||||
|
||||
# Totals.gross (within tolerance)
|
||||
if "totals.gross" in pdf_values:
|
||||
try:
|
||||
pdf_gross = float(pdf_values["totals.gross"])
|
||||
if not amounts_match(xml_data.totals.gross, pdf_gross):
|
||||
add_mismatch("totals.gross", xml_data.totals.gross, pdf_gross)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Totals.net (within tolerance)
|
||||
if "totals.net" in pdf_values:
|
||||
try:
|
||||
pdf_net = float(pdf_values["totals.net"])
|
||||
if not amounts_match(xml_data.totals.net, pdf_net):
|
||||
add_mismatch("totals.net", xml_data.totals.net, pdf_net)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Totals.vat_total (within tolerance)
|
||||
if "totals.vat_total" in pdf_values:
|
||||
try:
|
||||
pdf_vat = float(pdf_values["totals.vat_total"])
|
||||
if not amounts_match(xml_data.totals.vat_total, pdf_vat):
|
||||
add_mismatch("totals.vat_total", xml_data.totals.vat_total, pdf_vat)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_invoice(request: ValidateRequest) -> ValidationResult:
|
||||
"""Run selected validation checks."""
|
||||
start_time = time.time()
|
||||
all_errors = []
|
||||
all_warnings = []
|
||||
|
||||
xml_data = XmlData(**request.xml_data)
|
||||
|
||||
checks_run = 0
|
||||
checks_passed = 0
|
||||
|
||||
# Run requested checks
|
||||
for check_name in request.checks:
|
||||
check_errors: list[ErrorDetail] = []
|
||||
|
||||
if check_name == "pflichtfelder":
|
||||
check_errors = validate_pflichtfelder(xml_data)
|
||||
checks_run += 1
|
||||
elif check_name == "betraege":
|
||||
check_errors = validate_betraege(xml_data)
|
||||
checks_run += 1
|
||||
elif check_name == "ustid":
|
||||
# Check supplier VAT ID
|
||||
if xml_data.supplier.vat_id:
|
||||
error = validate_ustid(xml_data.supplier.vat_id)
|
||||
if error:
|
||||
check_errors.append(error)
|
||||
# Check buyer VAT ID if present
|
||||
if xml_data.buyer.vat_id:
|
||||
error = validate_ustid(xml_data.buyer.vat_id)
|
||||
if error:
|
||||
check_errors.append(error)
|
||||
checks_run += 1
|
||||
elif check_name == "pdf_abgleich":
|
||||
if request.pdf_text:
|
||||
# For simplicity, try to extract values from PDF text
|
||||
pdf_values = {}
|
||||
try:
|
||||
if "Invoice" in request.pdf_text:
|
||||
parts = request.pdf_text.split()
|
||||
if len(parts) > 1:
|
||||
pdf_values["invoice_number"] = parts[1]
|
||||
if "Total:" in request.pdf_text:
|
||||
parts = request.pdf_text.split("Total:")
|
||||
if len(parts) > 1:
|
||||
total_str = parts[1].strip().split()[0]
|
||||
pdf_values["totals.gross"] = total_str
|
||||
except Exception:
|
||||
pass
|
||||
check_errors = validate_pdf_abgleich(xml_data, pdf_values)
|
||||
checks_run += 1
|
||||
|
||||
# Separate errors and warnings
|
||||
critical_errors = [e for e in check_errors if e.severity == "critical"]
|
||||
warnings = [e for e in check_errors if e.severity == "warning"]
|
||||
all_errors.extend(critical_errors)
|
||||
all_warnings.extend(warnings)
|
||||
|
||||
if len(critical_errors) == 0:
|
||||
checks_passed += 1
|
||||
|
||||
validation_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
is_valid = len(all_errors) == 0
|
||||
|
||||
summary = {
|
||||
"total_checks": checks_run,
|
||||
"checks_passed": checks_passed,
|
||||
"checks_failed": checks_run - checks_passed,
|
||||
"critical_errors": len(all_errors),
|
||||
"warnings": len(all_warnings),
|
||||
}
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=is_valid,
|
||||
errors=all_errors,
|
||||
warnings=all_warnings,
|
||||
summary=summary,
|
||||
validation_time_ms=validation_time_ms,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user