Files
NewsSite/ai-worker/worker.py
hubble_dubble 4e235c2f28 modified: .gitignore
modified:   ai-worker/worker.py
	modified:   server-app/main.go
2026-01-26 01:04:58 +01:00

177 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import base64
import logging
import os
import socket
import threading
import time
from io import BytesIO
from pathlib import Path
from typing import Optional
import psycopg2
import torch
from diffusers import StableDiffusionPipeline
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
DB_URL = os.environ.get("DATABASE_URL", "postgres://app:appsecret@database:5432/appdb?sslmode=disable")
DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "models" / "sd15"
MODEL_PATH = Path(DEFAULT_MODEL_PATH)
IMAGE_OUTPUT_DIR = Path(os.environ.get("IMAGE_OUTPUT_DIR") or "/app/images")
IMAGE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_ID = "stabilityai/sd-turbo"
WORKER_ID = os.environ.get("WORKER_ID") or f"{socket.gethostname()}:{os.getpid()}"
BATCH_SIZE = int(os.environ.get("IMAGE_BATCH_SIZE") or "3")
CLAIM_TTL_MINUTES = int(os.environ.get("IMAGE_CLAIM_TTL_MINUTES") or "20")
pipe = None
pipe_lock = threading.Lock()
def connect():
return psycopg2.connect(DB_URL)
def load_pipeline() -> Optional[StableDiffusionPipeline]:
global pipe
if pipe is not None:
return pipe
with pipe_lock:
if pipe is not None:
return pipe
try:
# Prüfen, ob lokal schon etwas im Modellordner liegt
has_local_model = MODEL_PATH.is_dir() and any(MODEL_PATH.iterdir())
if has_local_model:
logging.info("Lade Stable Diffusion Modell lokal aus %s", MODEL_PATH)
pipe = StableDiffusionPipeline.from_pretrained(
str(MODEL_PATH),
torch_dtype=torch.float32,
local_files_only=True,
)
else:
logging.info("Kein lokales Modell gefunden lade von Hugging Face (%s)", MODEL_ID)
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32
)
pipe.save_pretrained(MODEL_PATH)
logging.info("Modell erfolgreich nach %s gespeichert", MODEL_PATH)
pipe = pipe.to("cpu")
pipe.enable_attention_slicing()
return pipe
except Exception as e:
logging.error("Konnte Pipeline nicht laden: %s", e)
pipe = None
return None
def fetch_articles_without_image(cur, limit=BATCH_SIZE):
cur.execute(
"""
WITH claim AS (
SELECT id
FROM articles
WHERE image IS NULL
AND (
image_claimed_at IS NULL
OR image_claimed_at < now() - (%s * INTERVAL '1 minute')
)
ORDER BY created_at DESC
LIMIT %s
FOR UPDATE SKIP LOCKED
)
UPDATE articles AS a
SET image_claimed_at = now(),
image_claimed_by = %s
FROM claim
WHERE a.id = claim.id
RETURNING a.id, a.title, a.article_id
""",
(CLAIM_TTL_MINUTES, limit, WORKER_ID),
)
return cur.fetchall()
def update_image(cur, article_id: int, data_uri: str):
cur.execute(
"""
UPDATE articles
SET image = %s,
image_claimed_at = NULL,
image_claimed_by = NULL
WHERE id = %s
""",
(data_uri, article_id),
)
def release_claim(cur, article_id: int):
cur.execute(
"""
UPDATE articles
SET image_claimed_at = NULL,
image_claimed_by = NULL
WHERE id = %s AND image IS NULL
""",
(article_id,),
)
def safe_filename(name: str) -> str:
safe = "".join(c if c.isalnum() or c in ("-", "_") else "_" for c in name)
return safe or "image"
def generate_image(prompt: str) -> Optional[tuple[str, bytes]]:
model = load_pipeline()
if model is None:
return None
try:
img = model(
prompt=prompt,
num_inference_steps=8,
guidance_scale=7,
).images[0]
buf = BytesIO()
img.save(buf, format="PNG")
img_bytes = buf.getvalue()
data = base64.b64encode(img_bytes).decode("ascii")
return f"data:image/png;base64,{data}", img_bytes
except Exception as e:
logging.error("Bildgenerierung fehlgeschlagen: %s", e)
return None
def main():
while True:
try:
with connect() as conn:
with conn.cursor() as cur:
rows = fetch_articles_without_image(cur)
if not rows:
logging.info("keine neuen Artikel ohne Bild")
for aid, title, article_id in rows:
prompt = title or "news illustration"
result = generate_image(prompt)
if result:
data_uri, img_bytes = result
update_image(cur, aid, data_uri)
logging.info("Bild gesetzt für Artikel %s", aid)
else:
release_claim(cur, aid)
conn.commit()
except Exception as e:
logging.error("Fehler im Worker: %s", e)
time.sleep(30)
if __name__ == "__main__":
main()