|
| 1 | +// Copyright (c) SimpleIdServer. All rights reserved. |
| 2 | +// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. |
| 3 | +using Cassandra; |
| 4 | +using Cassandra.Data; |
| 5 | +using Microsoft.EntityFrameworkCore; |
| 6 | +using Microsoft.EntityFrameworkCore.Infrastructure; |
| 7 | +using Microsoft.EntityFrameworkCore.Metadata; |
| 8 | +using Microsoft.EntityFrameworkCore.Storage; |
| 9 | +using Microsoft.EntityFrameworkCore.Update; |
| 10 | +using System.Collections.Generic; |
| 11 | +using System.Data; |
| 12 | +using System.Linq; |
| 13 | +using System.Net; |
| 14 | +using System.Reflection; |
| 15 | +using System.Threading.Tasks; |
| 16 | + |
| 17 | +namespace EFCore.Cassandra.Bulk |
| 18 | +{ |
| 19 | + internal class SqlBulkOperation |
| 20 | + { |
| 21 | + public static void Insert<T>(DbContext dbContext, IList<T> entities) |
| 22 | + { |
| 23 | + var result = BuildBatch(dbContext, entities); |
| 24 | + result.session.Execute(result.batchStatement); |
| 25 | + } |
| 26 | + |
| 27 | + public static async Task InsertAsync<T>(DbContext dbContext, IList<T> entities) |
| 28 | + { |
| 29 | + var result = BuildBatch(dbContext, entities); |
| 30 | + await result.session.ExecuteAsync(result.batchStatement); |
| 31 | + } |
| 32 | + |
| 33 | + private static (BatchStatement batchStatement, ISession session) BuildBatch<T>(DbContext dbContext, IList<T> entities) |
| 34 | + { |
| 35 | + var service = dbContext.GetService<ICommandBatchPreparer>(); |
| 36 | + var sqlGenerationHelper = dbContext.GetService<ISqlGenerationHelper>(); |
| 37 | + var database = dbContext.Database.GetDbConnection() as CqlConnection; |
| 38 | + if (database.State != ConnectionState.Open) |
| 39 | + { |
| 40 | + database.Open(); |
| 41 | + } |
| 42 | + |
| 43 | + var prop = typeof(CqlConnection).GetField("ManagedConnection", BindingFlags.NonPublic | BindingFlags.Instance); |
| 44 | + var session = (ISession)prop.GetValue(database); |
| 45 | + var batch = new BatchStatement(); |
| 46 | + foreach (var entity in entities) |
| 47 | + { |
| 48 | + var name = entity.GetType().FullName; |
| 49 | + var entityType = dbContext.Model.FindEntityType(name); |
| 50 | + var propertyNames = new List<string>(); |
| 51 | + var propertyValues = new List<object>(); |
| 52 | + foreach (var property in entityType.GetProperties()) |
| 53 | + { |
| 54 | + propertyNames.Add(sqlGenerationHelper.DelimitIdentifier(property.GetColumnName())); |
| 55 | + var propValue = property.PropertyInfo.GetValue(entity); |
| 56 | + propValue = GetValue(property, propValue); |
| 57 | + propertyValues.Add(propValue); |
| 58 | + } |
| 59 | + |
| 60 | + var tableName = entityType.GetTableName(); |
| 61 | + var schema = entityType.GetSchema(); |
| 62 | + var cqlQuery = $"INSERT INTO \"{schema}\".\"{tableName}\" ({string.Join(',', propertyNames)}) VALUES ({string.Join(',', Enumerable.Repeat(1, propertyNames.Count()).Select(_ => "?"))})"; |
| 63 | + var smt = session.Prepare(cqlQuery); |
| 64 | + batch.Add(smt.Bind(propertyValues.ToArray())); |
| 65 | + } |
| 66 | + |
| 67 | + return (batch, session); |
| 68 | + } |
| 69 | + |
| 70 | + private static object GetValue(IProperty property, object value) |
| 71 | + { |
| 72 | + if(property.ClrType == typeof(IPAddress)) |
| 73 | + { |
| 74 | + if (value == null) |
| 75 | + { |
| 76 | + return new byte[0]; |
| 77 | + } |
| 78 | + |
| 79 | + return ((IPAddress)value).GetAddressBytes(); |
| 80 | + } |
| 81 | + |
| 82 | + return value; |
| 83 | + } |
| 84 | + } |
| 85 | +} |
0 commit comments