Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions spec/pg/replication_spec.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
require "../spec_helper"

struct TestMessageHandler
include PG::Replication::Handler

def received(msg : PG::Replication::Begin)
pp begin: msg
end

def received(msg : PG::Replication::Message)
pp message: msg
end

def received(msg : PG::Replication::Commit)
pp commit: msg
end

def received(msg : PG::Replication::Origin)
end

def received(msg : PG::Replication::Relation)
pp relation: msg
end

def received(msg : PG::Replication::Type)
end

def received(msg : PG::Replication::Insert)
pp insert: msg
end

def received(msg : PG::Replication::Update)
pp update: msg
end

def received(msg : PG::Replication::Delete)
pp delete: msg
end

def received(msg : PG::Replication::Truncate)
end

def received(msg : PG::Replication::StreamStart)
end

def received(msg : PG::Replication::StreamStop)
end

def received(msg : PG::Replication::StreamCommit)
end

def received(msg : PG::Replication::StreamAbort)
end

def received(msg : PG::Replication::BeginPrepare)
end

def received(msg : PG::Replication::Prepare)
end

def received(msg : PG::Replication::CommitPrepared)
end

def received(msg : PG::Replication::RollbackPrepared)
end

def received(msg : PG::Replication::StreamPrepare)
end

def received(msg : PG::Replication::TupleData)
end
end

describe PG::Replication do
it "consumes the WAL" do
publication_name = "test_publication_#{Random::Secure.hex}"
slot_name = "test_slot_#{Random::Secure.hex}"
table_name = "test_table_#{Random::Secure.hex}"
handler = TestMessageHandler.new
PG_DB.exec "CREATE PUBLICATION #{publication_name} FOR ALL TABLES"
PG_DB.exec "SELECT pg_create_logical_replication_slot($1, 'pgoutput')", slot_name
subscriber = PG.connect_replication(DB_URL, handler: handler, publication_name: publication_name, slot_name: slot_name)
sleep 100.milliseconds
PG_DB.exec "DROP TABLE IF EXISTS #{table_name}"
PG_DB.exec <<-SQL
CREATE TABLE #{table_name}(
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
string TEXT,
number INT8
)
SQL
id = PG_DB.query_one <<-SQL, as: UUID
INSERT INTO #{table_name} (string, number)
VALUES ('foo', 1)
RETURNING id
SQL

PG_DB.exec "UPDATE #{table_name} SET string = 'bar' WHERE id = $1", id
PG_DB.exec "UPDATE #{table_name} SET number = 2 WHERE id = $1", id

sleep 100.milliseconds

pp handler
ensure
subscriber.try &.close
PG_DB.exec "SELECT pg_drop_replication_slot($1)", slot_name
PG_DB.exec "DROP PUBLICATION IF EXISTS #{publication_name}"
PG_DB.exec "DROP TABLE IF EXISTS #{table_name}"
end
end
4 changes: 4 additions & 0 deletions src/pg.cr
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ module PG
ListenConnection.new(url, channels, blocking, &blk)
end

def self.connect_replication(url, *, handler, publication_name, slot_name)
Replication::Connection.new(url, handler, publication_name: publication_name, slot_name: slot_name)
end

class ListenConnection
@conn : PG::Connection

Expand Down
10 changes: 9 additions & 1 deletion src/pg/connection.cr
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ module PG
super(options)

begin
@connection.connect
@connection.connect(replication: @connection.conninfo.replication)
rescue ex
raise DB::ConnectionRefused.new(cause: ex)
end
Expand Down Expand Up @@ -95,6 +95,14 @@ module PG
end
end

protected def listen_replication(publication_name : String, slot_name : String, blocking : Bool = false, &block : Replication::Frame ->)
if blocking
@connection.start_replication_frame_loop(publication_name, slot_name, &block)
else
spawn { @connection.start_replication_frame_loop(publication_name, slot_name, &block) }
end
end

def version
vers = connection.server_parameters["server_version"].partition(' ').first.split('.').map(&.to_i)
{major: vers[0], minor: vers[1], patch: vers[2]? || 0}
Expand Down
Loading
Loading