"""Модуль взаимодействия с Stable Diffusion WebUI API (Automatic1111).""" import base64 import logging from typing import Optional import aiohttp from config import settings logger = logging.getLogger(__name__) class SDClient: """Клиент для работы с Stable Diffusion WebUI API.""" def __init__(self, api_url: str): self.api_url = api_url.rstrip("/") async def check_connection(self) -> bool: """Проверка соединения с API.""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.api_url}/sdapi/v1/options", timeout=aiohttp.ClientTimeout(total=10) ) as response: return response.status == 200 except Exception as e: logger.error(f"Ошибка подключения к SD API: {e}") return False async def get_models(self) -> list[str]: """Получить список доступных моделей.""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.api_url}/sdapi/v1/sd-models", timeout=aiohttp.ClientTimeout(total=30) ) as response: if response.status == 200: data = await response.json() return [model.get("title", model.get("model_name", "")) for model in data] except Exception as e: logger.error(f"Ошибка получения списка моделей: {e}") return [] async def get_current_model(self) -> str: """Получить текущую загруженную модель.""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.api_url}/sdapi/v1/options", timeout=aiohttp.ClientTimeout(total=10) ) as response: if response.status == 200: data = await response.json() return data.get("sd_model_checkpoint", "Unknown") except Exception as e: logger.error(f"Ошибка получения текущей модели: {e}") return "Unknown" async def get_samplers(self) -> list[str]: """Получить список доступных сэмплеров.""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.api_url}/sdapi/v1/samplers", timeout=aiohttp.ClientTimeout(total=10) ) as response: if response.status == 200: data = await response.json() return [sampler.get("name", "") for sampler in data] except Exception as e: logger.error(f"Ошибка получения списка сэмплеров: {e}") return ["Euler a", "Euler", "DPM++ 2M Karras", "DPM++ SDE Karras", "DDIM"] async def get_schedulers(self) -> list[str]: """Получить список доступных шедулеров.""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.api_url}/sdapi/v1/schedulers", timeout=aiohttp.ClientTimeout(total=10) ) as response: if response.status == 200: data = await response.json() return [scheduler.get("name", "") for scheduler in data] except Exception as e: logger.error(f"Ошибка получения списка шедулеров: {e}") return ["automatic", "normal", "karras", "exponential", "SGM uniform", "simple", "DDIM"] async def get_loras(self) -> list[str]: """Получить список доступных LoRA.""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.api_url}/sdapi/v1/loras", timeout=aiohttp.ClientTimeout(total=10) ) as response: if response.status == 200: data = await response.json() return [lora.get("name", "") for lora in data] except Exception as e: logger.error(f"Ошибка получения списка LoRA: {e}") return [] async def txt2img( self, prompt: str, negative_prompt: str = "", width: int = 512, height: int = 512, steps: int = 20, cfg_scale: float = 7.0, sampler: str = "Euler a", scheduler: str = "automatic", seed: int = -1, model: Optional[str] = None, lora: Optional[str] = None, lora_strength: float = 0.8, ) -> Optional[tuple[bytes, dict]]: """ Генерация изображения из текста. Возвращает кортеж (изображение в bytes, info_dict) или None при ошибке. """ payload = { "prompt": prompt, "negative_prompt": negative_prompt, "steps": steps, "width": width, "height": height, "cfg_scale": cfg_scale, "sampler_name": sampler, "scheduler": scheduler, "seed": seed, "save_images": False, "send_images": True, } # Если указана модель, переключаем её if model: await self._switch_model(model) # Формируем промпт с LoRA final_prompt = prompt if lora: final_prompt = f" {prompt}" payload["prompt"] = final_prompt try: async with aiohttp.ClientSession() as session: async with session.post( f"{self.api_url}/sdapi/v1/txt2img", json=payload, timeout=aiohttp.ClientTimeout(total=600) # 10 минут на генерацию ) as response: if response.status == 200: data = await response.json() if data.get("images"): image_bytes = base64.b64decode(data["images"][0]) info = { "prompt": final_prompt, "negative_prompt": negative_prompt, "width": width, "height": height, "steps": steps, "cfg_scale": cfg_scale, "sampler": sampler, "scheduler": scheduler, "seed": data.get("parameters", {}).get("seed", seed), "model": model or await self.get_current_model(), "lora": lora, } return image_bytes, info else: logger.error("API не вернул изображение") return None else: error_text = await response.text() logger.error(f"Ошибка API txt2img: {response.status} - {error_text}") return None except aiohttp.ClientTimeout: logger.error("Таймаут запроса txt2img (10 минут)") return None except Exception as e: logger.error(f"Ошибка txt2img: {e}") return None async def img2img( self, init_image_bytes: bytes, prompt: str, negative_prompt: str = "", width: int = 512, height: int = 512, steps: int = 20, cfg_scale: float = 7.0, sampler: str = "Euler a", scheduler: str = "automatic", denoising_strength: float = 0.75, seed: int = -1, model: Optional[str] = None, lora: Optional[str] = None, lora_strength: float = 0.8, ) -> Optional[tuple[bytes, dict]]: """ Генерация изображения на основе изображения. Возвращает кортеж (изображение в bytes, info_dict) или None при ошибке. """ init_image_base64 = base64.b64encode(init_image_bytes).decode("utf-8") payload = { "init_images": [init_image_base64], "prompt": prompt, "negative_prompt": negative_prompt, "steps": steps, "width": width, "height": height, "cfg_scale": cfg_scale, "sampler_name": sampler, "scheduler": scheduler, "denoising_strength": denoising_strength, "seed": seed, "save_images": False, "send_images": True, } if model: await self._switch_model(model) final_prompt = prompt if lora: final_prompt = f" {prompt}" payload["prompt"] = final_prompt try: async with aiohttp.ClientSession() as session: async with session.post( f"{self.api_url}/sdapi/v1/img2img", json=payload, timeout=aiohttp.ClientTimeout(total=600) ) as response: if response.status == 200: data = await response.json() if data.get("images"): image_bytes = base64.b64decode(data["images"][0]) info = { "prompt": final_prompt, "negative_prompt": negative_prompt, "width": width, "height": height, "steps": steps, "cfg_scale": cfg_scale, "sampler": sampler, "scheduler": scheduler, "seed": data.get("parameters", {}).get("seed", seed), "denoising_strength": denoising_strength, "model": model or await self.get_current_model(), "lora": lora, } return image_bytes, info else: logger.error("API не вернул изображение") return None else: error_text = await response.text() logger.error(f"Ошибка API img2img: {response.status} - {error_text}") return None except aiohttp.ClientTimeout: logger.error("Таймаут запроса img2img (10 минут)") return None except Exception as e: logger.error(f"Ошибка img2img: {e}") return None async def _switch_model(self, model_name: str) -> bool: """Переключить модель.""" try: async with aiohttp.ClientSession() as session: async with session.post( f"{self.api_url}/sdapi/v1/options", json={"sd_model_checkpoint": model_name}, timeout=aiohttp.ClientTimeout(total=120) ) as response: if response.status == 200: logger.info(f"Модель переключена на: {model_name}") return True else: logger.error(f"Ошибка переключения модели: {response.status}") return False except Exception as e: logger.error(f"Ошибка переключения модели: {e}") return False async def get_progress(self, skip_headers: bool = False) -> Optional[dict]: """Получить текущий прогресс генерации.""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.api_url}/sdapi/v1/progress", params={"skip_current_image": skip_headers}, timeout=aiohttp.ClientTimeout(total=10) ) as response: if response.status == 200: return await response.json() except Exception as e: logger.error(f"Ошибка получения прогресса: {e}") return None async def interrupt(self) -> bool: """Прервать текущую генерацию.""" try: async with aiohttp.ClientSession() as session: async with session.post( f"{self.api_url}/sdapi/v1/interrupt", json={}, timeout=aiohttp.ClientTimeout(total=10) ) as response: return response.status == 200 except Exception as e: logger.error(f"Ошибка прерывания генерации: {e}") return False async def get_options(self) -> dict: """Получить текущие опции API.""" try: async with aiohttp.ClientSession() as session: async with session.get( f"{self.api_url}/sdapi/v1/options", timeout=aiohttp.ClientTimeout(total=10) ) as response: if response.status == 200: return await response.json() except Exception as e: logger.error(f"Ошибка получения опций: {e}") return {} # Глобальный экземпляр sd_client = SDClient(settings.SD_API_URL)