Skip to content

Commit c5caabb

Browse files
committed
allow resumable queries based on runtime conditions
Signed-off-by: George Lemon <georgelemon@protonmail.com>
1 parent 047539c commit c5caabb

2 files changed

Lines changed: 99 additions & 28 deletions

File tree

src/ozark/query.nim

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,34 @@ template withColumnsCheck(model: NimNode, cols: openArray[string], body) =
4242
discard
4343
body
4444

45-
proc ozarkSelectResult(sql: static[string]): NimNode {.compileTime.} = newLit(sql)
46-
proc ozarkWhereResult(sql: static[string], val: varargs[string]): NimNode {.compileTime.} = newLit(sql)
47-
proc ozarkWhereInResult(sql: static[string], vals: varargs[string]): NimNode {.compileTime.} = newLit(sql)
45+
proc ozarkSelectResult*(sql: static[string]): NimNode {.compileTime.} = newLit(sql)
46+
proc ozarkWhereResult*(sql: static[string], val: varargs[string]): NimNode {.compileTime.} = newLit(sql)
47+
proc ozarkWhereInResult*(sql: static[string], vals: varargs[string]): NimNode {.compileTime.} = newLit(sql)
4848
proc ozarkRawSQLResult(sql: static[string], vals: varargs[string]): NimNode {.compileTime.} = newLit(sql)
49-
proc ozarkInsertResult(sql: static[string], values: seq[string]): NimNode {.compileTime.} = newLit(sql)
50-
proc ozarkUpdateResult(sql: static[string], values: seq[string]): NimNode {.compileTime.} = newLit(sql)
51-
proc ozarkLimitResult(sql: static[string], count: int): NimNode {.compileTime.} = newLit(sql)
52-
proc ozarkOrderByResult(sql: static[string], col: string, desc: bool): NimNode {.compileTime.} = newLit(sql)
53-
proc ozarkCreateTableResult(sql: static[string]): NimNode {.compileTime.} = newLit(sql)
54-
proc ozarkRemoveResult(sql: static[string]): NimNode {.compileTime.} = newLit(sql)
55-
proc ozarkHoldModel[T](t: T) {.compileTime.} =
49+
proc ozarkInsertResult*(sql: static[string], values: seq[string]): NimNode {.compileTime.} = newLit(sql)
50+
proc ozarkUpdateResult*(sql: static[string], values: seq[string]): NimNode {.compileTime.} = newLit(sql)
51+
proc ozarkLimitResult*(sql: static[string]): NimNode {.compileTime.} = newLit(sql)
52+
proc ozarkOrderByResult*(sql: static[string], vals: varargs[string]): NimNode {.compileTime.} = newLit(sql)
53+
proc ozarkCreateTableResult*(sql: static[string]): NimNode {.compileTime.} = newLit(sql)
54+
proc ozarkRemoveResult*(sql: static[string]): NimNode {.compileTime.} = newLit(sql)
55+
proc ozarkHoldModel*[T](t: T) {.compileTime.} =
5656
var x: T
5757

58+
proc ozarkHoldModel*[T: typedesc](t: T) {.compileTime.} =
59+
var x: T
60+
61+
macro extractSQL*(sql: NimNode): untyped =
62+
## Extracts the SQL `NimNode` to string for use in code generation
63+
result = newLit(sql.repr)
64+
65+
macro fromSQL*(sql: untyped): untyped =
66+
## Macro to parse a resumed SQL string back into a NimNode for further manipulation
67+
## This is used in the `where` macros to allow chaining multiple clauses based on runtime computations.
68+
let sqlStrNode = sql.getImpl()
69+
var parseStmtNode = parseStmt(sqlStrNode[^1].strVal)
70+
parseStmtNode[0][1][0][1].insert(0, newEmptyNode())
71+
result = parseStmtNode[0]
72+
5873
template withColumnCheck(model: NimNode, col: string, body) =
5974
if col == "*":
6075
body # allow all columns, no need to check for existence
@@ -363,6 +378,11 @@ proc writeOrWhereStatement(op: static string,
363378
sql[1][^1][2][1].add(val)
364379
result = sql
365380

381+
template getSqlImpl() {.dirty.} =
382+
var sql = sql
383+
if sql.kind == nnkSym:
384+
sql = sql.getImpl()[2] # handle the case where the macro is called with a symbol instead of a block (e.g. `where x = 5` instead of `where(select(...), x = 5)`)
385+
366386
# WHERE clause public macros
367387
macro where*(sql: untyped, col: static string, val: untyped): untyped =
368388
## Define `WHERE` clause
@@ -499,7 +519,7 @@ macro getAll*(sql: untyped): untyped =
499519
"ozarkLimitResult", "ozarkOrderByResult",
500520
"ozarkSelectResult"
501521
]:
502-
error("The argument to `getAll` must be the result of a `where` macro.")
522+
error("The argument to `getAll` must be the result of a `where` macro.", sql)
503523
if sql[1][^1][0].strVal == "ozarkSelectResult":
504524
result = sql.parseSqlQuery("instantRows")
505525
else:
@@ -605,24 +625,46 @@ macro exists*(tableName: untyped) =
605625

