Skip to content

Commit 9486f46

Browse files
committed
refactor: format code with ruff
1 parent 9735fe8 commit 9486f46

40 files changed

Lines changed: 338 additions & 228 deletions

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ repos:
2121
- id: trailing-whitespace
2222
files: \.(py|sh|rst|yml|yaml)$
2323
- repo: https://github.com/astral-sh/ruff-pre-commit
24-
rev: v0.11.7
24+
rev: v0.15.2
2525
hooks:
2626
- id: ruff
2727
args: [ --fix ]

anylabeling/app.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
import yaml
1616
from PyQt6 import QtCore, QtWidgets
1717

18+
from anylabeling import config as anylabeling_config
1819
from anylabeling.app_info import __appname__
1920
from anylabeling.config import get_config
20-
from anylabeling import config as anylabeling_config
21-
from anylabeling.views.mainwindow import MainWindow
22-
from anylabeling.views.labeling.logger import logger
23-
from anylabeling.views.labeling.utils import new_icon
2421
from anylabeling.resources import resources
2522
from anylabeling.styles import AppTheme
23+
from anylabeling.views.labeling.logger import logger
24+
from anylabeling.views.labeling.utils import new_icon
25+
from anylabeling.views.mainwindow import MainWindow
2626

2727
__all__ = ["resources"]
2828

@@ -175,7 +175,9 @@ def main():
175175

176176
# Enable scaling for high dpi screens
177177
# High DPI scaling is enabled by default in Qt 6
178-
QtCore.QCoreApplication.setAttribute(QtCore.Qt.ApplicationAttribute.AA_ShareOpenGLContexts)
178+
QtCore.QCoreApplication.setAttribute(
179+
QtCore.Qt.ApplicationAttribute.AA_ShareOpenGLContexts
180+
)
179181

180182
app = QtWidgets.QApplication(sys.argv)
181183
app.processEvents()

anylabeling/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from .views.labeling.logger import logger
1414

15-
1615
# Save current config file
1716
current_config_file = None
1817

anylabeling/resources/resources.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# -*- coding: utf-8 -*-
2-
31
# Resource object code
42
#
53
# Created by: The Resource Compiler for PyQt6 (Qt v5.15.2)
@@ -27893,18 +27891,25 @@
2789327891
\x00\x00\x01\x96\x97\x00\x2b\x5f\
2789427892
"
2789527893

27896-
qt_version = [int(v) for v in QtCore.qVersion().split('.')]
27894+
qt_version = [int(v) for v in QtCore.qVersion().split(".")]
2789727895
if qt_version < [5, 8, 0]:
2789827896
rcc_version = 1
2789927897
qt_resource_struct = qt_resource_struct_v1
2790027898
else:
2790127899
rcc_version = 2
2790227900
qt_resource_struct = qt_resource_struct_v2
2790327901

27902+
2790427903
def qInitResources():
27905-
QtCore.qRegisterResourceData(rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data)
27904+
QtCore.qRegisterResourceData(
27905+
rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data
27906+
)
27907+
2790627908

2790727909
def qCleanupResources():
27908-
QtCore.qUnregisterResourceData(rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data)
27910+
QtCore.qUnregisterResourceData(
27911+
rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data
27912+
)
27913+
2790927914

2791027915
qInitResources()
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from .model import Model
2-
from .registry import ModelRegistry
3-
4-
# Import models to ensure they register themselves
5-
from . import yolov5
6-
from . import yolov8
7-
from . import segment_anything
1+
# Import models to ensure they register themselves via @ModelRegistry.register
2+
from . import segment_anything as segment_anything # noqa: F401
3+
from . import yolov5 as yolov5 # noqa: F401
4+
from . import yolov8 as yolov8 # noqa: F401
5+
from .model import Model as Model
6+
from .registry import ModelRegistry as ModelRegistry

