Skip to content

Commit 124fc01

Browse files
Added basic implementation of CVP
1 parent e8e86a5 commit 124fc01

1 file changed

Lines changed: 59 additions & 0 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from collections import OrderedDict
2+
from typing import Union
3+
4+
import numpy as np
5+
from sentence_transformers import SentenceTransformer
6+
from sklearn.base import BaseEstimator, TransformerMixin
7+
8+
from turftopic.base import Encoder
9+
from turftopic.encoders.multimodal import MultimodalEncoder
10+
11+
12+
class ConceptVectorProjection(BaseEstimator, TransformerMixin):
13+
def __init__(
14+
self,
15+
seeds: (
16+
tuple[list[str], list[str]]
17+
| list[tuple[[str, tuple[list[str], list[str]]]]]
18+
),
19+
encoder: Union[
20+
Encoder, str, MultimodalEncoder
21+
] = "sentence-transformers/all-MiniLM-L6-v2",
22+
):
23+
self.seeds = seeds
24+
if (
25+
(len(seeds) == 2)
26+
and (isinstance(seeds, tuple))
27+
and (isinstance(seeds[0][0], str))
28+
):
29+
self._seeds = OrderedDict([("default", seeds)])
30+
else:
31+
self._seeds = OrderedDict(seeds)
32+
self.encoder = encoder
33+
if isinstance(encoder, str):
34+
self.encoder_ = SentenceTransformer(encoder)
35+
else:
36+
self.encoder_ = encoder
37+
self.classes_ = np.array([name for name in self._seeds])
38+
self.concept_matrix_ = []
39+
for _, (positive, negative) in self._seeds.items():
40+
positive_emb = self.encoder_.encode(positive)
41+
negative_emb = self.encoder_.encode(negative)
42+
cv = np.mean(positive_emb, axis=0) - np.mean(negative_emb, axis=0)
43+
self.concept_matrix_.append(cv / np.linalg.norm(cv))
44+
self.concept_matrix_ = np.stack(self.concept_matrix_)
45+
46+
def get_feature_names_out(self):
47+
return self.classes_
48+
49+
def fit_transform(self, raw_documents=None, y=None, embeddings=None):
50+
if (raw_documents is None) and (embeddings is None):
51+
raise ValueError(
52+
"Either embeddings or raw_documents has to be passed, both are None."
53+
)
54+
if embeddings is None:
55+
embeddings = self.encoder_.encode(raw_documents)
56+
return embeddings @ self.concept_matrix_.T
57+
58+
def transform(self, raw_documents=None, embeddings=None):
59+
return self.fit_transform(raw_documents, embeddings=embeddings)

0 commit comments

Comments
 (0)