Skip to content

Commit 88ab7f7

Browse files
committed
Optimized code. Improved import of yolo and detectron2 models. Improved work with language models.
1 parent 1612a13 commit 88ab7f7

File tree

9 files changed

+170
-254
lines changed

9 files changed

+170
-254
lines changed

particleanalyzer/core/Detectron2Loader.py

Lines changed: 58 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -9,123 +9,88 @@
99
logger = setup_logger()
1010
logger.setLevel("ERROR")
1111
logging.disable(logging.CRITICAL)
12-
13-
# Скрываем предупреждения PyTorch
1412
warnings.filterwarnings("ignore", category=UserWarning)
15-
"""Работаем с моделями Detectron2"""
1613

14+
"""Работаем с моделями Detectron2"""
1715

1816
class Detectron2Loader:
19-
def __init__(self, device=None):
20-
21-
base_path = os.path.dirname(__file__)
22-
self.model_path = lambda name: os.path.join(base_path, "..", "model", name)
23-
24-
if device is None:
25-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
26-
elif isinstance(device, torch.device):
27-
self.device = device.type # Преобразуем torch.device в строку
28-
else:
29-
self.device = str(device)
30-
# Инициализация конфигураций
31-
self.configs = {
32-
"R101": self._init_r101_config(),
33-
"X101": self._init_x101_config(),
34-
"Cascade_R50": self._init_cascade_r50_config(),
35-
"Cascade_X152": self._init_cascade_x152_config(),
36-
}
37-
38-
self.config_paths = {
39-
"R101": self.model_path("faster_rcnn_R_101_FPN_3x.yaml"),
40-
"X101": self.model_path("faster_rcnn_X_101_32x8d_FPN_3x.yaml"),
41-
"Cascade_R50": self.model_path("cascade_mask_rcnn_R_50_FPN_3x.yaml"),
42-
"Cascade_X152": self.model_path(
43-
"cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"
44-
),
45-
}
46-
47-
self.model_paths = {
48-
"R101": self.model_path("/faster_rcnn_R_101_FPN_3x.pth"),
49-
"X101": self.model_path("faster_rcnn_X_101_32x8d_FPN_3x.pth"),
50-
"Cascade_R50": self.model_path("cascade_mask_rcnn_R_50_FPN_3x.pth"),
51-
"Cascade_X152": self.model_path(
52-
"cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.pth"
53-
),
17+
MODEL_MAPPING = {
18+
"R101": {
19+
"config_file": "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
20+
"weights_file": "faster_rcnn_R_101_FPN_3x.pth",
21+
"config_path": "faster_rcnn_R_101_FPN_3x.yaml"
22+
},
23+
"X101": {
24+
"config_file": "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
25+
"weights_file": "faster_rcnn_X_101_32x8d_FPN_3x.pth",
26+
"config_path": "faster_rcnn_X_101_32x8d_FPN_3x.yaml"
27+
},
28+
"Cascade_R50": {
29+
"config_file": "Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml",
30+
"weights_file": "cascade_mask_rcnn_R_50_FPN_3x.pth",
31+
"config_path": "cascade_mask_rcnn_R_50_FPN_3x.yaml"
32+
},
33+
"Cascade_X152": {
34+
"config_file": "Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml",
35+
"weights_file": "cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.pth",
36+
"config_path": "cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"
5437
}
38+
}
5539

56-
# Сохраняем конфигурации в файлы
57-
self._save_configs()
40+
def __init__(self, device=None):
41+
self._base_path = os.path.join(os.path.dirname(__file__), "..", "model")
42+
self.device = self._get_device(device)
43+
self.configs = {}
44+
self._init_models()
5845

59-
def _init_r101_config(self):
60-
cfg = get_cfg()
61-
cfg.merge_from_file(
62-
model_zoo.get_config_file(
63-
"COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
64-
)
65-
)
66-
cfg.OUTPUT_DIR = self.model_path("")
67-
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "faster_rcnn_R_101_FPN_3x.pth")
68-
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
69-
cfg.MODEL.DEVICE = self.device
70-
return cfg
46+
def _get_device(self, device):
47+
if device is None:
48+
return "cuda" if torch.cuda.is_available() else "cpu"
49+
return device.type if isinstance(device, torch.device) else str(device)
7150

