TGTGInsighttelegram intelligenceLIVE / telegram public index
← AI[ex]Time
AI[ex]Time avatar

TGINSIGHT POST

Post #94

@AIexTime

AI[ex]Time

Views2,530Post view count
PostedOct 710/07/2024, 05:52 PM
Post content

Post content

Kaggle соревнование lmsys chatbot arena, часть 2. Технические подходы. В продолжение разбора соревнования обещал написать вторую часть с техническим обзором. Время пришло. Задачу можно было решать двумя вариантами: добавить голову и учить модель на задачу классификации или же оставить предсказание следующего токена и напрямую предсказывать токен-метку. Разницы особой нет, но во втором случае можно использовать много разных фреймворков с оптимизациями обучения и инференса по типу unsloth. Бейзлайн выглядит так: 1. Берем llama3-8B или gemma2-9B 2. Учим лору, вставляя адаптеры во все линейные слои 3. Инференсим квантизованную модель в int4/8 без мерджа весов адаптеров Улучшить решение можно было несколькими способами: 1. Pseudo-labeling. берем какой-нибудь lmsys-1M-dataset, составляем пары ответов на один промпт и размечаем llama3.1_405B. Были попытки и с нуля генерировать синтетические данные, но докидывали они значительно меньше, все-таки распределение данных в таком случае сильно отличается от целевого. 2. External Datasets. Просто докидываем больше данных в post pre-train. Важно, что не в финальный fine-tune, тк на последнем шаге лучше использовать только данные из соревнования. Много интересных датасетов можно найти в RLHFlow. Авторы так же в свое время писали неплохую статью про RLHF. 3. Ensembling. Пришлось пробовать много разных моделей: MistralNemo, Llama3/3.1, Phi, Yi, Qwen, Gemma и тд. Лучше всего заработала gemma2-it, причем с большим отрывом по сравнению с другими моделями. На втором месте Llama3 (интересно, что 3.1 не докидывала). Удивительно, но модели от Mistral вообще не могли справиться с задачей. Если добавить всякие оптимизации во время инференса (dynamic batch size, dataset length sorting), где-то пожертвовать длиной контекста, то можно было уместить на 2xT4 инференс gemma + llama за 9 часов. Gemma работала значительно дольше, в частности, из-за огромного словаря. 4. Inference tricks. Всякие мелкие, но важные детали. Например, если мы используем ансамбль, то в одну модель лучше отправлять question-responseA-responseB, а в другую ответы поменять местами, чтобы добавить больше разнообразия. Важно также выставить truncation left side, чтобы жертвовать токенами из начала — они меньше влияет на предикт модели. Кто-то лез совсем в детали и выключал logit soft-capping в gemma, писали, что докидывает пару тысячных на лб — типичный кегл 😋 Кстати, если я не ошибаюсь, это первое соревнование, в котором завели инференс 33B моделей: vllm + квантизация AWQ + Tensor Parallel. 5. И напоследок прием, который зарешал больше всех — Distillation. Парень с таким подходом и взял первое место. Логика следующая: 1. Бьем весь трейн на 5 фолдов. 2. Тренируем на фолдах Llama3-70B и Qwen2-72B и размечаем весь датасет их предиктами. 3. Опять же на фолдах дистиллируем предикты больших моделей в gemma2, используя самый простой KL loss. Учим только LoRA адаптеры и в итоге получаем 5 моделей. 4. Усредняем веса всех адаптеров и получаем с помощью такого model merging финальную модель. 5. На все про все — А100 80G * 8 + ZeRO2 Часть 1 про лик в соревновании