anylabeling/services/auto_labeling/lru_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Thread-safe LRU cache implementation."""
22

3-
from collections import OrderedDict
43
import threading
4+
from collections import OrderedDict
55

66

77
class LRUCache:

anylabeling/services/auto_labeling/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import logging
22
import os
3-
import yaml
43
import socket
54
import ssl
65
from abc import abstractmethod
76

7+
import yaml
88
from PyQt6.QtCore import QCoreApplication, QFile, QObject
99
from PyQt6.QtGui import QImage
1010

11-
from .types import AutoLabelingResult
1211
from anylabeling.views.labeling.label_file import LabelFile, LabelFileError
1312

13+
from .types import AutoLabelingResult
14+
1415
# Prevent issue when downloading models behind a proxy
1516
os.environ["no_proxy"] = "*"
1617

@@ -44,7 +45,7 @@ def __init__(self, model_config, on_message) -> None:
4445
"Model", "Config file not found: {model_config}"
4546
).format(model_config=model_config)
4647
)
47-
with open(model_config, "r") as f:
48+
with open(model_config) as f:
4849
self.config = yaml.safe_load(f)
4950
elif isinstance(model_config, dict):
5051
self.config = model_config
@@ -120,14 +121,14 @@ def load_image_from_filename(filename):
120121
try:
121122
label_file = LabelFile(label_file)
122123
except LabelFileError as e:
123-
logging.error("Error reading {}: {}".format(label_file, e))
124+
logging.error(f"Error reading {label_file}: {e}")
124125
return None, None
125126
image_data = label_file.image_data
126127
else:
127128
image_data = LabelFile.load_image_file(filename)
128129
image = QImage.fromData(image_data)
129130
if image.isNull():
130-
logging.error("Error reading {}".format(filename))
131+
logging.error(f"Error reading {filename}")
131132
return image
132133

133134
def on_next_files_changed(self, next_files):

anylabeling/services/auto_labeling/model_manager.py

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,27 @@
1-
import os
21
import copy
3-
import time
4-
import shutil
5-
import pathlib
2+
import importlib.resources as pkg_resources
63
import logging
4+
import os
5+
import pathlib
6+
import shutil
7+
import ssl
78
import tempfile
9+
import time
10+
import urllib.request
811
import zipfile
9-
import importlib.resources as pkg_resources
1012
from threading import Lock
11-
import urllib.request
1213

1314
import yaml
14-
from PyQt6.QtCore import QObject, QThread, pyqtSignal, pyqtSlot
15-
from PyQt6.QtCore import QCoreApplication
15+
from huggingface_hub import snapshot_download
16+
from PyQt6.QtCore import QCoreApplication, QObject, QThread, pyqtSignal, pyqtSlot
1617

18+
from anylabeling.config import get_config, save_config
1719
from anylabeling.configs import auto_labeling as auto_labeling_configs
1820
from anylabeling.services.auto_labeling.types import AutoLabelingResult
1921
from anylabeling.utils import GenericWorker
2022

21-
from anylabeling.config import get_config, save_config
2223
from .registry import ModelRegistry
2324

24-
import ssl
25-
from huggingface_hub import snapshot_download
26-
2725
ssl._create_default_https_context = (
2826
ssl._create_unverified_context
2927
) # Prevent issue when downloading models behind a proxy
@@ -110,7 +108,7 @@ def load_model_configs(self):
110108
model_config = copy.deepcopy(model)
111109
config_file = model.get("config_file", None)
112110
if config_file:
113-
with open(config_file, "r") as f:
111+
with open(config_file) as f:
114112
model_config = yaml.safe_load(f)
115113
model_config["config_file"] = os.path.normpath(
116114
os.path.abspath(config_file)
@@ -183,7 +181,7 @@ def load_custom_model(self, config_file):
183181

184182
# Check config file content
185183
model_config = {}
186-
with open(config_file, "r") as f:
184+
with open(config_file) as f:
187185
model_config = yaml.safe_load(f)
188186
model_config["config_file"] = os.path.abspath(config_file)
189187
if not model_config:
@@ -324,7 +322,7 @@ def _progress(count, block_size, total_size):
324322
tmp_extract_dir = os.path.join(tmp_dir, "extract")
325323
with zipfile.ZipFile(zip_model_path, "r") as zip_ref:
326324
zip_ref.extractall(tmp_extract_dir)
327-
# Find model folder (containing config.yaml)
325+
# Find model folder (containing config.yaml)
328326
model_folder = None
329327
for root, _, files in os.walk(tmp_extract_dir):
330328
if "config.yaml" in files:
@@ -333,16 +331,16 @@ def _progress(count, block_size, total_size):
333331
if model_folder is None:
334332
raise ValueError(self.tr("Could not find config.yaml in zip file."))
335333
return model_folder
336-
334+
337335
def download_hf(self, tmp_dir, download_url, model_config):
338-
repo_id = download_url.replace('https://huggingface.co/', '').strip('/')
336+
repo_id = download_url.replace("https://huggingface.co/", "").strip("/")
339337
# Only take the first two segments: namespace/repo_name
340-
repo_id = "/".join(repo_id.split('/')[:2])
341-
338+
repo_id = "/".join(repo_id.split("/")[:2])
339+
342340
tmp_extract_dir = os.path.join(tmp_dir, "extract")
343-
local_dir = snapshot_download(
341+
snapshot_download(
344342
repo_id=repo_id,
345-
local_dir=tmp_extract_dir # where to store everything
343+
local_dir=tmp_extract_dir, # where to store everything
346344
)
347345
with open(tmp_extract_dir + "/config.yaml", "w") as f:
348346
yaml.dump(model_config, f, default_flow_style=False)
@@ -355,7 +353,7 @@ def _download_and_extract_model(self, model_config):
355353
# Check if model is already downloaded
356354
if not os.path.exists(config_file):
357355
raise ValueError(self.tr("Error in loading config file."))
358-
with open(config_file, "r") as f:
356+
with open(config_file) as f:
359357
model_config = yaml.safe_load(f)
360358
if model_config.get("has_downloaded", False):
361359
return
@@ -364,19 +362,18 @@ def _download_and_extract_model(self, model_config):
364362
download_url = model_config.get("download_url", None)
365363
if not download_url:
366364
raise ValueError(self.tr("Missing download_url in config file."))
367-
365+
368366
tmp_dir = tempfile.mkdtemp()
369367
model_folder = None
370-
if download_url.endswith('.zip'):
368+
if download_url.endswith(".zip"):
371369
model_folder = self.download_zip(tmp_dir, download_url)
372-
elif download_url.startswith('https://huggingface.co'):
370+
elif download_url.startswith("https://huggingface.co"):
373371
model_folder = self.download_hf(tmp_dir, download_url, model_config)
374372

375373
if model_folder is None:
376374
shutil.rmtree(tmp_dir)
377375
raise ValueError(self.tr("Could not download model."))
378376

379-
380377
# Move model folder to correct location
381378
shutil.rmtree(extract_dir)
382379
shutil.move(model_folder, extract_dir)
@@ -385,7 +382,7 @@ def _download_and_extract_model(self, model_config):
385382
shutil.rmtree(tmp_dir)
386383

387384
# Update config file
388-
with open(config_file, "r") as f:
385+
with open(config_file) as f:
389386
model_config = yaml.safe_load(f)
390387
model_config["has_downloaded"] = True
391388
model_config["config_file"] = config_file
@@ -422,7 +419,7 @@ def _load_model(self, model_id):
422419
model_config["model"] = model_class(
423420
model_config, on_message=self.new_model_status.emit
424421
)
425-
422+
426423
# Specific logic for interactive models (like SAM) vs detection models
427424
# Ideally this should be a property of the model class (capabilities)
428425
if model_type == "segment_anything":
@@ -433,18 +430,8 @@ def _load_model(self, model_id):
433430
self.auto_segmentation_model_unselected.emit()
434431

435432
except Exception as e: # noqa
436-
self.new_model_status.emit(
437-
self.tr(
438-
"Error in loading model: {error_message}".format(
439-
error_message=str(e)
440-
)
441-
)
442-
)
443-
print(
444-
"Error in loading model: {error_message}".format(
445-
error_message=str(e)
446-
)
447-
)
433+
self.new_model_status.emit(self.tr(f"Error in loading model: {str(e)}"))
434+
print(f"Error in loading model: {str(e)}")
448435
return
449436

450437
self.loaded_model_config = model_config
@@ -530,7 +517,7 @@ def predict_shapes_threading(self, image, filename=None):
530517
):
531518
if hasattr(self.loaded_model_config["model"], "unload"):
532519
self.loaded_model_config["model"].unload()
533-
520+
534521
# Wait for the thread to finish
535522
self.model_execution_thread.quit()
536523
if not self.model_execution_thread.wait(1000):

anylabeling/services/auto_labeling/registry.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Dict, Type, Any
21
import logging
32

3+
44
class ModelRegistry:
55
"""
66
Singleton registry to manage auto-labeling model classes.
77
"""
8-
_registry: Dict[str, Type] = {}
8+
9+
_registry: dict[str, type] = {}
910

1011
@classmethod
1112
def register(cls, name: str):
@@ -15,15 +16,19 @@ def register(cls, name: str):
1516
@ModelRegistry.register("yolov8")
1617
class YOLOv8(Model): ...
1718
"""
19+
1820
def decorator(model_class):
1921
if name in cls._registry:
20-
logging.warning(f"Model type '{name}' is already registered. Overwriting.")
22+
logging.warning(
23+
f"Model type '{name}' is already registered. Overwriting."
24+
)
2125
cls._registry[name] = model_class
2226
return model_class
27+
2328
return decorator
2429

2530
@classmethod
26-
def get(cls, name: str) -> Type:
31+
def get(cls, name: str) -> type:
2732
"""Get a model class by type name."""
2833
return cls._registry.get(name)
2934

anylabeling/services/auto_labeling/sam2_coreml.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
2+
23
import cv2
34
import numpy as np
4-
from pathlib import Path
55
from PIL import Image
66

77

88
class SegmentAnything2CoreML:
99
def __init__(self, model_path: str) -> None:
1010
import coremltools as ct # macOS-only; imported lazily to avoid failure on Windows/Linux
11+
1112
print("using CoreML", model_path)
1213
image_decoder_path = os.path.join(
1314
model_path, "SAM2_1LargeImageEncoderFLOAT16.mlpackage"

0 commit comments

Comments
 (0)