99logger = setup_logger ()
1010logger .setLevel ("ERROR" )
1111logging .disable (logging .CRITICAL )
12-
13- # Скрываем предупреждения PyTorch
1412warnings .filterwarnings ("ignore" , category = UserWarning )
15- """Работаем с моделями Detectron2"""
1613
14+ """Работаем с моделями Detectron2"""
1715
1816class Detectron2Loader :
19- def __init__ (self , device = None ):
20-
21- base_path = os .path .dirname (__file__ )
22- self .model_path = lambda name : os .path .join (base_path , ".." , "model" , name )
23-
24- if device is None :
25- self .device = "cuda" if torch .cuda .is_available () else "cpu"
26- elif isinstance (device , torch .device ):
27- self .device = device .type # Преобразуем torch.device в строку
28- else :
29- self .device = str (device )
30- # Инициализация конфигураций
31- self .configs = {
32- "R101" : self ._init_r101_config (),
33- "X101" : self ._init_x101_config (),
34- "Cascade_R50" : self ._init_cascade_r50_config (),
35- "Cascade_X152" : self ._init_cascade_x152_config (),
36- }
37-
38- self .config_paths = {
39- "R101" : self .model_path ("faster_rcnn_R_101_FPN_3x.yaml" ),
40- "X101" : self .model_path ("faster_rcnn_X_101_32x8d_FPN_3x.yaml" ),
41- "Cascade_R50" : self .model_path ("cascade_mask_rcnn_R_50_FPN_3x.yaml" ),
42- "Cascade_X152" : self .model_path (
43- "cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"
44- ),
45- }
46-
47- self .model_paths = {
48- "R101" : self .model_path ("/faster_rcnn_R_101_FPN_3x.pth" ),
49- "X101" : self .model_path ("faster_rcnn_X_101_32x8d_FPN_3x.pth" ),
50- "Cascade_R50" : self .model_path ("cascade_mask_rcnn_R_50_FPN_3x.pth" ),
51- "Cascade_X152" : self .model_path (
52- "cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.pth"
53- ),
17+ MODEL_MAPPING = {
18+ "R101" : {
19+ "config_file" : "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml" ,
20+ "weights_file" : "faster_rcnn_R_101_FPN_3x.pth" ,
21+ "config_path" : "faster_rcnn_R_101_FPN_3x.yaml"
22+ },
23+ "X101" : {
24+ "config_file" : "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml" ,
25+ "weights_file" : "faster_rcnn_X_101_32x8d_FPN_3x.pth" ,
26+ "config_path" : "faster_rcnn_X_101_32x8d_FPN_3x.yaml"
27+ },
28+ "Cascade_R50" : {
29+ "config_file" : "Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml" ,
30+ "weights_file" : "cascade_mask_rcnn_R_50_FPN_3x.pth" ,
31+ "config_path" : "cascade_mask_rcnn_R_50_FPN_3x.yaml"
32+ },
33+ "Cascade_X152" : {
34+ "config_file" : "Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml" ,
35+ "weights_file" : "cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.pth" ,
36+ "config_path" : "cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"
5437 }
38+ }
5539
56- # Сохраняем конфигурации в файлы
57- self ._save_configs ()
40+ def __init__ (self , device = None ):
41+ self ._base_path = os .path .join (os .path .dirname (__file__ ), ".." , "model" )
42+ self .device = self ._get_device (device )
43+ self .configs = {}
44+ self ._init_models ()
5845
59- def _init_r101_config (self ):
60- cfg = get_cfg ()
61- cfg .merge_from_file (
62- model_zoo .get_config_file (
63- "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
64- )
65- )
66- cfg .OUTPUT_DIR = self .model_path ("" )
67- cfg .MODEL .WEIGHTS = os .path .join (cfg .OUTPUT_DIR , "faster_rcnn_R_101_FPN_3x.pth" )
68- cfg .MODEL .ROI_HEADS .NUM_CLASSES = 1
69- cfg .MODEL .DEVICE = self .device
70- return cfg
46+ def _get_device (self , device ):
47+ if device is None :
48+ return "cuda" if torch .cuda .is_available () else "cpu"
49+ return device .type if isinstance (device , torch .device ) else str (device )
7150
72- def _init_x101_config (self ):
73- cfg = get_cfg ()
74- cfg .merge_from_file (
75- model_zoo .get_config_file (
76- "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"
77- )
78- )
79- cfg .OUTPUT_DIR = self .model_path ("" )
80- cfg .MODEL .WEIGHTS = os .path .join (
81- cfg .OUTPUT_DIR , "faster_rcnn_X_101_32x8d_FPN_3x.pth"
82- )
83- cfg .MODEL .ROI_HEADS .NUM_CLASSES = 1
84- cfg .MODEL .DEVICE = self .device
85- return cfg
51+ def _model_path (self , name : str ) -> str :
52+ return os .path .join (self ._base_path , name )
8653
87- def _init_cascade_r50_config (self ):
54+ def _init_model_config (self , model_name ):
8855 cfg = get_cfg ()
89- cfg .merge_from_file (
90- model_zoo .get_config_file ("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml" )
91- )
92- cfg .OUTPUT_DIR = self .model_path ("" )
93- cfg .MODEL .WEIGHTS = os .path .join (
94- cfg .OUTPUT_DIR , "cascade_mask_rcnn_R_50_FPN_3x.pth"
95- )
56+ model_data = self .MODEL_MAPPING [model_name ]
57+
58+ cfg .merge_from_file (model_zoo .get_config_file (model_data ["config_file" ]))
59+ cfg .OUTPUT_DIR = self ._base_path
60+ cfg .MODEL .WEIGHTS = self ._model_path (model_data ["weights_file" ])
9661 cfg .MODEL .ROI_HEADS .NUM_CLASSES = 1
9762 cfg .MODEL .DEVICE = self .device
63+
9864 return cfg
9965
100- def _init_cascade_x152_config (self ):
101- cfg = get_cfg ()
102- cfg .merge_from_file (
103- model_zoo .get_config_file (
104- "Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"
105- )
106- )
107- cfg .OUTPUT_DIR = self .model_path ("" )
108- cfg .MODEL .WEIGHTS = os .path .join (
109- cfg .OUTPUT_DIR , "cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.pth"
110- )
111- cfg .MODEL .ROI_HEADS .NUM_CLASSES = 1
112- cfg .MODEL .DEVICE = self .device
113- return cfg
66+ def _init_models (self ):
67+ self .configs = {
68+ name : self ._init_model_config (name )
69+ for name in self .MODEL_MAPPING
70+ }
71+
72+ self .config_paths = {
73+ name : self ._model_path (self .MODEL_MAPPING [name ]["config_path" ])
74+ for name in self .MODEL_MAPPING
75+ }
76+
77+ self .model_paths = {
78+ name : self ._model_path (self .MODEL_MAPPING [name ]["weights_file" ])
79+ for name in self .MODEL_MAPPING
80+ }
81+
82+ self ._save_configs ()
11483
11584 def _save_configs (self ):
116- """Сохраняет конфигурации в файлы"""
11785 for model_name , cfg in self .configs .items ():
11886 with open (self .config_paths [model_name ], "w" ) as f :
11987 f .write (cfg .dump ())
12088
12189 def get_config (self , model_name : str ):
122- """Возвращает конфигурацию модели"""
12390 return self .configs .get (model_name )
12491
12592 def get_config_path (self , model_name : str ):
126- """Возвращает путь к файлу конфигурации"""
12793 return self .config_paths .get (model_name )
12894
12995 def get_model_path (self , model_name : str ):
130- """Возвращает путь к весам модели"""
131- return self .model_paths .get (model_name )
96+ return self .model_paths .get (model_name )
0 commit comments