import base64 import logging import os import socket import threading import time from io import BytesIO from pathlib import Path from typing import Optional import psycopg2 import torch from diffusers import StableDiffusionPipeline logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") DB_URL = os.environ.get("DATABASE_URL", "postgres://app:appsecret@database:5432/appdb?sslmode=disable") DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "models" / "sd15" MODEL_PATH = Path(DEFAULT_MODEL_PATH) IMAGE_OUTPUT_DIR = Path(os.environ.get("IMAGE_OUTPUT_DIR") or "/app/images") IMAGE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) MODEL_ID = "stabilityai/sd-turbo" WORKER_ID = os.environ.get("WORKER_ID") or f"{socket.gethostname()}:{os.getpid()}" BATCH_SIZE = int(os.environ.get("IMAGE_BATCH_SIZE") or "3") CLAIM_TTL_MINUTES = int(os.environ.get("IMAGE_CLAIM_TTL_MINUTES") or "20") pipe = None pipe_lock = threading.Lock() def connect(): return psycopg2.connect(DB_URL) def load_pipeline() -> Optional[StableDiffusionPipeline]: global pipe if pipe is not None: return pipe with pipe_lock: if pipe is not None: return pipe try: # Prüfen, ob lokal schon etwas im Modellordner liegt has_local_model = MODEL_PATH.is_dir() and any(MODEL_PATH.iterdir()) if has_local_model: logging.info("Lade Stable Diffusion Modell lokal aus %s", MODEL_PATH) pipe = StableDiffusionPipeline.from_pretrained( str(MODEL_PATH), torch_dtype=torch.float32, local_files_only=True, ) else: logging.info("Kein lokales Modell gefunden – lade von Hugging Face (%s)", MODEL_ID) pipe = StableDiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float32 ) pipe.save_pretrained(MODEL_PATH) logging.info("Modell erfolgreich nach %s gespeichert", MODEL_PATH) pipe = pipe.to("cpu") pipe.enable_attention_slicing() return pipe except Exception as e: logging.error("Konnte Pipeline nicht laden: %s", e) pipe = None return None def fetch_articles_without_image(cur, limit=BATCH_SIZE): cur.execute( """ WITH claim AS ( SELECT id FROM articles WHERE image IS NULL AND ( image_claimed_at IS NULL OR image_claimed_at < now() - (%s * INTERVAL '1 minute') ) ORDER BY created_at DESC LIMIT %s FOR UPDATE SKIP LOCKED ) UPDATE articles AS a SET image_claimed_at = now(), image_claimed_by = %s FROM claim WHERE a.id = claim.id RETURNING a.id, a.title, a.article_id """, (CLAIM_TTL_MINUTES, limit, WORKER_ID), ) return cur.fetchall() def update_image(cur, article_id: int, data_uri: str): cur.execute( """ UPDATE articles SET image = %s, image_claimed_at = NULL, image_claimed_by = NULL WHERE id = %s """, (data_uri, article_id), ) def release_claim(cur, article_id: int): cur.execute( """ UPDATE articles SET image_claimed_at = NULL, image_claimed_by = NULL WHERE id = %s AND image IS NULL """, (article_id,), ) def safe_filename(name: str) -> str: safe = "".join(c if c.isalnum() or c in ("-", "_") else "_" for c in name) return safe or "image" def generate_image(prompt: str) -> Optional[tuple[str, bytes]]: model = load_pipeline() if model is None: return None try: img = model( prompt=prompt, num_inference_steps=8, guidance_scale=7, ).images[0] buf = BytesIO() img.save(buf, format="PNG") img_bytes = buf.getvalue() data = base64.b64encode(img_bytes).decode("ascii") return f"data:image/png;base64,{data}", img_bytes except Exception as e: logging.error("Bildgenerierung fehlgeschlagen: %s", e) return None def main(): while True: try: with connect() as conn: with conn.cursor() as cur: rows = fetch_articles_without_image(cur) if not rows: logging.info("keine neuen Artikel ohne Bild") for aid, title, article_id in rows: prompt = title or "news illustration" result = generate_image(prompt) if result: data_uri, img_bytes = result filename = f"{safe_filename(article_id)}.png" out_path = IMAGE_OUTPUT_DIR / filename try: out_path.write_bytes(img_bytes) logging.info("Bild gespeichert unter %s", out_path) except Exception as e: logging.error("Konnte Bild nicht speichern (%s): %s", out_path, e) update_image(cur, aid, data_uri) logging.info("Bild gesetzt für Artikel %s", aid) else: release_claim(cur, aid) conn.commit() except Exception as e: logging.error("Fehler im Worker: %s", e) time.sleep(30) if __name__ == "__main__": main()