Initial commit
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
+341
@@ -0,0 +1,341 @@
|
||||
"""Модуль взаимодействия с Stable Diffusion WebUI API (Automatic1111)."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDClient:
|
||||
"""Клиент для работы с Stable Diffusion WebUI API."""
|
||||
|
||||
def __init__(self, api_url: str):
|
||||
self.api_url = api_url.rstrip("/")
|
||||
|
||||
async def check_connection(self) -> bool:
|
||||
"""Проверка соединения с API."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_url}/sdapi/v1/options",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка подключения к SD API: {e}")
|
||||
return False
|
||||
|
||||
async def get_models(self) -> list[str]:
|
||||
"""Получить список доступных моделей."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_url}/sdapi/v1/sd-models",
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return [model.get("title", model.get("model_name", "")) for model in data]
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения списка моделей: {e}")
|
||||
return []
|
||||
|
||||
async def get_current_model(self) -> str:
|
||||
"""Получить текущую загруженную модель."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_url}/sdapi/v1/options",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data.get("sd_model_checkpoint", "Unknown")
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения текущей модели: {e}")
|
||||
return "Unknown"
|
||||
|
||||
async def get_samplers(self) -> list[str]:
|
||||
"""Получить список доступных сэмплеров."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_url}/sdapi/v1/samplers",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return [sampler.get("name", "") for sampler in data]
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения списка сэмплеров: {e}")
|
||||
return ["Euler a", "Euler", "DPM++ 2M Karras", "DPM++ SDE Karras", "DDIM"]
|
||||
|
||||
async def get_schedulers(self) -> list[str]:
|
||||
"""Получить список доступных шедулеров."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_url}/sdapi/v1/schedulers",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return [scheduler.get("name", "") for scheduler in data]
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения списка шедулеров: {e}")
|
||||
return ["automatic", "normal", "karras", "exponential", "SGM uniform", "simple", "DDIM"]
|
||||
|
||||
async def get_loras(self) -> list[str]:
|
||||
"""Получить список доступных LoRA."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_url}/sdapi/v1/loras",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return [lora.get("name", "") for lora in data]
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения списка LoRA: {e}")
|
||||
return []
|
||||
|
||||
async def txt2img(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
steps: int = 20,
|
||||
cfg_scale: float = 7.0,
|
||||
sampler: str = "Euler a",
|
||||
scheduler: str = "automatic",
|
||||
seed: int = -1,
|
||||
model: Optional[str] = None,
|
||||
lora: Optional[str] = None,
|
||||
lora_strength: float = 0.8,
|
||||
) -> Optional[tuple[bytes, dict]]:
|
||||
"""
|
||||
Генерация изображения из текста.
|
||||
Возвращает кортеж (изображение в bytes, info_dict) или None при ошибке.
|
||||
"""
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"steps": steps,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"cfg_scale": cfg_scale,
|
||||
"sampler_name": sampler,
|
||||
"scheduler": scheduler,
|
||||
"seed": seed,
|
||||
"save_images": False,
|
||||
"send_images": True,
|
||||
}
|
||||
|
||||
# Если указана модель, переключаем её
|
||||
if model:
|
||||
await self._switch_model(model)
|
||||
|
||||
# Формируем промпт с LoRA
|
||||
final_prompt = prompt
|
||||
if lora:
|
||||
final_prompt = f"<lora:{lora}:{lora_strength}> {prompt}"
|
||||
payload["prompt"] = final_prompt
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.api_url}/sdapi/v1/txt2img",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=600) # 10 минут на генерацию
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data.get("images"):
|
||||
image_bytes = base64.b64decode(data["images"][0])
|
||||
info = {
|
||||
"prompt": final_prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
"sampler": sampler,
|
||||
"scheduler": scheduler,
|
||||
"seed": data.get("parameters", {}).get("seed", seed),
|
||||
"model": model or await self.get_current_model(),
|
||||
"lora": lora,
|
||||
}
|
||||
return image_bytes, info
|
||||
else:
|
||||
logger.error("API не вернул изображение")
|
||||
return None
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Ошибка API txt2img: {response.status} - {error_text}")
|
||||
return None
|
||||
except aiohttp.ClientTimeout:
|
||||
logger.error("Таймаут запроса txt2img (10 минут)")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка txt2img: {e}")
|
||||
return None
|
||||
|
||||
async def img2img(
|
||||
self,
|
||||
init_image_bytes: bytes,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
steps: int = 20,
|
||||
cfg_scale: float = 7.0,
|
||||
sampler: str = "Euler a",
|
||||
scheduler: str = "automatic",
|
||||
denoising_strength: float = 0.75,
|
||||
seed: int = -1,
|
||||
model: Optional[str] = None,
|
||||
lora: Optional[str] = None,
|
||||
lora_strength: float = 0.8,
|
||||
) -> Optional[tuple[bytes, dict]]:
|
||||
"""
|
||||
Генерация изображения на основе изображения.
|
||||
Возвращает кортеж (изображение в bytes, info_dict) или None при ошибке.
|
||||
"""
|
||||
init_image_base64 = base64.b64encode(init_image_bytes).decode("utf-8")
|
||||
|
||||
payload = {
|
||||
"init_images": [init_image_base64],
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"steps": steps,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"cfg_scale": cfg_scale,
|
||||
"sampler_name": sampler,
|
||||
"scheduler": scheduler,
|
||||
"denoising_strength": denoising_strength,
|
||||
"seed": seed,
|
||||
"save_images": False,
|
||||
"send_images": True,
|
||||
}
|
||||
|
||||
if model:
|
||||
await self._switch_model(model)
|
||||
|
||||
final_prompt = prompt
|
||||
if lora:
|
||||
final_prompt = f"<lora:{lora}:{lora_strength}> {prompt}"
|
||||
payload["prompt"] = final_prompt
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.api_url}/sdapi/v1/img2img",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=600)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data.get("images"):
|
||||
image_bytes = base64.b64decode(data["images"][0])
|
||||
info = {
|
||||
"prompt": final_prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
"sampler": sampler,
|
||||
"scheduler": scheduler,
|
||||
"seed": data.get("parameters", {}).get("seed", seed),
|
||||
"denoising_strength": denoising_strength,
|
||||
"model": model or await self.get_current_model(),
|
||||
"lora": lora,
|
||||
}
|
||||
return image_bytes, info
|
||||
else:
|
||||
logger.error("API не вернул изображение")
|
||||
return None
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Ошибка API img2img: {response.status} - {error_text}")
|
||||
return None
|
||||
except aiohttp.ClientTimeout:
|
||||
logger.error("Таймаут запроса img2img (10 минут)")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка img2img: {e}")
|
||||
return None
|
||||
|
||||
async def _switch_model(self, model_name: str) -> bool:
|
||||
"""Переключить модель."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.api_url}/sdapi/v1/options",
|
||||
json={"sd_model_checkpoint": model_name},
|
||||
timeout=aiohttp.ClientTimeout(total=120)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info(f"Модель переключена на: {model_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Ошибка переключения модели: {response.status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка переключения модели: {e}")
|
||||
return False
|
||||
|
||||
async def get_progress(self, skip_headers: bool = False) -> Optional[dict]:
|
||||
"""Получить текущий прогресс генерации."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_url}/sdapi/v1/progress",
|
||||
params={"skip_current_image": skip_headers},
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения прогресса: {e}")
|
||||
return None
|
||||
|
||||
async def interrupt(self) -> bool:
|
||||
"""Прервать текущую генерацию."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.api_url}/sdapi/v1/interrupt",
|
||||
json={},
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка прерывания генерации: {e}")
|
||||
return False
|
||||
|
||||
async def get_options(self) -> dict:
|
||||
"""Получить текущие опции API."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_url}/sdapi/v1/options",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения опций: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# Глобальный экземпляр
|
||||
sd_client = SDClient(settings.SD_API_URL)
|
||||
Reference in New Issue
Block a user