Skip to content

Commit 34b3839

Browse files
author
Perry
committed
added model download logic to utils (huggingface call)
1 parent 1c5adc3 commit 34b3839

1 file changed

Lines changed: 50 additions & 0 deletions

File tree

panel_segmentation/utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,58 @@
1717
import geopandas
1818
from pyproj import Transformer
1919
from skimage.transform import hough_line, hough_line_peaks
20+
from huggingface_hub import hf_hub_download
2021

2122

23+
def downloadModel(filename,
24+
repo_id="kperrynrel/panel-segmentation-models"):
25+
"""
26+
Download a model file from a Hugging Face repository and save it
27+
to the panel_segmentation/models folder.
28+
29+
Parameters
30+
----------
31+
filename : str
32+
The name of the model file to download from the repository.
33+
Example: 'panel_detection_model.pth'
34+
repo_id : str
35+
The Hugging Face repository ID in the format 'owner/repo-name'.
36+
Default is 'kperrynrel/panel-segmentation-models', which is where
37+
the models are located.
38+
39+
Returns
40+
-------
41+
model_path : str
42+
The full file path of the downloaded model.
43+
"""
44+
# Check input variable types
45+
if not isinstance(repo_id, str):
46+
raise TypeError("repo_id variable must be of type string.")
47+
if not isinstance(filename, str):
48+
raise TypeError("filename variable must be of type string.")
49+
# models folder is panel_segmentation/models relative to utils.py
50+
models_folder = os.path.join(
51+
os.path.dirname(os.path.abspath(__file__)), "models"
52+
)
53+
# Build the destination path
54+
model_path = os.path.join(models_folder, filename)
55+
# Skip download if the file already exists locally
56+
if os.path.exists(model_path):
57+
print(f"Model already exists at {model_path}, skipping download.")
58+
return model_path
59+
try:
60+
downloaded_path = hf_hub_download(
61+
repo_id=repo_id,
62+
filename=filename,
63+
local_dir=models_folder
64+
)
65+
except Exception as e:
66+
raise ValueError(
67+
f"Failed to download '{filename}' from '{repo_id}': {e}"
68+
)
69+
print(f"Model downloaded to {downloaded_path}")
70+
return downloaded_path
71+
2272
def generateSatelliteImage(latitude, longitude,
2373
file_name_save, google_maps_api_key,
2474
zoom_level=18):

0 commit comments

Comments
 (0)