Skip to content

Commit 19e3e56

Browse files
authored
fix cuBERT_LOGITS multi label num_labels!=1 bug (#19)
1 parent 30a4927 commit 19e3e56

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

src/cuBERT/Bert.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace cuBERT {
1818
this->max_batch_size = max_batch_size;
1919
this->seq_length = seq_length;
2020
this->hidden_size = hidden_size;
21+
this->num_labels = num_labels;
2122

2223
this->stream = cuBERT::cuda_stream_create();
2324
this->cublas = cuBERT::blas_create();
@@ -48,7 +49,7 @@ namespace cuBERT {
4849

4950
this->_embedding_output = static_cast<T *>(cuBERT::malloc(sizeof(T) * max_batch_size * seq_length * hidden_size));
5051
this->_pooled_output = static_cast<T *>(cuBERT::malloc(sizeof(T) * max_batch_size * hidden_size));
51-
this->_logits = static_cast<T *>(cuBERT::malloc(sizeof(T) * max_batch_size));
52+
this->_logits = static_cast<T *>(cuBERT::malloc(sizeof(T) * max_batch_size * num_labels));
5253

5354
this->input_ids_buf = static_cast<int *>(cuBERT::malloc(sizeof(int) * max_batch_size * seq_length));
5455
this->input_mask_buf = static_cast<int8_t *>(cuBERT::malloc(sizeof(int8_t) * max_batch_size * seq_length));
@@ -131,7 +132,7 @@ namespace cuBERT {
131132
}
132133

133134
void *streamId = cuBERT::blas_get_stream(cublas);
134-
cuBERT::memcpyAsync(logits, _logits, sizeof(T) * batch_size, 2, streamId);
135+
cuBERT::memcpyAsync(logits, _logits, sizeof(T) * batch_size * num_labels, 2, streamId);
135136
cuBERT::cuda_stream_synchronize(streamId);
136137

137138
if (!buffer_filled) {

src/cuBERT/Bert.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ namespace cuBERT {
4141
size_t max_batch_size;
4242
size_t seq_length;
4343
size_t hidden_size;
44+
size_t num_labels;
4445

4546
BertEmbeddings<T> *bert_embeddings;
4647
Transformer<T> *transformer;

0 commit comments

Comments
 (0)