606626
macro limit*(sql: untyped, count: untyped): untyped =
607627
## Placeholder for a `LIMIT` clause in SQL queries.
608-
if sql.kind != nnkCall or sql[0].strVal notin ["ozarkWhereResult", "ozarkRawSQLResult"]:
609-
error("The argument to `get` must be the result of a `where` macro.")
610-
result = newCall(
611-
bindSym"ozarkLimitResult",
612-
newLit(sql[1].strVal & " LIMIT ?"),
613-
count
614-
)
628+
if sql.kind != nnkBlockExpr or sql[1][1][0].strVal notin [
629+
"ozarkWhereResult", "ozarkRawSQLResult", "ozarkOrderByResult", "ozarkSelectResult"]:
630+
error("The argument to `limit` must be the result of a `select`, `where`", sql)
631+
let len = sql[1][^1][2][1].len + 1
632+
# sql[1][^1][1].strVal = sql[1][^1][1].strVal & " AND " & col & " " & op & " $" & $(len)
633+
sql[1][^1][^1][1].add(count) # add to the current varargs list
634+
sql[1][1] = newCall(
635+
bindSym"ozarkLimitResult",
636+
newLit(sql[1][1][1].strVal & " LIMIT $" & $(len))
637+
)
638+
result = sql
615639

616-
macro orderBy*(sql: untyped, col: static string, desc: static bool = false): untyped =
640+
type
641+
Order* = enum
642+
Desc, Asc
643+
644+
macro orderDescBy*(sql: untyped, cols: static openArray[string]): untyped =
617645
## Placeholder for an `ORDER BY` clause in SQL queries.
618-
if sql.kind != nnkCall or sql[0].strVal notin ["ozarkWhereResult"]:
619-
error("The argument to `orderBy` must be the result of a `where` macro.")
620-
withColumnCheck(sql[1][0][1], col):
621-
result =
622-
newCall(
623-
bindSym"ozarkOrderByResult",
624-
newLit(sql[1].strVal & " ORDER BY " & col & (if desc: " DESC" else: ""))
625-
)
646+
if sql[1][1].kind != nnkCall or sql[1][1][0].strVal notin ["ozarkWhereResult", "ozarkSelectResult"]:
647+
error("The argument to `orderDescBy` must be the result of a `where` macro.")
648+
withColumnsCheck(sql[1][0][1], cols):
649+
let blockIdent = genSym(nskLabel, "ozarkBlockOrderBy")
650+
var vals: seq[NimNode]
651+
var newCallNode = newCall(bindSym"ozarkOrderByResult")
652+
let totalParam =
653+
if sql[1][^1].len > 2 and sql[1][^1][2].kind == nnkHiddenStdConv:
654+
# if there are already parameters (e.g. from a WHERE IN clause), we need to calculate
655+
# the new parameter index based on the existing parameters
656+
newCallNode.add(sql[1][^1][2]) # add the existing parameters to the new call node
657+
sql[1][^1][2][1].len # the number of params inside a nnkBracket node
658+
else:
659+
0
660+
var idx: seq[int]
661+
for i, col in cols:
662+
newCallNode.add(newLit(col))
663+
idx.add(i + totalParam + 1) # +1 because SQL parameters are 1-indexed
664+
665+
newCallNode[1] = newLit(sql[1][1][1].strVal & " ORDER BY " & idx.mapIt("$" & $it).join(", "))
666+
sql[1][1] = newCallNode
667+
result = sql
626668

627669
macro rawSQL*(models: ptr ModelsTable, sql: static string, values: varargs[untyped]): untyped =
628670
## Allows raw SQL queries without losing safety of

tests/test2.nim

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,44 @@ suite "IN queries":
150150
.whereNotIn("name", "John Doe").get()
151151
check res.isEmpty == true
152152

153+
suite "Resumable Queries":
154+
test "chain where clauses based on runtime conditions":
155+
withDBPool do:
156+
var baseQuery = Models.table(Users).select("name").extractSQL()
157+
let filterByName = true
158+
if filterByName:
159+
baseQuery = baseQuery.fromSQL().where("name", "John Doe").extractSQL()
160+
161+
baseQuery = baseQuery.fromSQL().whereIn("email", "johndoe@example.com").extractSQL()
162+
let res = baseQuery.fromSQL().getAll()
163+
check res.isEmpty == false
164+
check res.get(0).name == "John Doe"
165+
166+
test "chain where clauses with whereNot based on runtime conditions":
167+
withDBPool do:
168+
var baseQuery = Models.table(Users).select("name").extractSQL()
169+
var emailAddress: string
170+
let filterByName = true
171+
if filterByName:
172+
emailAddress = "johndoe@example.com"
173+
baseQuery = baseQuery.fromSQL().whereNot("name", "Ghost").extractSQL()
174+
else:
175+
emailAddress = "none@example.com"
176+
177+
baseQuery = baseQuery.fromSQL().whereIn("email", emailAddress).extractSQL()
178+
let res = baseQuery.fromSQL().getAll()
179+
check res.isEmpty == false
180+
check res.get(0).name == "John Doe"
181+
153182
suite "RAW queries":
154183
test "raw where query":
155184
withDBPool do:
156185
let res = Models.rawSQL("SELECT name FROM users WHERE name = $1", "Alice")
157186
.getWith(Users)
158187
assert res.isEmpty
159-
160188
{.pop.}
161189

190+
162191
test "close embedded postgres":
163192
greskew.stop()
164193
greskew.dispose()

0 commit comments

Comments
 (0)