Перейти к содержанию

Обучение SDXL

Если вы используете скрипты https://github.com/kohya-ss/sd-scripts напрямую, то, для обучения SDXL, вам необходимо переключиться на ветку "sdxl" и обновить зависимости. Эта операция может привести к проблемам совместимости, так что, желательно, делать отдельную установку для обучения SDXL и использовать отдельную venv-среду. Скрипты для тренировки SDXL имеют в имени файла префикс sdxl_.

Подробнее про обучение SDXL через kohya-ss можно почитать тут: https://github.com/kohya-ss/sd-scripts/tree/sdxl#about-sdxl-training

Для GUI https://github.com/bmaltais/kohya_ss и https://github.com/derrian-distro/ так же вышли обновления, позволяющее делать файнтьюны для SDXL. Кроме полноценного файнтьюна и обучения лор, для bmaltais/kohya_ss так же доступны пресеты для обучения LoRA/LoHa/LoKr, в том числе и для SDXL, требующие больше VRAM.

Требования по VRAM для тренировки SDXL

TL;DR

Обучение полновесных SDXL-чекпоинтов через будку недоступно для обывателя.

Тренировать лоры можно и на 8 гигах. Качественный результат и адекватная скорость в сделку не входят.

Приведённые ниже тесты актуальны на 09.02.2024, тестировалось на kohya-ss/sd-scripts версии 0.8.3.

Тренировка ниже 20 гигов достигается за счет градиент-чекпоинтинга, ниже 12 - треш типа загрузки базовой модели в 8 битах. Насколько сильно всрет - неизвестно.

Gradient checkpointing позволяет легко бустануть батчсайз без сильного роста жора памяти и необходимости в аккумуляции, но снижает скорость тренировки.

Warning

Все значения - только потребление питона, если делается на винде с ускоряемым браузером, ютубчиком, парой мониторов и т.д. - можно гиг-два еще накидывать.

Тип обучения batch size = 1 batch size = 2 batch size = 3 batch size = 4 batch size = 8
Dreambooth 38.5 44.0 - - -
Dreambooth, gradient checkpointing 34.4 35.2 - - 40.5
LoRA dim=32 17.4 - - 47.7 -
LoRA dim=32, full bf16 16.7 - - - -
LoRA dim=32, gradient checkpointing 9.4 13.1 - - -
LoRA dim=32, full bf16, gradient checkpointing 9.1 - - - -
LoRA dim=32, fp8base, gradient checkpointing 6.4 - - - -
LoRA dim=128 20.5 - 41.7 - -
LoRA dim=128, full bf16 18.3 - - 47.8 -
LoRA dim=128, gradient checkpointing 12 - - - -
LoRA dim=128, full bf16, gradient checkpointing 10.3 - - 15.0 -
LoRA dim=128, fp8base, gradient checkpointing 9.9 - - - -
LoCon dim=64+32 18.6 30.4 - - -
LoCon dim=64+32, gradient checkpointing 10.6 - - - -
LoCon dim=64+32, full bf16, gradient checkpointing 9.7 - - - -
LoCon dim=64+32, fp8base, gradient checkpointing 7.2 - - - -