72-
def _init_x101_config(self):
73-
cfg = get_cfg()
74-
cfg.merge_from_file(
75-
model_zoo.get_config_file(
76-
"COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"
77-
)
78-
)
79-
cfg.OUTPUT_DIR = self.model_path("")
80-
cfg.MODEL.WEIGHTS = os.path.join(
81-
cfg.OUTPUT_DIR, "faster_rcnn_X_101_32x8d_FPN_3x.pth"
82-
)
83-
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
84-
cfg.MODEL.DEVICE = self.device
85-
return cfg
51+
def _model_path(self, name: str) -> str:
52+
return os.path.join(self._base_path, name)
8653

87-
def _init_cascade_r50_config(self):
54+
def _init_model_config(self, model_name):
8855
cfg = get_cfg()
89-
cfg.merge_from_file(
90-
model_zoo.get_config_file("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml")
91-
)
92-
cfg.OUTPUT_DIR = self.model_path("")
93-
cfg.MODEL.WEIGHTS = os.path.join(
94-
cfg.OUTPUT_DIR, "cascade_mask_rcnn_R_50_FPN_3x.pth"
95-
)
56+
model_data = self.MODEL_MAPPING[model_name]
57+
58+
cfg.merge_from_file(model_zoo.get_config_file(model_data["config_file"]))
59+
cfg.OUTPUT_DIR = self._base_path
60+
cfg.MODEL.WEIGHTS = self._model_path(model_data["weights_file"])
9661
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
9762
cfg.MODEL.DEVICE = self.device
63+
9864
return cfg
9965

100-
def _init_cascade_x152_config(self):
101-
cfg = get_cfg()
102-
cfg.merge_from_file(
103-
model_zoo.get_config_file(
104-
"Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"
105-
)
106-
)
107-
cfg.OUTPUT_DIR = self.model_path("")
108-
cfg.MODEL.WEIGHTS = os.path.join(
109-
cfg.OUTPUT_DIR, "cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.pth"
110-
)
111-
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
112-
cfg.MODEL.DEVICE = self.device
113-
return cfg
66+
def _init_models(self):
67+
self.configs = {
68+
name: self._init_model_config(name)
69+
for name in self.MODEL_MAPPING
70+
}
71+
72+
self.config_paths = {
73+
name: self._model_path(self.MODEL_MAPPING[name]["config_path"])
74+
for name in self.MODEL_MAPPING
75+
}
76+
77+
self.model_paths = {
78+
name: self._model_path(self.MODEL_MAPPING[name]["weights_file"])
79+
for name in self.MODEL_MAPPING
80+
}
81+
82+
self._save_configs()
11483

11584
def _save_configs(self):
116-
"""Сохраняет конфигурации в файлы"""
11785
for model_name, cfg in self.configs.items():
11886
with open(self.config_paths[model_name], "w") as f:
11987
f.write(cfg.dump())
12088

12189
def get_config(self, model_name: str):
122-
"""Возвращает конфигурацию модели"""
12390
return self.configs.get(model_name)
12491

12592
def get_config_path(self, model_name: str):
126-
"""Возвращает путь к файлу конфигурации"""
12793
return self.config_paths.get(model_name)
12894

12995
def get_model_path(self, model_name: str):
130-
"""Возвращает путь к весам модели"""
131-
return self.model_paths.get(model_name)
96+
return self.model_paths.get(model_name)

particleanalyzer/core/LLMAnalysis.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import json
2-
from typing import Dict, List, Tuple, Literal, Optional
2+
from typing import Dict, List, Tuple, Literal
33
import pandas as pd
44
import numpy as np
55
from openai import OpenAI
66
from huggingface_hub import InferenceClient
77
from particleanalyzer.core.language_context import LanguageContext
8-
from particleanalyzer.core.languages import translations
98

109

