@@ -53,40 +53,20 @@ def _encode_tokens(
5353 Start and end character of each token in each document.
5454 """
5555 self .has_used_token_level = True
56- token_embeddings = []
57- offsets = []
58- for start_index in trange (
59- 0 ,
60- len (texts ),
61- batch_size ,
62- desc = "Encoding batches..." ,
63- ):
64- batch = texts [start_index : start_index + batch_size ]
65- features = self .tokenize (batch )
66- with torch .no_grad ():
67- output_features = self .forward (features )
68- n_tokens = output_features ["attention_mask" ].sum (axis = 1 )
69- # Find first nonzero elements in each document
70- # The document could be padded from the left, so we have to watch out for this.
71- start_token = torch .argmax (
72- (output_features ["attention_mask" ] > 0 ).to (torch .long ), axis = 1
73- )
74- end_token = start_token + n_tokens
75- for i_doc in range (len (batch )):
76- _token_embeddings = (
77- output_features ["token_embeddings" ][
78- i_doc , start_token [i_doc ] : end_token [i_doc ], :
79- ]
80- .float ()
81- .numpy (force = True )
82- )
83- _n = _token_embeddings .shape [0 ]
84- # We extract the character offsets and prune it at the maximum context length
85- _offsets = self .tokenizer (
86- batch [i_doc ], return_offsets_mapping = True , verbose = False
87- )["offset_mapping" ][:_n ]
88- token_embeddings .append (_token_embeddings )
89- offsets .append (_offsets )
56+ token_embeddings = self .encode (
57+ texts , output_value = "token_embeddings" , batch_size = batch_size
58+ )
59+ offsets = self .tokenizer (
60+ texts , return_offsets_mapping = True , verbose = False
61+ )["offset_mapping" ]
62+ offsets = [
63+ offs [: len (embs )] for offs , embs in zip (offsets , token_embeddings )
64+ ]
65+ token_embeddings = [
66+ embs .numpy (force = True )
67+ for embs in token_embeddings
68+ if torch .is_tensor (embs )
69+ ]
9070 return token_embeddings , offsets
9171
9272 def encode_tokens (
0 commit comments