|
| 1 | +import json |
| 2 | +import tempfile |
1 | 3 | from collections import OrderedDict |
| 4 | +from pathlib import Path |
2 | 5 | from typing import Union |
3 | 6 |
|
| 7 | +import joblib |
4 | 8 | import numpy as np |
| 9 | +from huggingface_hub import HfApi |
5 | 10 | from sentence_transformers import SentenceTransformer |
6 | 11 | from sklearn.base import BaseEstimator, TransformerMixin |
7 | 12 |
|
8 | 13 | from turftopic.base import Encoder |
9 | 14 | from turftopic.encoders.multimodal import MultimodalEncoder |
| 15 | +from turftopic.serialization import create_readme, get_package_versions |
10 | 16 |
|
11 | 17 | Seeds = tuple[list[str], list[str]] |
12 | 18 |
|
@@ -105,3 +111,39 @@ def transform(self, raw_documents=None, embeddings=None): |
105 | 111 | Prevalance of each concept in each document. |
106 | 112 | """ |
107 | 113 | return self.fit_transform(raw_documents, embeddings=embeddings) |
| 114 | + |
| 115 | + def to_disk(self, out_dir: Union[Path, str]): |
| 116 | + """Persists model to directory on your machine. |
| 117 | +
|
| 118 | + Parameters |
| 119 | + ---------- |
| 120 | + out_dir: Path | str |
| 121 | + Directory to save the model to. |
| 122 | + """ |
| 123 | + out_dir = Path(out_dir) |
| 124 | + out_dir.mkdir(exist_ok=True) |
| 125 | + package_versions = get_package_versions() |
| 126 | + with out_dir.joinpath("package_versions.json").open("w") as ver_file: |
| 127 | + ver_file.write(json.dumps(package_versions)) |
| 128 | + joblib.dump(self, out_dir.joinpath("model.joblib")) |
| 129 | + |
| 130 | + def push_to_hub(self, repo_id: str): |
| 131 | + """Uploads model to HuggingFace Hub |
| 132 | +
|
| 133 | + Parameters |
| 134 | + ---------- |
| 135 | + repo_id: str |
| 136 | + Repository to upload the model to. |
| 137 | + """ |
| 138 | + api = HfApi() |
| 139 | + api.create_repo(repo_id, exist_ok=True) |
| 140 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 141 | + readme_path = Path(tmp_dir).joinpath("README.md") |
| 142 | + with readme_path.open("w") as readme_file: |
| 143 | + readme_file.write(create_readme(self, repo_id)) |
| 144 | + self.to_disk(tmp_dir) |
| 145 | + api.upload_folder( |
| 146 | + folder_path=tmp_dir, |
| 147 | + repo_id=repo_id, |
| 148 | + repo_type="model", |
| 149 | + ) |
0 commit comments