177 lines
5.2 KiB
Python
177 lines
5.2 KiB
Python
import base64
|
||
import logging
|
||
import os
|
||
import socket
|
||
import threading
|
||
import time
|
||
from io import BytesIO
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
import psycopg2
|
||
import torch
|
||
from diffusers import StableDiffusionPipeline
|
||
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||
|
||
DB_URL = os.environ.get("DATABASE_URL", "postgres://app:appsecret@database:5432/appdb?sslmode=disable")
|
||
DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "models" / "sd15"
|
||
MODEL_PATH = Path(DEFAULT_MODEL_PATH)
|
||
IMAGE_OUTPUT_DIR = Path(os.environ.get("IMAGE_OUTPUT_DIR") or "/app/images")
|
||
IMAGE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
MODEL_ID = "stabilityai/sd-turbo"
|
||
WORKER_ID = os.environ.get("WORKER_ID") or f"{socket.gethostname()}:{os.getpid()}"
|
||
BATCH_SIZE = int(os.environ.get("IMAGE_BATCH_SIZE") or "3")
|
||
CLAIM_TTL_MINUTES = int(os.environ.get("IMAGE_CLAIM_TTL_MINUTES") or "20")
|
||
pipe = None
|
||
pipe_lock = threading.Lock()
|
||
|
||
|
||
def connect():
|
||
return psycopg2.connect(DB_URL)
|
||
|
||
|
||
def load_pipeline() -> Optional[StableDiffusionPipeline]:
|
||
global pipe
|
||
|
||
if pipe is not None:
|
||
return pipe
|
||
|
||
with pipe_lock:
|
||
if pipe is not None:
|
||
return pipe
|
||
|
||
try:
|
||
# Prüfen, ob lokal schon etwas im Modellordner liegt
|
||
has_local_model = MODEL_PATH.is_dir() and any(MODEL_PATH.iterdir())
|
||
|
||
if has_local_model:
|
||
logging.info("Lade Stable Diffusion Modell lokal aus %s", MODEL_PATH)
|
||
pipe = StableDiffusionPipeline.from_pretrained(
|
||
str(MODEL_PATH),
|
||
torch_dtype=torch.float32,
|
||
local_files_only=True,
|
||
)
|
||
else:
|
||
logging.info("Kein lokales Modell gefunden – lade von Hugging Face (%s)", MODEL_ID)
|
||
pipe = StableDiffusionPipeline.from_pretrained(
|
||
MODEL_ID,
|
||
torch_dtype=torch.float32
|
||
)
|
||
pipe.save_pretrained(MODEL_PATH)
|
||
logging.info("Modell erfolgreich nach %s gespeichert", MODEL_PATH)
|
||
|
||
pipe = pipe.to("cpu")
|
||
pipe.enable_attention_slicing()
|
||
return pipe
|
||
|
||
except Exception as e:
|
||
logging.error("Konnte Pipeline nicht laden: %s", e)
|
||
pipe = None
|
||
return None
|
||
|
||
|
||
def fetch_articles_without_image(cur, limit=BATCH_SIZE):
|
||
cur.execute(
|
||
"""
|
||
WITH claim AS (
|
||
SELECT id
|
||
FROM articles
|
||
WHERE image IS NULL
|
||
AND (
|
||
image_claimed_at IS NULL
|
||
OR image_claimed_at < now() - (%s * INTERVAL '1 minute')
|
||
)
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
FOR UPDATE SKIP LOCKED
|
||
)
|
||
UPDATE articles AS a
|
||
SET image_claimed_at = now(),
|
||
image_claimed_by = %s
|
||
FROM claim
|
||
WHERE a.id = claim.id
|
||
RETURNING a.id, a.title, a.article_id
|
||
""",
|
||
(CLAIM_TTL_MINUTES, limit, WORKER_ID),
|
||
)
|
||
return cur.fetchall()
|
||
|
||
|
||
def update_image(cur, article_id: int, data_uri: str):
|
||
cur.execute(
|
||
"""
|
||
UPDATE articles
|
||
SET image = %s,
|
||
image_claimed_at = NULL,
|
||
image_claimed_by = NULL
|
||
WHERE id = %s
|
||
""",
|
||
(data_uri, article_id),
|
||
)
|
||
|
||
|
||
def release_claim(cur, article_id: int):
|
||
cur.execute(
|
||
"""
|
||
UPDATE articles
|
||
SET image_claimed_at = NULL,
|
||
image_claimed_by = NULL
|
||
WHERE id = %s AND image IS NULL
|
||
""",
|
||
(article_id,),
|
||
)
|
||
|
||
|
||
def safe_filename(name: str) -> str:
|
||
safe = "".join(c if c.isalnum() or c in ("-", "_") else "_" for c in name)
|
||
return safe or "image"
|
||
|
||
|
||
def generate_image(prompt: str) -> Optional[tuple[str, bytes]]:
|
||
model = load_pipeline()
|
||
if model is None:
|
||
return None
|
||
try:
|
||
img = model(
|
||
prompt=prompt,
|
||
num_inference_steps=8,
|
||
guidance_scale=7,
|
||
).images[0]
|
||
buf = BytesIO()
|
||
img.save(buf, format="PNG")
|
||
img_bytes = buf.getvalue()
|
||
data = base64.b64encode(img_bytes).decode("ascii")
|
||
return f"data:image/png;base64,{data}", img_bytes
|
||
except Exception as e:
|
||
logging.error("Bildgenerierung fehlgeschlagen: %s", e)
|
||
return None
|
||
|
||
|
||
def main():
|
||
while True:
|
||
try:
|
||
with connect() as conn:
|
||
with conn.cursor() as cur:
|
||
rows = fetch_articles_without_image(cur)
|
||
if not rows:
|
||
logging.info("keine neuen Artikel ohne Bild")
|
||
for aid, title, article_id in rows:
|
||
prompt = title or "news illustration"
|
||
result = generate_image(prompt)
|
||
if result:
|
||
data_uri, img_bytes = result
|
||
update_image(cur, aid, data_uri)
|
||
logging.info("Bild gesetzt für Artikel %s", aid)
|
||
else:
|
||
release_claim(cur, aid)
|
||
conn.commit()
|
||
except Exception as e:
|
||
logging.error("Fehler im Worker: %s", e)
|
||
|
||
time.sleep(30)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|