1- import os
21import copy
3- import time
4- import shutil
5- import pathlib
2+ import importlib .resources as pkg_resources
63import logging
4+ import os
5+ import pathlib
6+ import shutil
7+ import ssl
78import tempfile
9+ import time
10+ import urllib .request
811import zipfile
9- import importlib .resources as pkg_resources
1012from threading import Lock
11- import urllib .request
1213
1314import yaml
14- from PyQt6 . QtCore import QObject , QThread , pyqtSignal , pyqtSlot
15- from PyQt6 .QtCore import QCoreApplication
15+ from huggingface_hub import snapshot_download
16+ from PyQt6 .QtCore import QCoreApplication , QObject , QThread , pyqtSignal , pyqtSlot
1617
18+ from anylabeling .config import get_config , save_config
1719from anylabeling .configs import auto_labeling as auto_labeling_configs
1820from anylabeling .services .auto_labeling .types import AutoLabelingResult
1921from anylabeling .utils import GenericWorker
2022
21- from anylabeling .config import get_config , save_config
2223from .registry import ModelRegistry
2324
24- import ssl
25- from huggingface_hub import snapshot_download
26-
2725ssl ._create_default_https_context = (
2826 ssl ._create_unverified_context
2927) # Prevent issue when downloading models behind a proxy
@@ -110,7 +108,7 @@ def load_model_configs(self):
110108 model_config = copy .deepcopy (model )
111109 config_file = model .get ("config_file" , None )
112110 if config_file :
113- with open (config_file , "r" ) as f :
111+ with open (config_file ) as f :
114112 model_config = yaml .safe_load (f )
115113 model_config ["config_file" ] = os .path .normpath (
116114 os .path .abspath (config_file )
@@ -183,7 +181,7 @@ def load_custom_model(self, config_file):
183181
184182 # Check config file content
185183 model_config = {}
186- with open (config_file , "r" ) as f :
184+ with open (config_file ) as f :
187185 model_config = yaml .safe_load (f )
188186 model_config ["config_file" ] = os .path .abspath (config_file )
189187 if not model_config :
@@ -324,7 +322,7 @@ def _progress(count, block_size, total_size):
324322 tmp_extract_dir = os .path .join (tmp_dir , "extract" )
325323 with zipfile .ZipFile (zip_model_path , "r" ) as zip_ref :
326324 zip_ref .extractall (tmp_extract_dir )
327- # Find model folder (containing config.yaml)
325+ # Find model folder (containing config.yaml)
328326 model_folder = None
329327 for root , _ , files in os .walk (tmp_extract_dir ):
330328 if "config.yaml" in files :
@@ -333,16 +331,16 @@ def _progress(count, block_size, total_size):
333331 if model_folder is None :
334332 raise ValueError (self .tr ("Could not find config.yaml in zip file." ))
335333 return model_folder
336-
334+
337335 def download_hf (self , tmp_dir , download_url , model_config ):
338- repo_id = download_url .replace (' https://huggingface.co/' , '' ).strip ('/' )
336+ repo_id = download_url .replace (" https://huggingface.co/" , "" ).strip ("/" )
339337 # Only take the first two segments: namespace/repo_name
340- repo_id = "/" .join (repo_id .split ('/' )[:2 ])
341-
338+ repo_id = "/" .join (repo_id .split ("/" )[:2 ])
339+
342340 tmp_extract_dir = os .path .join (tmp_dir , "extract" )
343- local_dir = snapshot_download (
341+ snapshot_download (
344342 repo_id = repo_id ,
345- local_dir = tmp_extract_dir # where to store everything
343+ local_dir = tmp_extract_dir , # where to store everything
346344 )
347345 with open (tmp_extract_dir + "/config.yaml" , "w" ) as f :
348346 yaml .dump (model_config , f , default_flow_style = False )
@@ -355,7 +353,7 @@ def _download_and_extract_model(self, model_config):
355353 # Check if model is already downloaded
356354 if not os .path .exists (config_file ):
357355 raise ValueError (self .tr ("Error in loading config file." ))
358- with open (config_file , "r" ) as f :
356+ with open (config_file ) as f :
359357 model_config = yaml .safe_load (f )
360358 if model_config .get ("has_downloaded" , False ):
361359 return
@@ -364,19 +362,18 @@ def _download_and_extract_model(self, model_config):
364362 download_url = model_config .get ("download_url" , None )
365363 if not download_url :
366364 raise ValueError (self .tr ("Missing download_url in config file." ))
367-
365+
368366 tmp_dir = tempfile .mkdtemp ()
369367 model_folder = None
370- if download_url .endswith (' .zip' ):
368+ if download_url .endswith (" .zip" ):
371369 model_folder = self .download_zip (tmp_dir , download_url )
372- elif download_url .startswith (' https://huggingface.co' ):
370+ elif download_url .startswith (" https://huggingface.co" ):
373371 model_folder = self .download_hf (tmp_dir , download_url , model_config )
374372
375373 if model_folder is None :
376374 shutil .rmtree (tmp_dir )
377375 raise ValueError (self .tr ("Could not download model." ))
378376
379-
380377 # Move model folder to correct location
381378 shutil .rmtree (extract_dir )
382379 shutil .move (model_folder , extract_dir )
@@ -385,7 +382,7 @@ def _download_and_extract_model(self, model_config):
385382 shutil .rmtree (tmp_dir )
386383
387384 # Update config file
388- with open (config_file , "r" ) as f :
385+ with open (config_file ) as f :
389386 model_config = yaml .safe_load (f )
390387 model_config ["has_downloaded" ] = True
391388 model_config ["config_file" ] = config_file
@@ -422,7 +419,7 @@ def _load_model(self, model_id):
422419 model_config ["model" ] = model_class (
423420 model_config , on_message = self .new_model_status .emit
424421 )
425-
422+
426423 # Specific logic for interactive models (like SAM) vs detection models
427424 # Ideally this should be a property of the model class (capabilities)
428425 if model_type == "segment_anything" :
@@ -433,18 +430,8 @@ def _load_model(self, model_id):
433430 self .auto_segmentation_model_unselected .emit ()
434431
435432 except Exception as e : # noqa
436- self .new_model_status .emit (
437- self .tr (
438- "Error in loading model: {error_message}" .format (
439- error_message = str (e )
440- )
441- )
442- )
443- print (
444- "Error in loading model: {error_message}" .format (
445- error_message = str (e )
446- )
447- )
433+ self .new_model_status .emit (self .tr (f"Error in loading model: { str (e )} " ))
434+ print (f"Error in loading model: { str (e )} " )
448435 return
449436
450437 self .loaded_model_config = model_config
@@ -530,7 +517,7 @@ def predict_shapes_threading(self, image, filename=None):
530517 ):
531518 if hasattr (self .loaded_model_config ["model" ], "unload" ):
532519 self .loaded_model_config ["model" ].unload ()
533-
520+
534521 # Wait for the thread to finish
535522 self .model_execution_thread .quit ()
536523 if not self .model_execution_thread .wait (1000 ):
0 commit comments