1110
class LLMAnalysis:
@@ -17,13 +16,18 @@ def __init__(
1716
self.provider = provider
1817
self.api_key = api_key
1918

20-
if provider == "openrouter":
19+
if self.api_key.startswith("hf_"):
20+
provider == "huggingface"
21+
self.client = InferenceClient(provider="fireworks-ai", api_key=api_key)
22+
self.model_list = ["deepseek-ai/DeepSeek-V3"]
23+
elif self.api_key.startswith("sk-or-"):
2124
self.client = OpenAI(
2225
base_url="https://openrouter.ai/api/v1",
2326
api_key=api_key,
2427
)
25-
elif provider == "huggingface":
26-
self.client = InferenceClient(provider=huggingface_model, api_key=api_key)
28+
self.model_list = ["deepseek/deepseek-chat:free", "deepseek/deepseek-chat-v3-0324", "google/gemini-2.0-flash-001",
29+
"openai/gpt-4o-mini"]
30+
provider == "openrouter"
2731
else:
2832
raise ValueError("Неизвестный провайдер. Доступные варианты: 'openrouter', 'huggingface'")
2933

particleanalyzer/core/ModelManager.py

Lines changed: 43 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
try:
77
from detectron2.engine import DefaultPredictor
88
from .Detectron2Loader import Detectron2Loader
9-
109
DETECTRON2_AVAILABLE = True
1110
except ImportError:
1211
DETECTRON2_AVAILABLE = False
@@ -15,101 +14,36 @@
1514
class ModelManager:
1615
def __init__(self, device=None):
1716
self.device = device
18-
# URL сервера
1917
self.SERVER_URL = "https://rybakov-k.ru/model/"
20-
21-
# Путь к моделям
18+
19+
# Инициализация путей
2220
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
2321
self.MODELS_DIR = os.path.join(base_path, "model")
2422
os.makedirs(self.MODELS_DIR, exist_ok=True)
2523

26-
# Проверяем и загружаем модели
27-
self._ensure_models_available()
28-
29-
# Инициализируем YOLO
30-
self.yolo_loader = YOLOLoader()
31-
32-
# Инициализируем detectron_loader только если detectron2 доступен
24+
yolo_files = list(YOLOLoader.MODEL_MAPPING.values())
25+
detectron_files = []
3326
if DETECTRON2_AVAILABLE:
34-
self.detectron_loader = Detectron2Loader(device=self.device)
35-
else:
36-
self.detectron_loader = None
37-
38-
self.model_types = {
39-
"Yolo11 (dataset 1)": "yolo",
40-
"Yolo12 (dataset 1)": "yolo",
41-
"Yolo11 (dataset 2)": "yolo",
42-
"Yolo12 (dataset 2)": "yolo",
43-
"R101": "detectron",
44-
"X101": "detectron",
45-
"Cascade_R50": "detectron",
46-
"Cascade_X152": "detectron",
47-
}
48-
49-
def get_model(self, model_name: str):
50-
"""Возвращает модель по имени"""
51-
model_type = self.model_types.get(model_name)
52-
53-
if model_type == "yolo":
54-
return self.yolo_loader.get_model(model_name)
55-
elif model_type == "detectron":
56-
return self.detectron_loader.get_config(model_name)
57-
else:
58-
raise ValueError(f"Unknown model type: {model_name}")
27+
detectron_files = [
28+
*[v["weights_file"] for v in Detectron2Loader.MODEL_MAPPING.values()]
29+
]
5930

60-
def get_predictor(self, model_name: str):
61-
"""Для Detectron возвращает готовый predictor"""
62-
if model_name in self.detectron_loader.configs:
63-
cfg = self.detectron_loader.get_config(model_name)
64-
return DefaultPredictor(cfg)
65-
return None
66-
67-
def get_model_path(self, model_name: str) -> str:
68-
"""Возвращает путь к модели по её имени"""
69-
if model_name in self.yolo_loader.models:
70-
return self.yolo_loader.get_model_path(model_name)
71-
elif model_name in self.detectron_loader.configs:
72-
return self.detectron_loader.get_model_path(model_name)
73-
else:
74-
raise ValueError(f"Model {model_name} not found")
75-
76-
def get_config_path(self, model_name: str) -> str:
77-
"""Возвращает путь к конфигу (только для Detectron)"""
78-
if model_name in self.detectron_loader.configs:
79-
return self.detectron_loader.get_config_path(model_name)
80-
raise ValueError(
81-
f"Config for {model_name} not available (YOLO models don't use config files)"
82-
)
83-
84-
def _ensure_models_available(self):
85-
"""Проверяет и загружает только необходимые модели"""
86-
# Базовые файлы для YOLO (всегда нужны)
87-
required_files = [
88-
"Yolo11_d1.pt",
89-
"Yolo11_d2.pt",
90-
"Yolo12_d1.pt",
91-
"Yolo12_d2.pt",
92-
]
31+
# Проверяем и загружаем модели
32+
self._ensure_models_available(yolo_files + detectron_files)
9333

94-
# Добавляем модели Detectron2 только если доступен
95-
if DETECTRON2_AVAILABLE:
96-
required_files.extend(
97-
[
98-
"faster_rcnn_R_101_FPN_3x.pth",
99-
"faster_rcnn_X_101_32x8d_FPN_3x.pth",
100-
"cascade_mask_rcnn_R_50_FPN_3x.pth",
101-
"cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.pth",
102-
]
103-
)
34+
# Инициализация загрузчиков
35+
self.yolo_loader = YOLOLoader()
36+
self.detectron_loader = Detectron2Loader(device=self.device) if DETECTRON2_AVAILABLE else None
10437

38+
def _ensure_models_available(self, required_files):
39+
"""Проверяет и загружает необходимые файлы моделей"""
10540
for filename in required_files:
10641
file_path = os.path.join(self.MODELS_DIR, filename)
10742
if not os.path.exists(file_path):
10843
self._download_file(filename)
10944

11045
def _download_file(self, filename):
11146
"""Скачивает файл с сервера"""
112-
# Если это модель Detectron2 и библиотека недоступна - пропускаем
11347
if filename.endswith((".pth", ".yaml")) and not DETECTRON2_AVAILABLE:
11448
print(f"Skipping {filename} (Detectron2 not available)")
11549
return False
@@ -138,3 +72,32 @@ def _download_file(self, filename):
13872
if os.path.exists(save_path):
13973
os.remove(save_path)
14074
return False
75+
76+
def get_model(self, model_name: str):
77+
"""Возвращает модель по имени"""
78+
if model_name in self.yolo_loader.MODEL_MAPPING:
79+
return self.yolo_loader.get_model(model_name)
80+
elif DETECTRON2_AVAILABLE and model_name in self.detectron_loader.MODEL_CONFIGS:
81+
return self.detectron_loader.get_config(model_name)
82+
raise ValueError(f"Unknown model: {model_name}")
83+
84+
def get_predictor(self, model_name: str):
85+
"""Для Detectron возвращает готовый predictor"""
86+
if DETECTRON2_AVAILABLE and model_name in self.detectron_loader.MODEL_CONFIGS:
87+
cfg = self.detectron_loader.get_config(model_name)
88+
return DefaultPredictor(cfg)
89+
return None
90+
91+
def get_model_path(self, model_name: str) -> str:
92+
"""Возвращает путь к модели по её имени"""
93+
if model_name in self.yolo_loader.MODEL_MAPPING:
94+
return self.yolo_loader.get_model_path(model_name)
95+
elif DETECTRON2_AVAILABLE and model_name in self.detectron_loader.MODEL_CONFIGS:
96+
return self.detectron_loader.get_model_path(model_name)
97+
raise ValueError(f"Model {model_name} not found")
98+
99+
def get_config_path(self, model_name: str) -> str:
100+
"""Возвращает путь к конфигу (только для Detectron)"""
101+
if DETECTRON2_AVAILABLE and model_name in self.detectron_loader.MODEL_CONFIGS:
102+
return self.detectron_loader.get_config_path(model_name)
103+
raise ValueError(f"Config for {model_name} not available")

0 commit comments

Comments
 (0)