342 lines
14 KiB
Python
342 lines
14 KiB
Python
|
|
"""Модуль взаимодействия с Stable Diffusion WebUI API (Automatic1111)."""
|
||
|
|
|
||
|
|
import base64
|
||
|
|
import logging
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
import aiohttp
|
||
|
|
|
||
|
|
from config import settings
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class SDClient:
|
||
|
|
"""Клиент для работы с Stable Diffusion WebUI API."""
|
||
|
|
|
||
|
|
def __init__(self, api_url: str):
|
||
|
|
self.api_url = api_url.rstrip("/")
|
||
|
|
|
||
|
|
async def check_connection(self) -> bool:
|
||
|
|
"""Проверка соединения с API."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self.api_url}/sdapi/v1/options",
|
||
|
|
timeout=aiohttp.ClientTimeout(total=10)
|
||
|
|
) as response:
|
||
|
|
return response.status == 200
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка подключения к SD API: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def get_models(self) -> list[str]:
|
||
|
|
"""Получить список доступных моделей."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self.api_url}/sdapi/v1/sd-models",
|
||
|
|
timeout=aiohttp.ClientTimeout(total=30)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
data = await response.json()
|
||
|
|
return [model.get("title", model.get("model_name", "")) for model in data]
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка получения списка моделей: {e}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
async def get_current_model(self) -> str:
|
||
|
|
"""Получить текущую загруженную модель."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self.api_url}/sdapi/v1/options",
|
||
|
|
timeout=aiohttp.ClientTimeout(total=10)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
data = await response.json()
|
||
|
|
return data.get("sd_model_checkpoint", "Unknown")
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка получения текущей модели: {e}")
|
||
|
|
return "Unknown"
|
||
|
|
|
||
|
|
async def get_samplers(self) -> list[str]:
|
||
|
|
"""Получить список доступных сэмплеров."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self.api_url}/sdapi/v1/samplers",
|
||
|
|
timeout=aiohttp.ClientTimeout(total=10)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
data = await response.json()
|
||
|
|
return [sampler.get("name", "") for sampler in data]
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка получения списка сэмплеров: {e}")
|
||
|
|
return ["Euler a", "Euler", "DPM++ 2M Karras", "DPM++ SDE Karras", "DDIM"]
|
||
|
|
|
||
|
|
async def get_schedulers(self) -> list[str]:
|
||
|
|
"""Получить список доступных шедулеров."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self.api_url}/sdapi/v1/schedulers",
|
||
|
|
timeout=aiohttp.ClientTimeout(total=10)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
data = await response.json()
|
||
|
|
return [scheduler.get("name", "") for scheduler in data]
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка получения списка шедулеров: {e}")
|
||
|
|
return ["automatic", "normal", "karras", "exponential", "SGM uniform", "simple", "DDIM"]
|
||
|
|
|
||
|
|
async def get_loras(self) -> list[str]:
|
||
|
|
"""Получить список доступных LoRA."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self.api_url}/sdapi/v1/loras",
|
||
|
|
timeout=aiohttp.ClientTimeout(total=10)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
data = await response.json()
|
||
|
|
return [lora.get("name", "") for lora in data]
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка получения списка LoRA: {e}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
async def txt2img(
|
||
|
|
self,
|
||
|
|
prompt: str,
|
||
|
|
negative_prompt: str = "",
|
||
|
|
width: int = 512,
|
||
|
|
height: int = 512,
|
||
|
|
steps: int = 20,
|
||
|
|
cfg_scale: float = 7.0,
|
||
|
|
sampler: str = "Euler a",
|
||
|
|
scheduler: str = "automatic",
|
||
|
|
seed: int = -1,
|
||
|
|
model: Optional[str] = None,
|
||
|
|
lora: Optional[str] = None,
|
||
|
|
lora_strength: float = 0.8,
|
||
|
|
) -> Optional[tuple[bytes, dict]]:
|
||
|
|
"""
|
||
|
|
Генерация изображения из текста.
|
||
|
|
Возвращает кортеж (изображение в bytes, info_dict) или None при ошибке.
|
||
|
|
"""
|
||
|
|
payload = {
|
||
|
|
"prompt": prompt,
|
||
|
|
"negative_prompt": negative_prompt,
|
||
|
|
"steps": steps,
|
||
|
|
"width": width,
|
||
|
|
"height": height,
|
||
|
|
"cfg_scale": cfg_scale,
|
||
|
|
"sampler_name": sampler,
|
||
|
|
"scheduler": scheduler,
|
||
|
|
"seed": seed,
|
||
|
|
"save_images": False,
|
||
|
|
"send_images": True,
|
||
|
|
}
|
||
|
|
|
||
|
|
# Если указана модель, переключаем её
|
||
|
|
if model:
|
||
|
|
await self._switch_model(model)
|
||
|
|
|
||
|
|
# Формируем промпт с LoRA
|
||
|
|
final_prompt = prompt
|
||
|
|
if lora:
|
||
|
|
final_prompt = f"<lora:{lora}:{lora_strength}> {prompt}"
|
||
|
|
payload["prompt"] = final_prompt
|
||
|
|
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.post(
|
||
|
|
f"{self.api_url}/sdapi/v1/txt2img",
|
||
|
|
json=payload,
|
||
|
|
timeout=aiohttp.ClientTimeout(total=600) # 10 минут на генерацию
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
data = await response.json()
|
||
|
|
if data.get("images"):
|
||
|
|
image_bytes = base64.b64decode(data["images"][0])
|
||
|
|
info = {
|
||
|
|
"prompt": final_prompt,
|
||
|
|
"negative_prompt": negative_prompt,
|
||
|
|
"width": width,
|
||
|
|
"height": height,
|
||
|
|
"steps": steps,
|
||
|
|
"cfg_scale": cfg_scale,
|
||
|
|
"sampler": sampler,
|
||
|
|
"scheduler": scheduler,
|
||
|
|
"seed": data.get("parameters", {}).get("seed", seed),
|
||
|
|
"model": model or await self.get_current_model(),
|
||
|
|
"lora": lora,
|
||
|
|
}
|
||
|
|
return image_bytes, info
|
||
|
|
else:
|
||
|
|
logger.error("API не вернул изображение")
|
||
|
|
return None
|
||
|
|
else:
|
||
|
|
error_text = await response.text()
|
||
|
|
logger.error(f"Ошибка API txt2img: {response.status} - {error_text}")
|
||
|
|
return None
|
||
|
|
except aiohttp.ClientTimeout:
|
||
|
|
logger.error("Таймаут запроса txt2img (10 минут)")
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка txt2img: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def img2img(
|
||
|
|
self,
|
||
|
|
init_image_bytes: bytes,
|
||
|
|
prompt: str,
|
||
|
|
negative_prompt: str = "",
|
||
|
|
width: int = 512,
|
||
|
|
height: int = 512,
|
||
|
|
steps: int = 20,
|
||
|
|
cfg_scale: float = 7.0,
|
||
|
|
sampler: str = "Euler a",
|
||
|
|
scheduler: str = "automatic",
|
||
|
|
denoising_strength: float = 0.75,
|
||
|
|
seed: int = -1,
|
||
|
|
model: Optional[str] = None,
|
||
|
|
lora: Optional[str] = None,
|
||
|
|
lora_strength: float = 0.8,
|
||
|
|
) -> Optional[tuple[bytes, dict]]:
|
||
|
|
"""
|
||
|
|
Генерация изображения на основе изображения.
|
||
|
|
Возвращает кортеж (изображение в bytes, info_dict) или None при ошибке.
|
||
|
|
"""
|
||
|
|
init_image_base64 = base64.b64encode(init_image_bytes).decode("utf-8")
|
||
|
|
|
||
|
|
payload = {
|
||
|
|
"init_images": [init_image_base64],
|
||
|
|
"prompt": prompt,
|
||
|
|
"negative_prompt": negative_prompt,
|
||
|
|
"steps": steps,
|
||
|
|
"width": width,
|
||
|
|
"height": height,
|
||
|
|
"cfg_scale": cfg_scale,
|
||
|
|
"sampler_name": sampler,
|
||
|
|
"scheduler": scheduler,
|
||
|
|
"denoising_strength": denoising_strength,
|
||
|
|
"seed": seed,
|
||
|
|
"save_images": False,
|
||
|
|
"send_images": True,
|
||
|
|
}
|
||
|
|
|
||
|
|
if model:
|
||
|
|
await self._switch_model(model)
|
||
|
|
|
||
|
|
final_prompt = prompt
|
||
|
|
if lora:
|
||
|
|
final_prompt = f"<lora:{lora}:{lora_strength}> {prompt}"
|
||
|
|
payload["prompt"] = final_prompt
|
||
|
|
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.post(
|
||
|
|
f"{self.api_url}/sdapi/v1/img2img",
|
||
|
|
json=payload,
|
||
|
|
timeout=aiohttp.ClientTimeout(total=600)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
data = await response.json()
|
||
|
|
if data.get("images"):
|
||
|
|
image_bytes = base64.b64decode(data["images"][0])
|
||
|
|
info = {
|
||
|
|
"prompt": final_prompt,
|
||
|
|
"negative_prompt": negative_prompt,
|
||
|
|
"width": width,
|
||
|
|
"height": height,
|
||
|
|
"steps": steps,
|
||
|
|
"cfg_scale": cfg_scale,
|
||
|
|
"sampler": sampler,
|
||
|
|
"scheduler": scheduler,
|
||
|
|
"seed": data.get("parameters", {}).get("seed", seed),
|
||
|
|
"denoising_strength": denoising_strength,
|
||
|
|
"model": model or await self.get_current_model(),
|
||
|
|
"lora": lora,
|
||
|
|
}
|
||
|
|
return image_bytes, info
|
||
|
|
else:
|
||
|
|
logger.error("API не вернул изображение")
|
||
|
|
return None
|
||
|
|
else:
|
||
|
|
error_text = await response.text()
|
||
|
|
logger.error(f"Ошибка API img2img: {response.status} - {error_text}")
|
||
|
|
return None
|
||
|
|
except aiohttp.ClientTimeout:
|
||
|
|
logger.error("Таймаут запроса img2img (10 минут)")
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка img2img: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def _switch_model(self, model_name: str) -> bool:
|
||
|
|
"""Переключить модель."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.post(
|
||
|
|
f"{self.api_url}/sdapi/v1/options",
|
||
|
|
json={"sd_model_checkpoint": model_name},
|
||
|
|
timeout=aiohttp.ClientTimeout(total=120)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
logger.info(f"Модель переключена на: {model_name}")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
logger.error(f"Ошибка переключения модели: {response.status}")
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка переключения модели: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def get_progress(self, skip_headers: bool = False) -> Optional[dict]:
|
||
|
|
"""Получить текущий прогресс генерации."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self.api_url}/sdapi/v1/progress",
|
||
|
|
params={"skip_current_image": skip_headers},
|
||
|
|
timeout=aiohttp.ClientTimeout(total=10)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
return await response.json()
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка получения прогресса: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def interrupt(self) -> bool:
|
||
|
|
"""Прервать текущую генерацию."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.post(
|
||
|
|
f"{self.api_url}/sdapi/v1/interrupt",
|
||
|
|
json={},
|
||
|
|
timeout=aiohttp.ClientTimeout(total=10)
|
||
|
|
) as response:
|
||
|
|
return response.status == 200
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка прерывания генерации: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def get_options(self) -> dict:
|
||
|
|
"""Получить текущие опции API."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self.api_url}/sdapi/v1/options",
|
||
|
|
timeout=aiohttp.ClientTimeout(total=10)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
return await response.json()
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Ошибка получения опций: {e}")
|
||
|
|
return {}
|
||
|
|
|
||
|
|
|
||
|
|
# Глобальный экземпляр
|
||
|
|
sd_client = SDClient(settings.SD_API_URL)
|