Нужно написать код для прунинга yolov9 используя torch_pruning

Задание закрыто
Стоимость:15 000 рублей
Срок выполнения:3 дня
Варианты оплаты:Планируется использовать Безопасную сделку
Дата публикации:2024-11-11 07:29
Был(а) на сайте:2024-11-12 15:41

Нужно написать код для прунинга yolov9 используя torch_pruning

 
Задача

Нужно написать код для прунинга yolov9 от ultralyticsиспользуя torch_pruning. Нужно ускорить модель в 2 раза с помощью прунинга, с потерями не более 10%. Так же важно чтоб код работал с моделями E, C, M. Версия ultralytics==8.3.18, у torch_pruning можно выбрать любую версию

Оставлять заявки могут только авторизованные пользователи.

Общие комментарии:

чатгпт не поможет, рассчитывайте только на свои силы

2024-11-11 07:46

Смогу выполнить работу сегодня до 20:00 .
Stanislav Ricci
Специализация: Программирование и IT
  • 15 000 руб3 дня
import torch
from ultralytics import YOLO
import torch_pruning as tp

def prune_yolov8_model(model_name='yolov8n.pt', data='coco128.yaml', pruning_rate=0.5, epochs=10):
# Загружаем модель
model = YOLO(model_name)
net = model.model

# Создаем пример входных данных
example_inputs = torch.randn(1, 3, 640, 640)

# Строим граф зависимостей
DG = tp.DependencyGraph().build_dependency(net, example_inputs)

# Определяем стратегию прунинга
strategy = tp.strategy.L1Strategy() # Можно использовать другие стратегии

# Собираем все слои Conv2d для прунинга
prunable_modules = []
for m in net.modules():
if isinstance(m, torch.nn.Conv2d):
prunable_modules.append(m)

# Выполняем прунинг
for m in prunable_modules:
# Пропускаем глубинные свертки
if m.groups == m.in_channels and m.in_channels == m.out_channels:
continue

# Получаем индексы каналов для прунинга
pruning_index = strategy(m.weight, amount=pruning_rate)
# Строим план прунинга
plan = DG.get_pruning_plan(m, tp.prune_conv, pruning_index)
# Выполняем план
plan.exec()

# Тонкая настройка (fine-tuning) прунинговой модели
model.train(data=data, epochs=epochs)

# Оцениваем производительность модели
metrics = model.val()

return model, metrics

if __name__ == '__main__':
# Применяем прунинг и оцениваем модель
pruned_model, metrics = prune_yolov8_model(
model_name='yolov8m.pt', # Можно заменить на 'yolov8e.pt', 'yolov8c.pt', 'yolov8m.pt'
data='coco128.yaml',
pruning_rate=0.5, # Прунинг 50% параметров для ускорения в 2 раза
epochs=10 # Количество эпох для тонкой настройки
)
print(metrics)

Примеры моих работ


Оставлять заявки могут только авторизованные пользователи.