Skip to content

Commit 05de799

Browse files
Added persistence and HF Hub uploads to CVP
1 parent 0fd01a4 commit 05de799

1 file changed

Lines changed: 42 additions & 0 deletions

File tree

turftopic/models/cvp.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
import json
2+
import tempfile
13
from collections import OrderedDict
4+
from pathlib import Path
25
from typing import Union
36

7+
import joblib
48
import numpy as np
9+
from huggingface_hub import HfApi
510
from sentence_transformers import SentenceTransformer
611
from sklearn.base import BaseEstimator, TransformerMixin
712

813
from turftopic.base import Encoder
914
from turftopic.encoders.multimodal import MultimodalEncoder
15+
from turftopic.serialization import create_readme, get_package_versions
1016

1117
Seeds = tuple[list[str], list[str]]
1218

@@ -105,3 +111,39 @@ def transform(self, raw_documents=None, embeddings=None):
105111
Prevalance of each concept in each document.
106112
"""
107113
return self.fit_transform(raw_documents, embeddings=embeddings)
114+
115+
def to_disk(self, out_dir: Union[Path, str]):
116+
"""Persists model to directory on your machine.
117+
118+
Parameters
119+
----------
120+
out_dir: Path | str
121+
Directory to save the model to.
122+
"""
123+
out_dir = Path(out_dir)
124+
out_dir.mkdir(exist_ok=True)
125+
package_versions = get_package_versions()
126+
with out_dir.joinpath("package_versions.json").open("w") as ver_file:
127+
ver_file.write(json.dumps(package_versions))
128+
joblib.dump(self, out_dir.joinpath("model.joblib"))
129+
130+
def push_to_hub(self, repo_id: str):
131+
"""Uploads model to HuggingFace Hub
132+
133+
Parameters
134+
----------
135+
repo_id: str
136+
Repository to upload the model to.
137+
"""
138+
api = HfApi()
139+
api.create_repo(repo_id, exist_ok=True)
140+
with tempfile.TemporaryDirectory() as tmp_dir:
141+
readme_path = Path(tmp_dir).joinpath("README.md")
142+
with readme_path.open("w") as readme_file:
143+
readme_file.write(create_readme(self, repo_id))
144+
self.to_disk(tmp_dir)
145+
api.upload_folder(
146+
folder_path=tmp_dir,
147+
repo_id=repo_id,
148+
repo_type="model",
149+
)

0 commit comments

Comments
 (0)