Skip to content

Commit e03ede6

Browse files
Ticket #12 : Support bulk insert
1 parent 402ac0d commit e03ede6

4 files changed

Lines changed: 121 additions & 3 deletions

File tree

samples/EFCore.Cassandra.Samples/Program.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.
33
using Cassandra;
44
using EFCore.Cassandra.Samples.Models;
5+
using Microsoft.EntityFrameworkCore;
56
using System;
67
using System.Collections.Generic;
78
using System.Linq;
@@ -18,6 +19,10 @@ static void Main(string[] args)
1819
{
1920
using (var dbContext = new FakeDbContext())
2021
{
22+
Console.WriteLine("Bulk insert");
23+
var applicants = Enumerable.Repeat(1, 1).Select(_ => BuildApplicant()).ToList();
24+
dbContext.BulkInsert(applicants);
25+
2126
Console.WriteLine("Add applicant");
2227
var timeUuid = TimeUuid.NewId();
2328
dbContext.Applicants.Add(BuildApplicant());

src/EFCore.Cassandra.Benchmarks/InsertData.cs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.
33
using BenchmarkDotNet.Attributes;
44
using EFCore.Cassandra.Benchmarks.Models;
5+
using Microsoft.EntityFrameworkCore;
56
using System;
67
using System.Collections.Generic;
78
using System.Linq;
@@ -16,7 +17,7 @@ public class InsertData : IDisposable
1617
public InsertData()
1718
{
1819
_dbContext = new FakeDbContext();
19-
_applicants = Enumerable.Range(1, 5).Select(_ => BuildApplicant()).ToArray();
20+
_applicants = Enumerable.Range(1, 20).Select(_ => BuildApplicant()).ToArray();
2021
}
2122

2223
public void Dispose()
@@ -25,7 +26,7 @@ public void Dispose()
2526
}
2627

2728
[Benchmark]
28-
public void Add100Applicants()
29+
public void AddApplicants()
2930
{
3031
foreach (var applicant in _applicants)
3132
{
@@ -36,12 +37,18 @@ public void Add100Applicants()
3637
}
3738

3839
[Benchmark]
39-
public void AddRange100Applicants()
40+
public void AddRangeApplicants()
4041
{
4142
_dbContext.Applicants.AddRange(_applicants);
4243
_dbContext.SaveChanges();
4344
}
4445

46+
[Benchmark]
47+
public void BulkInsertApplicants()
48+
{
49+
_dbContext.BulkInsert(_applicants.ToList());
50+
}
51+
4552
private static Applicant BuildApplicant()
4653
{
4754
return new Applicant
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright (c) SimpleIdServer. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.
3+
using EFCore.Cassandra.Bulk;
4+
using System.Collections.Generic;
5+
using System.Threading.Tasks;
6+
7+
namespace Microsoft.EntityFrameworkCore
8+
{
9+
public static class DbContextBulkExtensions
10+
{
11+
public static void BulkInsert<T>(this DbContext dbContext, List<T> entities)
12+
{
13+
SqlBulkOperation.Insert(dbContext, entities);
14+
}
15+
16+
public static Task BulkInsertAsync<T>(this DbContext dbContext, List<T> entities)
17+
{
18+
return SqlBulkOperation.InsertAsync(dbContext, entities);
19+
}
20+
}
21+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright (c) SimpleIdServer. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.
3+
using Cassandra;
4+
using Cassandra.Data;
5+
using Microsoft.EntityFrameworkCore;
6+
using Microsoft.EntityFrameworkCore.Infrastructure;
7+
using Microsoft.EntityFrameworkCore.Metadata;
8+
using Microsoft.EntityFrameworkCore.Storage;
9+
using Microsoft.EntityFrameworkCore.Update;
10+
using System.Collections.Generic;
11+
using System.Data;
12+
using System.Linq;
13+
using System.Net;
14+
using System.Reflection;
15+
using System.Threading.Tasks;
16+
17+
namespace EFCore.Cassandra.Bulk
18+
{
19+
internal class SqlBulkOperation
20+
{
21+
public static void Insert<T>(DbContext dbContext, IList<T> entities)
22+
{
23+
var result = BuildBatch(dbContext, entities);
24+
result.session.Execute(result.batchStatement);
25+
}
26+
27+
public static async Task InsertAsync<T>(DbContext dbContext, IList<T> entities)
28+
{
29+
var result = BuildBatch(dbContext, entities);
30+
await result.session.ExecuteAsync(result.batchStatement);
31+
}
32+
33+
private static (BatchStatement batchStatement, ISession session) BuildBatch<T>(DbContext dbContext, IList<T> entities)
34+
{
35+
var service = dbContext.GetService<ICommandBatchPreparer>();
36+
var sqlGenerationHelper = dbContext.GetService<ISqlGenerationHelper>();
37+
var database = dbContext.Database.GetDbConnection() as CqlConnection;
38+
if (database.State != ConnectionState.Open)
39+
{
40+
database.Open();
41+
}
42+
43+
var prop = typeof(CqlConnection).GetField("ManagedConnection", BindingFlags.NonPublic | BindingFlags.Instance);
44+
var session = (ISession)prop.GetValue(database);
45+
var batch = new BatchStatement();
46+
foreach (var entity in entities)
47+
{
48+
var name = entity.GetType().FullName;
49+
var entityType = dbContext.Model.FindEntityType(name);
50+
var propertyNames = new List<string>();
51+
var propertyValues = new List<object>();
52+
foreach (var property in entityType.GetProperties())
53+
{
54+
propertyNames.Add(sqlGenerationHelper.DelimitIdentifier(property.GetColumnName()));
55+
var propValue = property.PropertyInfo.GetValue(entity);
56+
propValue = GetValue(property, propValue);
57+
propertyValues.Add(propValue);
58+
}
59+
60+
var tableName = entityType.GetTableName();
61+
var schema = entityType.GetSchema();
62+
var cqlQuery = $"INSERT INTO \"{schema}\".\"{tableName}\" ({string.Join(',', propertyNames)}) VALUES ({string.Join(',', Enumerable.Repeat(1, propertyNames.Count()).Select(_ => "?"))})";
63+
var smt = session.Prepare(cqlQuery);
64+
batch.Add(smt.Bind(propertyValues.ToArray()));
65+
}
66+
67+
return (batch, session);
68+
}
69+
70+
private static object GetValue(IProperty property, object value)
71+
{
72+
if(property.ClrType == typeof(IPAddress))
73+
{
74+
if (value == null)
75+
{
76+
return new byte[0];
77+
}
78+
79+
return ((IPAddress)value).GetAddressBytes();
80+
}
81+
82+
return value;
83+
}
84+
}
85+
}

0 commit comments

Comments
 (0)