Skip to content

Commit 3ececcc

Browse files
Renamed to LateSentenceTransformer
1 parent 6db8592 commit 3ececcc

2 files changed

Lines changed: 24 additions & 55 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def is_contextual(encoder):
1818
Lengths = list[int]
1919

2020

21-
class LateTransformer(SentenceTransformer):
21+
class LateSentenceTransformer(SentenceTransformer):
2222
def encode(
2323
self, sentences: Union[str, list[str], np.ndarray], *args, **kwargs
2424
):

turftopic/late.py

Lines changed: 23 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import numpy as np
44
from sklearn.base import TransformerMixin
55

6+
from turftopic.encoders.contextual import Offsets
7+
68
Lengths = list[int]
79

810

@@ -62,7 +64,17 @@ def pool_flat(flat_repr: np.ndarray, lengths: Lengths, agg=np.mean):
6264
return np.stack(pooled)
6365

6466

65-
class TokenLevel(TransformerMixin):
67+
def get_document_chunks(
68+
raw_documents: list[str], offsets: list[Offsets]
69+
) -> list[str]:
70+
chunks = []
71+
for doc, _offs in zip(raw_documents, offsets):
72+
for start_char, end_char in _offs:
73+
chunks.append(raw_documents[start_char, end_char])
74+
return chunks
75+
76+
77+
class LateModel(TransformerMixin):
6678
def __init__(
6779
self,
6880
model: TransformerMixin,
@@ -74,63 +86,18 @@ def __init__(
7486
self.pooling = pooling
7587

7688
def transform(
77-
self, raw_documents: list[str], embeddings: list[np.ndarray] = None
78-
):
79-
if embeddings is None:
80-
embeddings = self.model.encoder.encode_tokens(
81-
raw_documents, batch_size=self.batch_size
82-
)
83-
flat_embeddings, lengths = flatten_repr(embeddings)
84-
out_array = self.model.transform(
85-
raw_documents, embeddings=flat_embeddings
86-
)
87-
if self.pooling is None:
88-
return unflatten_repr(out_array, lengths)
89-
else:
90-
return pool_flat(out_array, lengths)
91-
92-
def fit_transform(
9389
self,
9490
raw_documents: list[str],
95-
y=None,
9691
embeddings: list[np.ndarray] = None,
92+
offsets: list[Offsets] = None,
9793
):
98-
if embeddings is None:
99-
embeddings = self.model.encoder.encode_tokens(
94+
if (embeddings is None) or (offsets is None):
95+
embeddings, offsets = self.model.encoder.encode_tokens(
10096
raw_documents, batch_size=self.batch_size
10197
)
10298
flat_embeddings, lengths = flatten_repr(embeddings)
103-
out_array = self.model.fit_transform(
104-
raw_documents, y, embeddings=flat_embeddings
105-
)
106-
if self.pooling is None:
107-
return unflatten_repr(out_array, lengths)
108-
else:
109-
return pool_flat(out_array, lengths)
110-
111-
112-
class Windowed(TransformerMixin):
113-
def __init__(
114-
self,
115-
model: TransformerMixin,
116-
batch_size: int = 32,
117-
pooling: Optional[Callable] = None,
118-
):
119-
self.model = model
120-
self.batch_size = batch_size
121-
self.pooling = pooling
122-
123-
def transform(
124-
self, raw_documents: list[str], embeddings: list[np.ndarray] = None
125-
):
126-
if embeddings is None:
127-
embeddings = self.model.encoder.encode_tokens(
128-
raw_documents, batch_size=self.batch_size
129-
)
130-
flat_embeddings, lengths = flatten_repr(embeddings)
131-
out_array = self.model.transform(
132-
raw_documents, embeddings=flat_embeddings
133-
)
99+
chunks = get_document_chunks(raw_documents, offsets)
100+
out_array = self.model.transform(chunks, embeddings=flat_embeddings)
134101
if self.pooling is None:
135102
return unflatten_repr(out_array, lengths)
136103
else:
@@ -141,14 +108,16 @@ def fit_transform(
141108
raw_documents: list[str],
142109
y=None,
143110
embeddings: list[np.ndarray] = None,
111+
offsets: list[Offsets] = None,
144112
):
145-
if embeddings is None:
146-
embeddings = self.model.encoder.encode_tokens(
113+
if (embeddings is None) or (offsets is None):
114+
embeddings, offsets = self.model.encoder.encode_tokens(
147115
raw_documents, batch_size=self.batch_size
148116
)
149117
flat_embeddings, lengths = flatten_repr(embeddings)
118+
chunks = get_document_chunks(raw_documents, offsets)
150119
out_array = self.model.fit_transform(
151-
raw_documents, y, embeddings=flat_embeddings
120+
chunks, embeddings=flat_embeddings
152121
)
153122
if self.pooling is None:
154123
return unflatten_repr(out_array, lengths)

0 commit comments

Comments
 (0)