33import numpy as np
44from sklearn .base import TransformerMixin
55
6+ from turftopic .encoders .contextual import Offsets
7+
68Lengths = list [int ]
79
810
@@ -62,7 +64,17 @@ def pool_flat(flat_repr: np.ndarray, lengths: Lengths, agg=np.mean):
6264 return np .stack (pooled )
6365
6466
65- class TokenLevel (TransformerMixin ):
67+ def get_document_chunks (
68+ raw_documents : list [str ], offsets : list [Offsets ]
69+ ) -> list [str ]:
70+ chunks = []
71+ for doc , _offs in zip (raw_documents , offsets ):
72+ for start_char , end_char in _offs :
73+ chunks .append (raw_documents [start_char , end_char ])
74+ return chunks
75+
76+
77+ class LateModel (TransformerMixin ):
6678 def __init__ (
6779 self ,
6880 model : TransformerMixin ,
@@ -74,63 +86,18 @@ def __init__(
7486 self .pooling = pooling
7587
7688 def transform (
77- self , raw_documents : list [str ], embeddings : list [np .ndarray ] = None
78- ):
79- if embeddings is None :
80- embeddings = self .model .encoder .encode_tokens (
81- raw_documents , batch_size = self .batch_size
82- )
83- flat_embeddings , lengths = flatten_repr (embeddings )
84- out_array = self .model .transform (
85- raw_documents , embeddings = flat_embeddings
86- )
87- if self .pooling is None :
88- return unflatten_repr (out_array , lengths )
89- else :
90- return pool_flat (out_array , lengths )
91-
92- def fit_transform (
9389 self ,
9490 raw_documents : list [str ],
95- y = None ,
9691 embeddings : list [np .ndarray ] = None ,
92+ offsets : list [Offsets ] = None ,
9793 ):
98- if embeddings is None :
99- embeddings = self .model .encoder .encode_tokens (
94+ if ( embeddings is None ) or ( offsets is None ) :
95+ embeddings , offsets = self .model .encoder .encode_tokens (
10096 raw_documents , batch_size = self .batch_size
10197 )
10298 flat_embeddings , lengths = flatten_repr (embeddings )
103- out_array = self .model .fit_transform (
104- raw_documents , y , embeddings = flat_embeddings
105- )
106- if self .pooling is None :
107- return unflatten_repr (out_array , lengths )
108- else :
109- return pool_flat (out_array , lengths )
110-
111-
112- class Windowed (TransformerMixin ):
113- def __init__ (
114- self ,
115- model : TransformerMixin ,
116- batch_size : int = 32 ,
117- pooling : Optional [Callable ] = None ,
118- ):
119- self .model = model
120- self .batch_size = batch_size
121- self .pooling = pooling
122-
123- def transform (
124- self , raw_documents : list [str ], embeddings : list [np .ndarray ] = None
125- ):
126- if embeddings is None :
127- embeddings = self .model .encoder .encode_tokens (
128- raw_documents , batch_size = self .batch_size
129- )
130- flat_embeddings , lengths = flatten_repr (embeddings )
131- out_array = self .model .transform (
132- raw_documents , embeddings = flat_embeddings
133- )
99+ chunks = get_document_chunks (raw_documents , offsets )
100+ out_array = self .model .transform (chunks , embeddings = flat_embeddings )
134101 if self .pooling is None :
135102 return unflatten_repr (out_array , lengths )
136103 else :
@@ -141,14 +108,16 @@ def fit_transform(
141108 raw_documents : list [str ],
142109 y = None ,
143110 embeddings : list [np .ndarray ] = None ,
111+ offsets : list [Offsets ] = None ,
144112 ):
145- if embeddings is None :
146- embeddings = self .model .encoder .encode_tokens (
113+ if ( embeddings is None ) or ( offsets is None ) :
114+ embeddings , offsets = self .model .encoder .encode_tokens (
147115 raw_documents , batch_size = self .batch_size
148116 )
149117 flat_embeddings , lengths = flatten_repr (embeddings )
118+ chunks = get_document_chunks (raw_documents , offsets )
150119 out_array = self .model .fit_transform (
151- raw_documents , y , embeddings = flat_embeddings
120+ chunks , embeddings = flat_embeddings
152121 )
153122 if self .pooling is None :
154123 return unflatten_repr (out_array , lengths )
0 commit comments