From 763bdf398b70f2aef8bc8f2af78326a85b616278 Mon Sep 17 00:00:00 2001 From: "tejas.sp" <241722411+tejassp-db@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:32:50 +0530 Subject: [PATCH 1/3] Add SET TAGS profiling scripts and connector instrumentation Profiling scripts to compare direct ALTER SET TAGS vs reading from information_schema before writing. Includes scripts for column tags, table tags, information_schema reads, cleanup, and chart generation. Credentials are loaded from examples/credentials.env (gitignored). Copy credentials.env.example and fill in workspace details. Connector instrumentation adds [PROFILE] log lines in the Thrift backend retry loop and urllib3 retry policy to capture per-attempt timing, statement IDs, retry decisions, and SQL text. Co-authored-by: Isaac --- .gitignore | 3 + examples/PROFILING_README.md | 167 +++++ examples/cleanup_column_tags.py | 81 +++ examples/credentials.env.example | 9 + examples/load_credentials.py | 31 + examples/plot_comparison.py | 217 ++++++ examples/profile_column_tags.py | 674 ++++++++++++++++++ .../profile_read_then_write_table_tags.py | 545 ++++++++++++++ examples/profile_read_then_write_tags.py | 546 ++++++++++++++ examples/profile_table_tags.py | 573 +++++++++++++++ examples/test_connection.py | 48 ++ src/databricks/sql/auth/retry.py | 32 +- src/databricks/sql/backend/thrift_backend.py | 40 +- 13 files changed, 2962 insertions(+), 4 deletions(-) create mode 100644 examples/PROFILING_README.md create mode 100644 examples/cleanup_column_tags.py create mode 100644 examples/credentials.env.example create mode 100644 examples/load_credentials.py create mode 100644 examples/plot_comparison.py create mode 100644 examples/profile_column_tags.py create mode 100644 examples/profile_read_then_write_table_tags.py create mode 100644 examples/profile_read_then_write_tags.py create mode 100644 examples/profile_table_tags.py create mode 100644 examples/test_connection.py diff --git a/.gitignore b/.gitignore index 2ae38dbc6..9d02d5775 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ +# Profiling credentials +examples/credentials.env + # Created by https://www.toptal.com/developers/gitignore/api/python,macos # Edit at https://www.toptal.com/developers/gitignore?templates=python,macos diff --git a/examples/PROFILING_README.md b/examples/PROFILING_README.md new file mode 100644 index 000000000..15226c9e9 --- /dev/null +++ b/examples/PROFILING_README.md @@ -0,0 +1,167 @@ +# SET TAGS Profiling Scripts + +## Context + +These scripts measure the performance of two approaches for managing tags on Databricks tables and columns: + +**Approach A (direct ALTER)**: Call `ALTER TABLE SET TAGS` or `ALTER TABLE ALTER COLUMN SET TAGS` directly, overwriting any existing tags without reading them first. SET TAGS is idempotent — setting the same key overwrites the value. + +**Approach B (read then write)**: First query `system.information_schema.column_tags` or `system.information_schema.table_tags` to read existing tags, compute a diff, then issue ALTERs only for changes. + +The goal is to determine whether the information_schema read step is worth the cost, or whether direct ALTER is faster even though it may redundantly set unchanged tags. + +## Prerequisites + +- Python 3.x with `databricks-sql-connector` installed (`pip install -e .` from repo root) +- A Databricks SQL warehouse +- 64 tables (`table1` through `table64`) with 128 STRING columns each. Create them by running `profile_column_tags.py` without `--skip-setup`. +- Credentials in `examples/credentials.env` (gitignored). Copy and edit: + +```bash +cp examples/credentials.env.example examples/credentials.env +# Edit credentials.env with your workspace details +``` + +The file format is: +``` +SERVER_HOSTNAME=your-workspace.cloud.databricks.com +HTTP_PATH=/sql/1.0/warehouses/your_warehouse_id +ACCESS_TOKEN=your_token +CATALOG=your_catalog +SCHEMA=your_schema +``` + +All scripts read from this file via `load_credentials.py`. To switch workspaces, just edit `credentials.env`. + +## Scripts + +### profile_column_tags.py — Direct ALTER column tags + +Sets tags on columns directly via ALTER statements. No information_schema reads. + +```bash +# Create tables + validate +python examples/profile_column_tags.py --columns 1 --tags 1 --threads 1 --iterations 1 --validate + +# Full experiment: 100 columns, 9 tags each, 8 threads, 3 iterations +python examples/profile_column_tags.py --columns 100 --tags 9 --threads 8 --iterations 3 --skip-setup +``` + +Arguments: +- `--columns`: Number of columns to tag per table +- `--tags`: Number of tags per ALTER command +- `--threads`: Concurrent connections +- `--iterations`: Times to repeat the full 64-table sweep +- `--skip-setup`: Skip table creation +- `--validate`: Force 1 iteration for quick validation + +Output: `examples/results/column_tags/` + +### profile_table_tags.py — Direct ALTER table tags + +Sets tags on tables directly via ALTER statements. One ALTER per table. + +```bash +python examples/profile_table_tags.py --tags 1 --threads 8 --iterations 3 +``` + +Arguments: +- `--tags`: Number of tags per ALTER command +- `--threads`: Concurrent connections +- `--iterations`: Times to repeat the full 64-table sweep +- `--validate`: Force 1 iteration + +Output: `examples/results/table_tags/` + +### profile_read_then_write_tags.py — information_schema column_tags SELECT + +Queries `system.information_schema.column_tags` for each table. No ALTER — measures the read cost only. + +```bash +python examples/profile_read_then_write_tags.py --threads 1 --iterations 3 +``` + +Arguments: +- `--threads`: Concurrent connections +- `--iterations`: Times to repeat the full 64-table sweep +- `--validate`: Force 1 iteration + +Output: `examples/results/read_then_write/` + +### profile_read_then_write_table_tags.py — information_schema table_tags SELECT + +Queries `system.information_schema.table_tags` for each table. No ALTER — measures the read cost only. + +```bash +python examples/profile_read_then_write_table_tags.py --threads 1 --iterations 3 +``` + +Arguments: +- `--threads`: Concurrent connections +- `--iterations`: Times to repeat the full 64-table sweep +- `--validate`: Force 1 iteration + +Output: `examples/results/read_then_write_table_tags/` + +### cleanup_column_tags.py — Remove all tags + +Removes all column tags and table tags from all 64 tables using 32 threads. Run this to reset state between experiments. + +```bash +python examples/cleanup_column_tags.py +``` + +### plot_comparison.py — Generate charts + +Reads all report files and generates comparison charts as PNGs. + +```bash +pip install matplotlib +python examples/plot_comparison.py +``` + +Output: +- `examples/results/comparison_column_tags.png` — column tags: ALTER vs info_schema +- `examples/results/comparison_table_tags.png` — table tags: ALTER vs info_schema + +Each PNG has 4 charts: wall-clock time, throughput, P50 latency, P99 latency, all plotted against thread count. + +## Running the definitive experiment + +```bash +# Step 1: Create tables (once) +python examples/profile_column_tags.py --columns 1 --tags 1 --threads 1 --iterations 1 --validate + +# Step 2: Run info_schema reads across thread counts +# Stop early if latency is already unacceptable +for n in 1 2 4 8 16 32 64; do + python examples/profile_read_then_write_tags.py --threads $n --iterations 3 +done + +# Step 3: Run direct ALTERs across thread counts +for n in 1 2 4 8 16 32 64; do + python examples/profile_column_tags.py --columns 100 --tags 9 --threads $n --iterations 3 --skip-setup +done + +# Step 4: Generate charts +python examples/plot_comparison.py +``` + +## Connector instrumentation + +The scripts capture retry behavior via `[PROFILE]` log lines added to the connector: +- `src/databricks/sql/backend/thrift_backend.py` — logs per-attempt timing, success, statement IDs, and retry sleeps in `make_request()` +- `src/databricks/sql/auth/retry.py` — logs urllib3-level retry decisions (`should_retry`) and sleep durations with HTTP status codes, Thrift method names, and SQL text + +These are written to `*_retries.log` files alongside each report. Use `grep "[PROFILE]"` to filter. + +## Output structure + +Each script run produces three files: +- `*_report.md` — Markdown report with latency percentiles, throughput, error analysis, retry analysis +- `*_data.jsonl` — Raw per-operation data (one JSON line per ALTER or SELECT) +- `*_retries.log` — Full connector debug logs with `[PROFILE]` instrumentation + +## Key finding + +On Azure workspaces, `system.information_schema.column_tags` queries can take 60-110 seconds under concurrency due to server-side queuing (visible as repeated `GetOperationStatus` polling in logs). Direct ALTER SET TAGS consistently completes in ~500ms regardless of concurrency. The information_schema read alone is slower than performing all the writes it was meant to optimize. \ No newline at end of file diff --git a/examples/cleanup_column_tags.py b/examples/cleanup_column_tags.py new file mode 100644 index 000000000..e816d0a40 --- /dev/null +++ b/examples/cleanup_column_tags.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +"""Remove all column tags and table tags from all 64 tables using 32 threads.""" + +import sys +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed + +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql +from load_credentials import load_credentials + +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] + +NUM_TABLES = 64 +NUM_THREADS = 32 + + +def cleanup_table(table_name): + table_fqn = f"`{CATALOG}`.`{SCHEMA}`.{table_name}" + total_removed = 0 + + with sql.connect( + server_hostname=SERVER_HOSTNAME, + http_path=HTTP_PATH, + access_token=ACCESS_TOKEN, + _tls_no_verify=True, + ) as conn: + with conn.cursor() as cursor: + # --- Clean up column tags --- + cursor.execute( + f"SELECT column_name, tag_name FROM system.information_schema.column_tags " + f"WHERE catalog_name = '{CATALOG}' AND schema_name = '{SCHEMA}' AND table_name = '{table_name}'" + ) + col_rows = cursor.fetchall() + + if col_rows: + col_tags = defaultdict(list) + for row in col_rows: + col_tags[row[0]].append(row[1]) + + for col, tags in col_tags.items(): + tag_list = ", ".join(f"'{tag}'" for tag in tags) + cursor.execute(f"ALTER TABLE {table_fqn} ALTER COLUMN {col} UNSET TAGS ({tag_list})") + + total_removed += len(col_rows) + + # --- Clean up table tags --- + cursor.execute( + f"SELECT tag_name FROM system.information_schema.table_tags " + f"WHERE catalog_name = '{CATALOG}' AND schema_name = '{SCHEMA}' AND table_name = '{table_name}'" + ) + tbl_rows = cursor.fetchall() + + if tbl_rows: + tag_list = ", ".join(f"'{row[0]}'" for row in tbl_rows) + cursor.execute(f"ALTER TABLE {table_fqn} UNSET TAGS ({tag_list})") + total_removed += len(tbl_rows) + + print(f"{table_name}: removed {len(col_rows)} column tags, {len(tbl_rows)} table tags") + return total_removed + + +total_removed = 0 +with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + futures = { + executor.submit(cleanup_table, f"table{t}"): t + for t in range(1, NUM_TABLES + 1) + } + for f in as_completed(futures): + total_removed += f.result() + +print(f"\nDone. Removed {total_removed} total tags (column + table).") diff --git a/examples/credentials.env.example b/examples/credentials.env.example new file mode 100644 index 000000000..406417236 --- /dev/null +++ b/examples/credentials.env.example @@ -0,0 +1,9 @@ +# SET TAGS Profiling — Workspace Credentials +# Copy this file to credentials.env and fill in your values. +# credentials.env is gitignored and will not be committed. + +SERVER_HOSTNAME=your-workspace.cloud.databricks.com +HTTP_PATH=/sql/1.0/warehouses/your_warehouse_id +ACCESS_TOKEN=your_access_token +CATALOG=your_catalog +SCHEMA=your_schema diff --git a/examples/load_credentials.py b/examples/load_credentials.py new file mode 100644 index 000000000..ef919cd8a --- /dev/null +++ b/examples/load_credentials.py @@ -0,0 +1,31 @@ +"""Load credentials from examples/credentials.env""" + +import os + + +def load_credentials(env_path=None): + """Read credentials.env and return a dict of key=value pairs.""" + if env_path is None: + env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "credentials.env") + + if not os.path.exists(env_path): + raise FileNotFoundError( + f"Credentials file not found: {env_path}\n" + f"Copy examples/credentials.env.example to examples/credentials.env and fill in your values." + ) + + creds = {} + with open(env_path) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + key, _, value = line.partition("=") + creds[key.strip()] = value.strip() + + required = ["SERVER_HOSTNAME", "HTTP_PATH", "ACCESS_TOKEN", "CATALOG", "SCHEMA"] + missing = [k for k in required if k not in creds] + if missing: + raise ValueError(f"Missing required credentials: {', '.join(missing)}") + + return creds diff --git a/examples/plot_comparison.py b/examples/plot_comparison.py new file mode 100644 index 000000000..2f22cfd82 --- /dev/null +++ b/examples/plot_comparison.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +""" +Plot all profiling results: info_schema SELECTs vs direct ALTERs. + +Auto-discovers all report MD files across all result directories. +Generates SEPARATE PNGs for column tags and table tags. +Each has 4 charts: wall-clock, throughput, P50, P99. + +Usage: + python examples/plot_comparison.py +""" + +import re +import os +import sys +from collections import defaultdict + +sys.stdout.reconfigure(line_buffering=True) + +try: + import matplotlib.pyplot as plt +except ImportError: + print("Install matplotlib: pip install matplotlib") + sys.exit(1) + +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results") + + +def parse_report(filepath): + """Extract key metrics from a report MD file.""" + metrics = {} + with open(filepath) as f: + content = f.read() + + m = re.search(r"\*\*Total wall-clock\*\*:\s*([\d.]+)s", content) + if m: + metrics["wall_clock_s"] = float(m.group(1)) + + m = re.search(r"\*\*(ALTERs/sec|SELECTs/sec|Operations/sec)\*\*:\s*([\d.]+)", content) + if m: + metrics["throughput"] = float(m.group(2)) + + for pct in ["p50", "p90", "p95", "p99"]: + m = re.search(rf"\|\s*{pct}\s*\|\s*([\d.]+)\s*\|", content) + if m: + metrics[pct] = float(m.group(1)) + + m = re.search(r"\|\s*max\s*\|\s*([\d.]+)\s*\|", content) + if m: + metrics["max"] = float(m.group(1)) + + m = re.search(r"\|\s*count\s*\|\s*([\d.]+)\s*\|", content) + if m: + metrics["count"] = int(float(m.group(1))) + + m = re.search(r"\*\*Threads\*\*:\s*(\d+)", content) + if m: + metrics["threads"] = int(m.group(1)) + + m = re.search(r"\*\*Iterations\*\*:\s*(\d+)", content) + if m: + metrics["iterations"] = int(m.group(1)) + + m = re.search(r"\*\*Columns tagged per table\*\*:\s*(\d+)", content) + if m: + metrics["columns"] = int(m.group(1)) + + m = re.search(r"\*\*Tags per ALTER\*\*:\s*(\d+)", content) + if m: + metrics["tags"] = int(m.group(1)) + + return metrics + + +def classify_report(dirpath, filename): + """Classify a report: (category, type) where category is 'column' or 'table'.""" + dirpath_lower = dirpath.lower() + filename_lower = filename.lower() + + if "read_then_write_table_tags" in dirpath_lower or filename_lower.startswith("rwtt_"): + return "table", "info_schema" + elif "read_then_write" in dirpath_lower or filename_lower.startswith("rw_"): + return "column", "info_schema" + elif "table_tags" in dirpath_lower or filename_lower.startswith("tt_"): + return "table", "alter" + elif "column_tags" in dirpath_lower or filename_lower.startswith("c"): + return "column", "alter" + else: + return "unknown", "unknown" + + +def discover_reports(): + """Walk all result directories and collect report data, split by category.""" + # {category: {series_label: {threads: metrics}}} + categories = defaultdict(lambda: defaultdict(dict)) + + for dirpath, _, filenames in os.walk(RESULTS_DIR): + for fname in sorted(filenames): + if not fname.endswith("_report.md"): + continue + + filepath = os.path.join(dirpath, fname) + metrics = parse_report(filepath) + + if "threads" not in metrics: + continue + + category, report_type = classify_report(dirpath, fname) + if category == "unknown": + continue + + threads = metrics["threads"] + + if report_type == "alter" and category == "column": + cols = metrics.get("columns", "?") + tags = metrics.get("tags", "?") + label = f"ALTER column tags (c={cols}, t={tags})" + elif report_type == "alter" and category == "table": + tags = metrics.get("tags", "?") + label = f"ALTER table tags (t={tags})" + elif report_type == "info_schema" and category == "column": + label = "info_schema column_tags SELECT" + elif report_type == "info_schema" and category == "table": + label = "info_schema table_tags SELECT" + else: + continue + + # Keep the one with more iterations + existing = categories[category][label].get(threads) + if existing and metrics.get("iterations", 0) <= existing.get("iterations", 0): + continue + + categories[category][label][threads] = metrics + print(f" [{category}] {label} threads={threads}: " + f"wall={metrics.get('wall_clock_s', '?')}s, " + f"p50={metrics.get('p50', '?')}ms, " + f"throughput={metrics.get('throughput', '?')} ops/s " + f"[{fname}]") + + return categories + + +def plot_category(category_name, series, output_path): + """Generate a 2x2 chart PNG for one category (column or table).""" + if not series: + print(f" No data for {category_name}, skipping.") + return + + # Color/style assignment + colors_info = ["#d62728", "#ff7f0e"] + colors_alter = ["#1f77b4", "#2ca02c", "#9467bd", "#17becf", "#8c564b"] + info_idx = 0 + alter_idx = 0 + style_map = {} + + for label in sorted(series.keys()): + if "info_schema" in label: + style_map[label] = {"color": colors_info[info_idx % len(colors_info)], "marker": "o", "linestyle": "--"} + info_idx += 1 + else: + style_map[label] = {"color": colors_alter[alter_idx % len(colors_alter)], "marker": "s", "linestyle": "-"} + alter_idx += 1 + + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + + chart_configs = [ + (axes[0][0], "wall_clock_s", "Wall-Clock Time (seconds)", "Wall-Clock Time vs Thread Count"), + (axes[0][1], "throughput", "Operations / second", "Throughput vs Thread Count"), + (axes[1][0], "p50", "P50 Latency (ms)", "P50 Latency vs Thread Count"), + (axes[1][1], "p99", "P99 Latency (ms)", "P99 Latency vs Thread Count"), + ] + + for ax, metric_key, ylabel, title in chart_configs: + for label, thread_data in sorted(series.items()): + threads = sorted(thread_data.keys()) + values = [thread_data[t].get(metric_key) for t in threads] + if any(v is not None for v in values): + s = style_map[label] + ax.plot(threads, values, marker=s["marker"], linestyle=s["linestyle"], + color=s["color"], linewidth=2, label=label, markersize=8) + ax.set_xlabel("Thread Count") + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + title_label = "Column Tags" if category_name == "column" else "Table Tags" + plt.suptitle(f"SET TAGS Profiling: {title_label} — info_schema SELECT vs Direct ALTER", + fontsize=14, fontweight="bold") + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Chart saved to: {output_path}") + + +if __name__ == "__main__": + print("Discovering results...\n") + categories = discover_reports() + + total_series = sum(len(v) for v in categories.values()) + total_points = sum(len(td) for cat in categories.values() for td in cat.values()) + print(f"\nFound {total_series} series across {total_points} data points.\n") + + if "column" in categories: + print("Generating column tags chart...") + plot_category("column", categories["column"], + os.path.join(RESULTS_DIR, "comparison_column_tags.png")) + + if "table" in categories: + print("Generating table tags chart...") + plot_category("table", categories["table"], + os.path.join(RESULTS_DIR, "comparison_table_tags.png")) + + if not categories: + print("No results found. Run experiments first.") + else: + print("\nDrag the PNGs into Google Docs.") diff --git a/examples/profile_column_tags.py b/examples/profile_column_tags.py new file mode 100644 index 000000000..3f779f0ac --- /dev/null +++ b/examples/profile_column_tags.py @@ -0,0 +1,674 @@ +#!/usr/bin/env python3 +""" +Profile SET COLUMN TAGS performance on Databricks. + +Usage: + # Quick validation (1 col x 1 tag x 1 thread x 1 iteration = 20 ALTERs) + python examples/profile_column_tags.py --columns 1 --tags 1 --threads 1 --iterations 1 --validate + + # Single experiment + python examples/profile_column_tags.py --columns 2 --tags 4 --threads 8 --iterations 10 + + # Full sweep + for c in 1 2 4; do + for t in 1 2 4; do + for n in 1 2 4 8 16; do + python examples/profile_column_tags.py --columns $c --tags $t --threads $n --iterations 10 + done + done + done +""" + +import argparse +import json +import logging +import os +import random +import re +import statistics +import string +import sys +import threading +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from queue import Empty, Queue + +# Force unbuffered stdout so output is visible when piped through grep +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql + +# ============================================================ +# CONFIGURATION — loaded from examples/credentials.env +# ============================================================ +from load_credentials import load_credentials +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] +# ============================================================ + +NUM_TABLES = 64 +MAX_COLUMNS = 128 # tables always created with this many columns +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "column_tags") + + +# --------------------------------------------------------------------------- +# Logging setup +# --------------------------------------------------------------------------- + +class ProfileLogHandler(logging.Handler): + """Captures [PROFILE] log lines for retry analysis.""" + + def __init__(self): + super().__init__() + self.records: list = [] + + def emit(self, record): + msg = record.getMessage() + if "[PROFILE]" in msg: + self.records.append( + { + "timestamp": record.created, + "thread": record.threadName, + "message": msg, + } + ) + + +def setup_logging(log_path: str) -> ProfileLogHandler: + """Configure logging: file handler for all connector logs, profile handler for [PROFILE] lines.""" + profile_handler = ProfileLogHandler() + profile_handler.setLevel(logging.INFO) + + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s") + ) + + for logger_name in [ + "databricks.sql.backend.thrift_backend", + "databricks.sql.auth.retry", + "databricks.sql.client", + ]: + lgr = logging.getLogger(logger_name) + lgr.setLevel(logging.DEBUG) + lgr.addHandler(profile_handler) + lgr.addHandler(file_handler) + + return profile_handler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def conn_params() -> dict: + return { + "server_hostname": SERVER_HOSTNAME, + "http_path": HTTP_PATH, + "access_token": ACCESS_TOKEN, + "_tls_no_verify": True, + } + + +def random_tag_value(length: int = 5) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) + + +def build_alter_sql(table_fqn: str, column_name: str, num_tags: int) -> str: + tags = ", ".join(f"'key{i}' = '{random_tag_value()}'" for i in range(1, num_tags + 1)) + return f"ALTER TABLE {table_fqn} ALTER COLUMN {column_name} SET TAGS ({tags})" + + +def percentile(data: list, p: float) -> float: + """Return the p-th percentile (0-100) of data.""" + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (p / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[f] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def latency_stats(latencies: list) -> dict: + """Compute full latency statistics for a list of ms values.""" + if not latencies: + return {k: 0.0 for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]} + return { + "count": len(latencies), + "min": min(latencies), + "max": max(latencies), + "mean": statistics.mean(latencies), + "stdev": statistics.stdev(latencies) if len(latencies) > 1 else 0.0, + "p50": percentile(latencies, 50), + "p90": percentile(latencies, 90), + "p95": percentile(latencies, 95), + "p99": percentile(latencies, 99), + } + + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- + +def setup_tables(): + """Create NUM_TABLES tables with MAX_COLUMNS STRING columns each.""" + print(f"Setting up {NUM_TABLES} tables with {MAX_COLUMNS} columns each...") + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + cursor.execute(f"USE CATALOG `{CATALOG}`") + cursor.execute(f"USE SCHEMA `{SCHEMA}`") + for t in range(1, NUM_TABLES + 1): + cols = ", ".join(f"column{c} STRING" for c in range(1, MAX_COLUMNS + 1)) + ddl = f"CREATE TABLE IF NOT EXISTS table{t} ({cols})" + print(f" Creating table{t}...", end=" ", flush=True) + cursor.execute(ddl) + print("done") + print("Setup complete.\n") + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + +def worker( + thread_id: int, + table_queue: Queue, + num_columns: int, + num_tags: int, + alter_results: list, + table_results: list, + results_lock: threading.Lock, +): + """Worker thread: pulls tables from queue, ALTERs all columns, records metrics.""" + local_alter_results = [] + local_table_results = [] + table_fqn_prefix = f"`{CATALOG}`.`{SCHEMA}`" + + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + # Warmup + cursor.execute("SELECT 1") + + while True: + try: + table_name = table_queue.get_nowait() + except Empty: + break + + table_fqn = f"{table_fqn_prefix}.{table_name}" + table_start = time.perf_counter() + table_errors = 0 + + for c in range(1, num_columns + 1): + column_name = f"column{c}" + alter_sql = build_alter_sql(table_fqn, column_name, num_tags) + + cmd_start = time.perf_counter() + success = True + error_type = None + error_message = None + error_context = None + + try: + cursor.execute(alter_sql) + except Exception as e: + success = False + error_type = type(e).__name__ + error_message = str(e)[:500] + error_context = getattr(e, "context", None) + table_errors += 1 + + cmd_end = time.perf_counter() + latency_ms = (cmd_end - cmd_start) * 1000 + + local_alter_results.append( + { + "table": table_name, + "column": column_name, + "thread_id": thread_id, + "latency_ms": round(latency_ms, 2), + "success": success, + "error_type": error_type, + "error_message": error_message, + "error_context": str(error_context) if error_context else None, + "timestamp": cmd_start, + } + ) + + table_end = time.perf_counter() + table_latency_ms = (table_end - table_start) * 1000 + + local_table_results.append( + { + "table": table_name, + "thread_id": thread_id, + "latency_ms": round(table_latency_ms, 2), + "num_alters": num_columns, + "num_errors": table_errors, + "alters_per_sec": round(num_columns / (table_latency_ms / 1000), 2) + if table_latency_ms > 0 + else 0, + } + ) + + with results_lock: + alter_results.extend(local_alter_results) + table_results.extend(local_table_results) + + +# --------------------------------------------------------------------------- +# Run one iteration +# --------------------------------------------------------------------------- + +def run_iteration( + iteration: int, + num_columns: int, + num_tags: int, + num_threads: int, +) -> tuple: + """Run a single iteration: distribute 20 tables across threads.""" + table_queue = Queue() + for t in range(1, NUM_TABLES + 1): + table_queue.put(f"table{t}") + + alter_results: list = [] + table_results: list = [] + results_lock = threading.Lock() + + iter_start = time.perf_counter() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + for tid in range(num_threads): + f = executor.submit( + worker, + tid, + table_queue, + num_columns, + num_tags, + alter_results, + table_results, + results_lock, + ) + futures.append(f) + + for f in as_completed(futures): + f.result() # raise any thread exceptions + + iter_end = time.perf_counter() + duration_s = iter_end - iter_start + + # Tag each result with iteration number + for r in alter_results: + r["iteration"] = iteration + for r in table_results: + r["iteration"] = iteration + + return alter_results, table_results, duration_s + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- + +def generate_report( + args, + all_alter_results: list, + all_table_results: list, + iteration_durations: list, + profile_handler: ProfileLogHandler, + report_path: str, +): + """Generate the markdown report.""" + lines = [] + + def w(text=""): + lines.append(text) + + total_alters = len(all_alter_results) + total_duration = sum(iteration_durations) + successful = [r for r in all_alter_results if r["success"]] + failed = [r for r in all_alter_results if not r["success"]] + success_latencies = [r["latency_ms"] for r in successful] + all_latencies = [r["latency_ms"] for r in all_alter_results] + + # --- Header --- + w(f"# Profile: C={args.columns}, T={args.tags}, N={args.threads}, I={args.iterations}") + w() + w("## Configuration") + w(f"- **Server**: `{SERVER_HOSTNAME}`") + w(f"- **HTTP Path**: `{HTTP_PATH}`") + w(f"- **Catalog.Schema**: `{CATALOG}.{SCHEMA}`") + w(f"- **Tables**: {NUM_TABLES}") + w(f"- **Columns tagged per table**: {args.columns}") + w(f"- **Tags per ALTER**: {args.tags}") + w(f"- **Threads**: {args.threads}") + w(f"- **Iterations**: {args.iterations}") + w(f"- **Total ALTERs**: {total_alters}") + w(f"- **Date**: {datetime.now().isoformat()}") + w() + + # --- Overall ALTER Latency --- + w("## Per-ALTER Latency — All Iterations (ms)") + w() + stats = latency_stats(success_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k, v in stats.items(): + w(f"| {k} | {v:.2f} |") + w() + + # --- Throughput --- + w("## Throughput") + w() + w(f"- **Total ALTERs**: {total_alters}") + w(f"- **Successful**: {len(successful)}") + w(f"- **Failed**: {len(failed)}") + w(f"- **Total wall-clock**: {total_duration:.2f}s") + if total_duration > 0: + w(f"- **ALTERs/sec**: {total_alters / total_duration:.2f}") + w() + + # --- Cold Start vs Steady State --- + if args.iterations > 1: + w("## Cold Start vs Steady State") + w() + iter1 = [r["latency_ms"] for r in successful if r["iteration"] == 1] + iter_rest = [r["latency_ms"] for r in successful if r["iteration"] > 1] + w("| Phase | ALTERs | Mean (ms) | P50 (ms) | P99 (ms) |") + w("|-------|--------|-----------|----------|----------|") + s1 = latency_stats(iter1) + sr = latency_stats(iter_rest) + w(f"| Iteration 1 | {s1['count']:.0f} | {s1['mean']:.2f} | {s1['p50']:.2f} | {s1['p99']:.2f} |") + w(f"| Iterations 2-{args.iterations} | {sr['count']:.0f} | {sr['mean']:.2f} | {sr['p50']:.2f} | {sr['p99']:.2f} |") + w() + + # --- Per-Iteration Summary --- + w("## Per-Iteration Summary") + w() + w("| Iteration | ALTERs | Mean (ms) | P50 (ms) | P99 (ms) | Errors | Duration (s) | ALTERs/sec |") + w("|-----------|--------|-----------|----------|----------|--------|--------------|------------|") + for i in range(1, args.iterations + 1): + iter_lats = [r["latency_ms"] for r in successful if r["iteration"] == i] + iter_errs = len([r for r in failed if r["iteration"] == i]) + s = latency_stats(iter_lats) + dur = iteration_durations[i - 1] + alters_in_iter = len([r for r in all_alter_results if r["iteration"] == i]) + rps = alters_in_iter / dur if dur > 0 else 0 + w( + f"| {i} | {s['count']:.0f} | {s['mean']:.2f} | {s['p50']:.2f} | {s['p99']:.2f} " + f"| {iter_errs} | {dur:.2f} | {rps:.2f} |" + ) + w() + + # --- Per-ALTER Latency by Table --- + w("## Per-ALTER Latency by Table (ms)") + w() + w("| Table | Count | Min | Max | Mean | P50 | P90 | P95 | P99 |") + w("|-------|-------|-----|-----|------|-----|-----|-----|-----|") + tables_seen = sorted(set(r["table"] for r in successful), key=lambda x: int(x.replace("table", ""))) + for tbl in tables_seen: + tbl_lats = [r["latency_ms"] for r in successful if r["table"] == tbl] + s = latency_stats(tbl_lats) + w( + f"| {tbl} | {s['count']:.0f} | {s['min']:.2f} | {s['max']:.2f} " + f"| {s['mean']:.2f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p95']:.2f} | {s['p99']:.2f} |" + ) + w() + + # --- Per-ALTER Latency by Thread --- + w("## Per-ALTER Latency by Thread (ms)") + w() + w("| Thread | Count | Min | Max | Mean | P50 | P90 | P95 | P99 |") + w("|--------|-------|-----|-----|------|-----|-----|-----|-----|") + threads_seen = sorted(set(r["thread_id"] for r in successful)) + for tid in threads_seen: + thr_lats = [r["latency_ms"] for r in successful if r["thread_id"] == tid] + s = latency_stats(thr_lats) + w( + f"| {tid} | {s['count']:.0f} | {s['min']:.2f} | {s['max']:.2f} " + f"| {s['mean']:.2f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p95']:.2f} | {s['p99']:.2f} |" + ) + w() + + # --- Per-Table Latency (all columns in one table) --- + w("## Per-Table Latency — Time to Tag All Columns in One Table (ms)") + w() + w("| Table | Iteration | Thread | Latency (ms) | ALTERs/sec | Errors |") + w("|-------|-----------|--------|--------------|------------|--------|") + for r in sorted(all_table_results, key=lambda x: (x["iteration"], int(x["table"].replace("table", "")))): + w( + f"| {r['table']} | {r['iteration']} | {r['thread_id']} " + f"| {r['latency_ms']:.2f} | {r['alters_per_sec']:.2f} | {r['num_errors']} |" + ) + w() + + # --- Per-Table Aggregate Stats --- + w("## Per-Table Aggregate Stats (ms)") + w() + table_latencies = [r["latency_ms"] for r in all_table_results] + s = latency_stats(table_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k, v in s.items(): + w(f"| {k} | {v:.2f} |") + w() + + # --- Error Analysis --- + w("## Error Analysis") + w() + if not failed: + w("No errors encountered.") + else: + error_groups = defaultdict(list) + for r in failed: + error_groups[r["error_type"]].append(r) + w("| Error Type | Count | % of Total | Sample Message |") + w("|------------|-------|------------|----------------|") + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + pct = len(records) / total_alters * 100 + sample = records[0]["error_message"][:200] if records[0]["error_message"] else "N/A" + w(f"| {etype} | {len(records)} | {pct:.1f}% | {sample} |") + w() + + w("### Error Detail") + w() + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + w(f"**{etype}** ({len(records)} occurrences)") + w() + # Show up to 3 samples + for r in records[:3]: + w(f"- Table: {r['table']}, Column: {r['column']}, Iteration: {r['iteration']}") + w(f" Latency: {r['latency_ms']:.2f}ms") + w(f" Message: {r['error_message']}") + if r["error_context"]: + w(f" Context: {r['error_context']}") + if len(records) > 3: + w(f"- ... and {len(records) - 3} more") + w() + w() + + # --- Retry Analysis --- + ATTEMPT_RE = re.compile(r"\[PROFILE\] (?P\w+) attempt (?P\d+)/(?P\d+)") + SUCCESS_RE = re.compile(r"\[PROFILE\] (?P\w+) succeeded on attempt (?P\d+) in (?P[0-9.]+)s") + SHOULD_RETRY_RE = re.compile(r"\[PROFILE\] should_retry: status=(?P\d+), command=(?P[^,]+),") + RETRY_SLEEP_RE = re.compile(r"\[PROFILE\] (?P\w+) retry sleep=(?P[0-9.]+)s, attempt=(?P\d+)/(?P\d+)") + + parsed_events = [] + for r in profile_handler.records: + msg = r["message"] + for etype, regex in [("attempt", ATTEMPT_RE), ("success", SUCCESS_RE), + ("should_retry", SHOULD_RETRY_RE), ("retry_sleep", RETRY_SLEEP_RE)]: + m = regex.search(msg) + if m: + event = {"type": etype, "thread": r["thread"], "timestamp": r["timestamp"], "message": msg} + event.update(m.groupdict()) + if "attempt" in event: + event["attempt"] = int(event["attempt"]) + if "status" in event: + event["status"] = int(event["status"]) + parsed_events.append(event) + break + + # Filter to ExecuteStatement only (excludes OpenSession, CloseSession, GetOperationStatus) + exec_events = [e for e in parsed_events if e.get("cmd") == "ExecuteStatement"] + exec_retry_sleeps = [e for e in exec_events if e["type"] == "retry_sleep"] + exec_should_retry = [e for e in exec_events if e["type"] == "should_retry"] + exec_success_after_retry = [e for e in exec_events if e["type"] == "success" and e["attempt"] > 1] + exec_total_attempts = [e for e in exec_events if e["type"] == "attempt"] + exec_successes = [e for e in exec_events if e["type"] == "success"] + + w("## Statement Retry Analysis (ExecuteStatement only)") + w() + w("*Includes benchmarked ALTERs + one warmup SELECT 1 per worker thread.*") + w() + w(f"- **Total [PROFILE] events (all commands)**: {len(parsed_events)}") + w(f"- **ExecuteStatement attempts**: {len(exec_total_attempts)}") + w(f"- **ExecuteStatement successes**: {len(exec_successes)}") + w(f"- **ExecuteStatement retry sleeps**: {len(exec_retry_sleeps)}") + w(f"- **ExecuteStatement succeeded after retry (attempt > 1)**: {len(exec_success_after_retry)}") + w(f"- **should_retry evaluations**: {len(exec_should_retry)}") + w() + + if exec_retry_sleeps: + w("### Retry Events") + w() + w("| Timestamp | Thread | Attempt | Sleep (s) | Message |") + w("|-----------|--------|---------|-----------|---------|") + for e in exec_retry_sleeps[:50]: + ts = datetime.fromtimestamp(e["timestamp"]).strftime("%H:%M:%S.%f")[:-3] + w(f"| {ts} | {e['thread']} | {e['attempt']} | {e.get('sleep', '?')} | {e['message'][:150]} |") + if len(exec_retry_sleeps) > 50: + w(f"| ... | ... | ... | ... | {len(exec_retry_sleeps) - 50} more |") + w() + + if exec_should_retry: + w("### should_retry Decisions") + w() + status_counts = defaultdict(int) + for e in exec_should_retry: + status_counts[e["status"]] += 1 + w("| HTTP Status | Count |") + w("|-------------|-------|") + for status, count in sorted(status_counts.items()): + w(f"| {status} | {count} |") + w() + + # --- Footer --- + w("---") + w(f"*Generated by profile_column_tags.py on {datetime.now().isoformat()}*") + + report_text = "\n".join(lines) + with open(report_path, "w") as f: + f.write(report_text) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="Profile SET COLUMN TAGS performance") + parser.add_argument("--columns", type=int, required=True, help="Number of columns to tag per table (1, 2, 4)") + parser.add_argument("--tags", type=int, required=True, help="Number of tags per ALTER command (1, 2, 4)") + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads (1, 2, 4, 8, 16)") + parser.add_argument("--iterations", type=int, required=True, help="Number of times to repeat the full sweep") + parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration, print result") + parser.add_argument("--skip-setup", action="store_true", help="Skip table creation (tables already exist)") + args = parser.parse_args() + + if args.columns > MAX_COLUMNS: + print(f"Error: --columns {args.columns} exceeds MAX_COLUMNS={MAX_COLUMNS}") + sys.exit(1) + + if args.validate: + args.iterations = 1 + print("=== VALIDATION MODE: 1 iteration only ===\n") + + # File paths + os.makedirs(RESULTS_DIR, exist_ok=True) + prefix = f"c{args.columns}_t{args.tags}_n{args.threads}_i{args.iterations}" + report_path = os.path.join(RESULTS_DIR, f"{prefix}_report.md") + data_path = os.path.join(RESULTS_DIR, f"{prefix}_data.jsonl") + log_path = os.path.join(RESULTS_DIR, f"{prefix}_retries.log") + + # Logging + profile_handler = setup_logging(log_path) + + print(f"Profile: columns={args.columns}, tags={args.tags}, threads={args.threads}, iterations={args.iterations}") + print(f"ALTERs per iteration: {NUM_TABLES * args.columns}") + print(f"Total ALTERs: {NUM_TABLES * args.columns * args.iterations}") + print(f"Output: {report_path}") + print() + + # Setup + if not args.skip_setup: + setup_tables() + + # Run iterations + all_alter_results = [] + all_table_results = [] + iteration_durations = [] + + for i in range(1, args.iterations + 1): + print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) + alter_results, table_results, duration = run_iteration( + iteration=i, + num_columns=args.columns, + num_tags=args.tags, + num_threads=args.threads, + ) + all_alter_results.extend(alter_results) + all_table_results.extend(table_results) + iteration_durations.append(duration) + + alters_count = len(alter_results) + errors = len([r for r in alter_results if not r["success"]]) + rps = alters_count / duration if duration > 0 else 0 + print(f"done in {duration:.2f}s ({alters_count} ALTERs, {errors} errors, {rps:.1f} ALTERs/sec)") + + print() + + # Write raw data + with open(data_path, "w") as f: + for r in all_alter_results: + f.write(json.dumps(r) + "\n") + # separator + f.write("\n") + for r in all_table_results: + f.write(json.dumps(r) + "\n") + + # Generate report + generate_report(args, all_alter_results, all_table_results, iteration_durations, profile_handler, report_path) + + print(f"Report written to: {report_path}") + print(f"Raw data written to: {data_path}") + print(f"Retry log written to: {log_path}") + + # Print summary to stdout + success_lats = [r["latency_ms"] for r in all_alter_results if r["success"]] + if success_lats: + s = latency_stats(success_lats) + total_dur = sum(iteration_durations) + print() + print("=== Summary ===") + print(f" ALTERs: {len(all_alter_results)} ({len(success_lats)} ok, {len(all_alter_results) - len(success_lats)} failed)") + print(f" Latency: p50={s['p50']:.1f}ms p90={s['p90']:.1f}ms p95={s['p95']:.1f}ms p99={s['p99']:.1f}ms max={s['max']:.1f}ms") + print(f" Throughput: {len(all_alter_results) / total_dur:.1f} ALTERs/sec") + + +if __name__ == "__main__": + main() diff --git a/examples/profile_read_then_write_table_tags.py b/examples/profile_read_then_write_table_tags.py new file mode 100644 index 000000000..a8561f7a3 --- /dev/null +++ b/examples/profile_read_then_write_table_tags.py @@ -0,0 +1,545 @@ +#!/usr/bin/env python3 +""" +Profile information_schema.table_tags SELECT performance. + +For each table, this script SELECTs existing table tags from +system.information_schema.table_tags. No ALTER/write operations. + +Usage: + python examples/profile_read_then_write_table_tags.py --threads 1 --iterations 1 --validate + python examples/profile_read_then_write_table_tags.py --threads 8 --iterations 10 +""" + +import argparse +import json +import logging +import os +import random +import re +import statistics +import string +import sys +import threading +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from queue import Empty, Queue + +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql + +# ============================================================ +# CONFIGURATION — loaded from examples/credentials.env +# ============================================================ +from load_credentials import load_credentials +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] +# ============================================================ + +NUM_TABLES = 64 +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "read_then_write_table_tags") + +SELECT_TEMPLATE = """SELECT tag_name, tag_value +FROM system.information_schema.table_tags +WHERE catalog_name = '{catalog}' + AND schema_name = '{schema}' + AND table_name = '{table}'""" + + +# --------------------------------------------------------------------------- +# Logging setup +# --------------------------------------------------------------------------- + +class ProfileLogHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records: list = [] + + def emit(self, record): + msg = record.getMessage() + if "[PROFILE]" in msg: + self.records.append( + {"timestamp": record.created, "thread": record.threadName, "message": msg} + ) + + +def setup_logging(log_path: str) -> ProfileLogHandler: + profile_handler = ProfileLogHandler() + profile_handler.setLevel(logging.INFO) + + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s") + ) + + for logger_name in [ + "databricks.sql.backend.thrift_backend", + "databricks.sql.auth.retry", + "databricks.sql.client", + ]: + lgr = logging.getLogger(logger_name) + lgr.setLevel(logging.DEBUG) + lgr.addHandler(profile_handler) + lgr.addHandler(file_handler) + + return profile_handler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def conn_params() -> dict: + return { + "server_hostname": SERVER_HOSTNAME, + "http_path": HTTP_PATH, + "access_token": ACCESS_TOKEN, + "_tls_no_verify": True, + } + + +def random_tag_value(length: int = 5) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) + + +def percentile(data: list, p: float) -> float: + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (p / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[f] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def latency_stats(latencies: list) -> dict: + if not latencies: + return {k: 0.0 for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]} + return { + "count": len(latencies), + "min": min(latencies), + "max": max(latencies), + "mean": statistics.mean(latencies), + "stdev": statistics.stdev(latencies) if len(latencies) > 1 else 0.0, + "p50": percentile(latencies, 50), + "p90": percentile(latencies, 90), + "p95": percentile(latencies, 95), + "p99": percentile(latencies, 99), + } + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + +def worker( + thread_id: int, + table_queue: Queue, + results: list, + results_lock: threading.Lock, +): + local_results = [] + table_fqn_prefix = f"`{CATALOG}`.`{SCHEMA}`" + + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + # Warmup + cursor.execute("SELECT 1") + + while True: + try: + table_name = table_queue.get_nowait() + except Empty: + break + + table_fqn = f"{table_fqn_prefix}.{table_name}" + + # --- Step 1: Read table tags from information_schema --- + select_sql = SELECT_TEMPLATE.format( + catalog=CATALOG, schema=SCHEMA, table=table_name + ) + + op_start = time.perf_counter() + select_start = time.perf_counter() + select_success = True + select_error_type = None + select_error_message = None + select_rows = 0 + select_statement_id = None + + try: + cursor.execute(select_sql) + select_statement_id = str(cursor.active_command_id) if cursor.active_command_id else None + rows = cursor.fetchall() + select_rows = len(rows) + except Exception as e: + select_success = False + select_error_type = type(e).__name__ + select_error_message = str(e)[:500] + + select_end = time.perf_counter() + select_latency_ms = (select_end - select_start) * 1000 + + local_results.append( + { + "table": table_name, + "thread_id": thread_id, + "select_latency_ms": round(select_latency_ms, 2), + "select_success": select_success, + "select_error_type": select_error_type, + "select_error_message": select_error_message, + "select_rows": select_rows, + "select_statement_id": select_statement_id, + "timestamp": op_start, + } + ) + + with results_lock: + results.extend(local_results) + + +# --------------------------------------------------------------------------- +# Run one iteration +# --------------------------------------------------------------------------- + +def run_iteration(iteration: int, num_threads: int) -> tuple: + table_queue = Queue() + for t in range(1, NUM_TABLES + 1): + table_queue.put(f"table{t}") + + results: list = [] + results_lock = threading.Lock() + + iter_start = time.perf_counter() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(worker, tid, table_queue, results, results_lock) + for tid in range(num_threads) + ] + for f in as_completed(futures): + f.result() + + iter_end = time.perf_counter() + duration_s = iter_end - iter_start + + for r in results: + r["iteration"] = iteration + + return results, duration_s + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- + +def generate_report( + args, + all_results: list, + iteration_durations: list, + profile_handler: ProfileLogHandler, + report_path: str, +): + lines = [] + + def w(text=""): + lines.append(text) + + total_ops = len(all_results) + total_duration = sum(iteration_durations) + + select_ok = [r for r in all_results if r["select_success"]] + select_fail = [r for r in all_results if not r["select_success"]] + select_latencies = [r["select_latency_ms"] for r in select_ok] + + # --- Header --- + w(f"# Information Schema Table Tags Profile: N={args.threads}, I={args.iterations}") + w() + w("## Configuration") + w(f"- **Server**: `{SERVER_HOSTNAME}`") + w(f"- **HTTP Path**: `{HTTP_PATH}`") + w(f"- **Catalog.Schema**: `{CATALOG}.{SCHEMA}`") + w(f"- **Tables**: {NUM_TABLES}") + w(f"- **Pattern**: SELECT from system.information_schema.table_tags per table") + w(f"- **Threads**: {args.threads}") + w(f"- **Iterations**: {args.iterations}") + w(f"- **Total SELECTs**: {total_ops}") + w(f"- **Date**: {datetime.now().isoformat()}") + w() + + # --- Latency --- + w("## SELECT Latency (ms)") + w() + ss = latency_stats(select_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]: + w(f"| {k} | {ss[k]:.2f} |") + w() + + # --- Throughput --- + w("## Throughput") + w() + w(f"- **Total SELECTs**: {total_ops}") + w(f"- **Successes**: {len(select_ok)} / {total_ops}") + w(f"- **Failures**: {len(select_fail)} / {total_ops}") + w(f"- **Total wall-clock**: {total_duration:.2f}s") + if total_duration > 0: + w(f"- **SELECTs/sec**: {total_ops / total_duration:.2f}") + w() + + # --- Cold Start vs Steady State --- + if args.iterations > 1: + w("## Cold Start vs Steady State") + w() + iter1 = [r["select_latency_ms"] for r in select_ok if r["iteration"] == 1] + iter_rest = [r["select_latency_ms"] for r in select_ok if r["iteration"] > 1] + s1 = latency_stats(iter1) + sr = latency_stats(iter_rest) + w("| Phase | Count | P50 (ms) | P90 (ms) | P99 (ms) |") + w("|-------|-------|----------|----------|----------|") + w(f"| Iteration 1 | {s1['count']:.0f} | {s1['p50']:.2f} | {s1['p90']:.2f} | {s1['p99']:.2f} |") + w(f"| Iterations 2-{args.iterations} | {sr['count']:.0f} | {sr['p50']:.2f} | {sr['p90']:.2f} | {sr['p99']:.2f} |") + w() + + # --- Per-Iteration Summary --- + w("## Per-Iteration Summary") + w() + w("| Iteration | SELECTs | P50 (ms) | P90 (ms) | P99 (ms) | Errors | Duration (s) | SELECTs/sec |") + w("|-----------|---------|----------|----------|----------|--------|--------------|-------------|") + for i in range(1, args.iterations + 1): + i_lats = [r["select_latency_ms"] for r in select_ok if r["iteration"] == i] + i_errs = len([r for r in all_results if r["iteration"] == i and not r["select_success"]]) + dur = iteration_durations[i - 1] + ops_in_iter = len([r for r in all_results if r["iteration"] == i]) + rps = ops_in_iter / dur if dur > 0 else 0 + s = latency_stats(i_lats) + w(f"| {i} | {s['count']:.0f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p99']:.2f} | {i_errs} | {dur:.2f} | {rps:.2f} |") + w() + + # --- All SELECTs with Statement IDs --- + w("## All SELECTs with Statement IDs") + w() + w("| Table | Iteration | Latency (ms) | Rows | Statement ID |") + w("|-------|-----------|-------------|------|--------------|") + sorted_results = sorted(all_results, key=lambda r: (r.get("iteration", 0), int(r["table"].replace("table", "")))) + for r in sorted_results: + w( + f"| {r['table']} | {r.get('iteration', '?')} " + f"| {r['select_latency_ms']:.2f} | {r['select_rows']} " + f"| {r.get('select_statement_id', 'N/A')} |" + ) + w() + + # --- By Thread --- + w("## Latency by Thread (ms)") + w() + w("| Thread | SELECTs | P50 | P90 | P99 |") + w("|--------|---------|-----|-----|-----|") + threads_seen = sorted(set(r["thread_id"] for r in select_ok)) + for tid in threads_seen: + t_lats = latency_stats([r["select_latency_ms"] for r in select_ok if r["thread_id"] == tid]) + w(f"| {tid} | {t_lats['count']:.0f} | {t_lats['p50']:.2f} | {t_lats['p90']:.2f} | {t_lats['p99']:.2f} |") + w() + + # --- Rows returned by SELECT --- + w("## Information Schema Rows Returned") + w() + row_counts = [r["select_rows"] for r in select_ok] + if row_counts: + w(f"- **Min rows**: {min(row_counts)}") + w(f"- **Max rows**: {max(row_counts)}") + w(f"- **Mean rows**: {statistics.mean(row_counts):.1f}") + w() + + # --- Error Analysis --- + w("## Error Analysis") + w() + all_errors = [] + for r in all_results: + if not r["select_success"]: + all_errors.append({"table": r["table"], "iteration": r.get("iteration", "?"), + "error_type": r["select_error_type"], "error_message": r["select_error_message"]}) + + if not all_errors: + w("No errors encountered.") + else: + error_groups = defaultdict(list) + for e in all_errors: + error_groups[e["error_type"]].append(e) + w("| Error Type | Count | % of Total | Sample Message |") + w("|------------|-------|------------|----------------|") + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + pct = len(records) / total_ops * 100 + sample = records[0]["error_message"][:200] if records[0]["error_message"] else "N/A" + w(f"| {etype} | {len(records)} | {pct:.1f}% | {sample} |") + w() + + w("### Error Detail") + w() + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + w(f"**{etype}** ({len(records)} occurrences)") + w() + for e in records[:3]: + w(f"- Table: {e['table']}, Iteration: {e['iteration']}") + w(f" Message: {e['error_message']}") + if len(records) > 3: + w(f"- ... and {len(records) - 3} more") + w() + w() + + # --- Retry Analysis --- + ATTEMPT_RE = re.compile(r"\[PROFILE\] (?P\w+) attempt (?P\d+)/(?P\d+)") + SUCCESS_RE = re.compile(r"\[PROFILE\] (?P\w+) succeeded on attempt (?P\d+) in (?P[0-9.]+)s") + SHOULD_RETRY_RE = re.compile(r"\[PROFILE\] should_retry: status=(?P\d+), command=(?P[^,]+),") + RETRY_SLEEP_RE = re.compile(r"\[PROFILE\] (?P\w+) retry sleep=(?P[0-9.]+)s, attempt=(?P\d+)/(?P\d+)") + + parsed_events = [] + for r in profile_handler.records: + msg = r["message"] + for etype, regex in [("attempt", ATTEMPT_RE), ("success", SUCCESS_RE), + ("should_retry", SHOULD_RETRY_RE), ("retry_sleep", RETRY_SLEEP_RE)]: + m = regex.search(msg) + if m: + event = {"type": etype, "thread": r["thread"], "timestamp": r["timestamp"], "message": msg} + event.update(m.groupdict()) + if "attempt" in event: + event["attempt"] = int(event["attempt"]) + if "status" in event: + event["status"] = int(event["status"]) + parsed_events.append(event) + break + + exec_events = [e for e in parsed_events if e.get("cmd") == "ExecuteStatement"] + exec_retry_sleeps = [e for e in exec_events if e["type"] == "retry_sleep"] + exec_should_retry = [e for e in exec_events if e["type"] == "should_retry"] + exec_success_after_retry = [e for e in exec_events if e["type"] == "success" and e["attempt"] > 1] + exec_total_attempts = [e for e in exec_events if e["type"] == "attempt"] + exec_successes = [e for e in exec_events if e["type"] == "success"] + + w("## Statement Retry Analysis (ExecuteStatement only)") + w() + w("*Includes information_schema SELECTs and one warmup SELECT 1 per thread.*") + w() + w(f"- **Total [PROFILE] events (all commands)**: {len(parsed_events)}") + w(f"- **ExecuteStatement attempts**: {len(exec_total_attempts)}") + w(f"- **ExecuteStatement successes**: {len(exec_successes)}") + w(f"- **ExecuteStatement retry sleeps**: {len(exec_retry_sleeps)}") + w(f"- **ExecuteStatement succeeded after retry (attempt > 1)**: {len(exec_success_after_retry)}") + w(f"- **should_retry evaluations**: {len(exec_should_retry)}") + w() + + if exec_retry_sleeps: + w("### Retry Events") + w() + w("| Timestamp | Thread | Attempt | Sleep (s) | Message |") + w("|-----------|--------|---------|-----------|---------|") + for e in exec_retry_sleeps[:50]: + ts = datetime.fromtimestamp(e["timestamp"]).strftime("%H:%M:%S.%f")[:-3] + w(f"| {ts} | {e['thread']} | {e['attempt']} | {e.get('sleep', '?')} | {e['message'][:150]} |") + if len(exec_retry_sleeps) > 50: + w(f"| ... | ... | ... | ... | {len(exec_retry_sleeps) - 50} more |") + w() + + if exec_should_retry: + w("### should_retry Decisions") + w() + status_counts = defaultdict(int) + for e in exec_should_retry: + status_counts[e["status"]] += 1 + w("| HTTP Status | Count |") + w("|-------------|-------|") + for status, count in sorted(status_counts.items()): + w(f"| {status} | {count} |") + w() + + # --- Footer --- + w("---") + w(f"*Generated by profile_read_then_write_table_tags.py on {datetime.now().isoformat()}*") + + report_text = "\n".join(lines) + with open(report_path, "w") as f: + f.write(report_text) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Profile read-from-information_schema then write-table-tag pattern" + ) + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") + parser.add_argument("--iterations", type=int, required=True, help="Number of times to repeat the full sweep") + parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration") + args = parser.parse_args() + + if args.validate: + args.iterations = 1 + print("=== VALIDATION MODE: 1 iteration only ===\n") + + os.makedirs(RESULTS_DIR, exist_ok=True) + prefix = f"rwtt_n{args.threads}_i{args.iterations}" + report_path = os.path.join(RESULTS_DIR, f"{prefix}_report.md") + data_path = os.path.join(RESULTS_DIR, f"{prefix}_data.jsonl") + log_path = os.path.join(RESULTS_DIR, f"{prefix}_retries.log") + + profile_handler = setup_logging(log_path) + + print(f"Profile (information_schema.table_tags): threads={args.threads}, iterations={args.iterations}") + print(f"SELECTs per iteration: {NUM_TABLES} (1 per table)") + print(f"Total SELECTs: {NUM_TABLES * args.iterations}") + print(f"Output: {report_path}") + print() + + all_results = [] + iteration_durations = [] + + for i in range(1, args.iterations + 1): + print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) + results, duration = run_iteration(iteration=i, num_threads=args.threads) + all_results.extend(results) + iteration_durations.append(duration) + + errs = len([r for r in results if not r["select_success"]]) + rps = len(results) / duration if duration > 0 else 0 + print(f"done in {duration:.2f}s ({len(results)} SELECTs, {errs} errors, {rps:.1f} SELECTs/sec)") + + print() + + with open(data_path, "w") as f: + for r in all_results: + f.write(json.dumps(r) + "\n") + + generate_report(args, all_results, iteration_durations, profile_handler, report_path) + + print(f"Report written to: {report_path}") + print(f"Raw data written to: {data_path}") + print(f"Retry log written to: {log_path}") + + ok = [r for r in all_results if r["select_success"]] + if ok: + s = latency_stats([r["select_latency_ms"] for r in ok]) + total_dur = sum(iteration_durations) + print() + print("=== Summary ===") + print(f" SELECTs: {len(all_results)} ({len(ok)} ok, {len(all_results) - len(ok)} failed)") + print(f" Latency: p50={s['p50']:.1f}ms p90={s['p90']:.1f}ms p99={s['p99']:.1f}ms max={s['max']:.1f}ms") + print(f" Throughput: {len(all_results) / total_dur:.1f} SELECTs/sec") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/profile_read_then_write_tags.py b/examples/profile_read_then_write_tags.py new file mode 100644 index 000000000..a80f7a994 --- /dev/null +++ b/examples/profile_read_then_write_tags.py @@ -0,0 +1,546 @@ +#!/usr/bin/env python3 +""" +Profile information_schema.column_tags SELECT performance. + +For each table, this script SELECTs existing column tags from +system.information_schema.column_tags. No ALTER/write operations. + +Usage: + python examples/profile_read_then_write_tags.py --threads 1 --iterations 1 --validate + python examples/profile_read_then_write_tags.py --threads 8 --iterations 10 + python examples/profile_read_then_write_tags.py --threads 32 --iterations 10 +""" + +import argparse +import json +import logging +import os +import random +import re +import statistics +import string +import sys +import threading +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from queue import Empty, Queue + +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql + +# ============================================================ +# CONFIGURATION — loaded from examples/credentials.env +# ============================================================ +from load_credentials import load_credentials +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] +# ============================================================ + +NUM_TABLES = 64 +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "read_then_write") + +SELECT_TEMPLATE = """SELECT column_name, tag_name, tag_value +FROM system.information_schema.column_tags +WHERE catalog_name = '{catalog}' + AND schema_name = '{schema}' + AND table_name = '{table}'""" + + +# --------------------------------------------------------------------------- +# Logging setup +# --------------------------------------------------------------------------- + +class ProfileLogHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records: list = [] + + def emit(self, record): + msg = record.getMessage() + if "[PROFILE]" in msg: + self.records.append( + {"timestamp": record.created, "thread": record.threadName, "message": msg} + ) + + +def setup_logging(log_path: str) -> ProfileLogHandler: + profile_handler = ProfileLogHandler() + profile_handler.setLevel(logging.INFO) + + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s") + ) + + for logger_name in [ + "databricks.sql.backend.thrift_backend", + "databricks.sql.auth.retry", + "databricks.sql.client", + ]: + lgr = logging.getLogger(logger_name) + lgr.setLevel(logging.DEBUG) + lgr.addHandler(profile_handler) + lgr.addHandler(file_handler) + + return profile_handler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def conn_params() -> dict: + return { + "server_hostname": SERVER_HOSTNAME, + "http_path": HTTP_PATH, + "access_token": ACCESS_TOKEN, + "_tls_no_verify": True, + } + + +def random_tag_value(length: int = 5) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) + + +def percentile(data: list, p: float) -> float: + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (p / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[f] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def latency_stats(latencies: list) -> dict: + if not latencies: + return {k: 0.0 for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]} + return { + "count": len(latencies), + "min": min(latencies), + "max": max(latencies), + "mean": statistics.mean(latencies), + "stdev": statistics.stdev(latencies) if len(latencies) > 1 else 0.0, + "p50": percentile(latencies, 50), + "p90": percentile(latencies, 90), + "p95": percentile(latencies, 95), + "p99": percentile(latencies, 99), + } + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + +def worker( + thread_id: int, + table_queue: Queue, + results: list, + results_lock: threading.Lock, +): + local_results = [] + table_fqn_prefix = f"`{CATALOG}`.`{SCHEMA}`" + + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + # Warmup + cursor.execute("SELECT 1") + + while True: + try: + table_name = table_queue.get_nowait() + except Empty: + break + + table_fqn = f"{table_fqn_prefix}.{table_name}" + + # --- Step 1: Read column tags from information_schema --- + select_sql = SELECT_TEMPLATE.format( + catalog=CATALOG, schema=SCHEMA, table=table_name + ) + + op_start = time.perf_counter() + select_start = time.perf_counter() + select_success = True + select_error_type = None + select_error_message = None + select_rows = 0 + select_statement_id = None + + try: + cursor.execute(select_sql) + select_statement_id = str(cursor.active_command_id) if cursor.active_command_id else None + rows = cursor.fetchall() + select_rows = len(rows) + except Exception as e: + select_success = False + select_error_type = type(e).__name__ + select_error_message = str(e)[:500] + + select_end = time.perf_counter() + select_latency_ms = (select_end - select_start) * 1000 + + local_results.append( + { + "table": table_name, + "thread_id": thread_id, + "select_latency_ms": round(select_latency_ms, 2), + "select_success": select_success, + "select_error_type": select_error_type, + "select_error_message": select_error_message, + "select_rows": select_rows, + "select_statement_id": select_statement_id, + "timestamp": op_start, + } + ) + + with results_lock: + results.extend(local_results) + + +# --------------------------------------------------------------------------- +# Run one iteration +# --------------------------------------------------------------------------- + +def run_iteration(iteration: int, num_threads: int) -> tuple: + table_queue = Queue() + for t in range(1, NUM_TABLES + 1): + table_queue.put(f"table{t}") + + results: list = [] + results_lock = threading.Lock() + + iter_start = time.perf_counter() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(worker, tid, table_queue, results, results_lock) + for tid in range(num_threads) + ] + for f in as_completed(futures): + f.result() + + iter_end = time.perf_counter() + duration_s = iter_end - iter_start + + for r in results: + r["iteration"] = iteration + + return results, duration_s + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- + +def generate_report( + args, + all_results: list, + iteration_durations: list, + profile_handler: ProfileLogHandler, + report_path: str, +): + lines = [] + + def w(text=""): + lines.append(text) + + total_ops = len(all_results) + total_duration = sum(iteration_durations) + + select_ok = [r for r in all_results if r["select_success"]] + select_fail = [r for r in all_results if not r["select_success"]] + select_latencies = [r["select_latency_ms"] for r in select_ok] + + # --- Header --- + w(f"# Information Schema Column Tags Profile: N={args.threads}, I={args.iterations}") + w() + w("## Configuration") + w(f"- **Server**: `{SERVER_HOSTNAME}`") + w(f"- **HTTP Path**: `{HTTP_PATH}`") + w(f"- **Catalog.Schema**: `{CATALOG}.{SCHEMA}`") + w(f"- **Tables**: {NUM_TABLES}") + w(f"- **Pattern**: SELECT from system.information_schema.column_tags per table") + w(f"- **Threads**: {args.threads}") + w(f"- **Iterations**: {args.iterations}") + w(f"- **Total SELECTs**: {total_ops}") + w(f"- **Date**: {datetime.now().isoformat()}") + w() + + # --- Latency --- + w("## SELECT Latency (ms)") + w() + ss = latency_stats(select_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]: + w(f"| {k} | {ss[k]:.2f} |") + w() + + # --- Throughput --- + w("## Throughput") + w() + w(f"- **Total SELECTs**: {total_ops}") + w(f"- **Successes**: {len(select_ok)} / {total_ops}") + w(f"- **Failures**: {len(select_fail)} / {total_ops}") + w(f"- **Total wall-clock**: {total_duration:.2f}s") + if total_duration > 0: + w(f"- **SELECTs/sec**: {total_ops / total_duration:.2f}") + w() + + # --- Cold Start vs Steady State --- + if args.iterations > 1: + w("## Cold Start vs Steady State") + w() + iter1 = [r["select_latency_ms"] for r in select_ok if r["iteration"] == 1] + iter_rest = [r["select_latency_ms"] for r in select_ok if r["iteration"] > 1] + s1 = latency_stats(iter1) + sr = latency_stats(iter_rest) + w("| Phase | Count | P50 (ms) | P90 (ms) | P99 (ms) |") + w("|-------|-------|----------|----------|----------|") + w(f"| Iteration 1 | {s1['count']:.0f} | {s1['p50']:.2f} | {s1['p90']:.2f} | {s1['p99']:.2f} |") + w(f"| Iterations 2-{args.iterations} | {sr['count']:.0f} | {sr['p50']:.2f} | {sr['p90']:.2f} | {sr['p99']:.2f} |") + w() + + # --- Per-Iteration Summary --- + w("## Per-Iteration Summary") + w() + w("| Iteration | SELECTs | P50 (ms) | P90 (ms) | P99 (ms) | Errors | Duration (s) | SELECTs/sec |") + w("|-----------|---------|----------|----------|----------|--------|--------------|-------------|") + for i in range(1, args.iterations + 1): + i_lats = [r["select_latency_ms"] for r in select_ok if r["iteration"] == i] + i_errs = len([r for r in all_results if r["iteration"] == i and not r["select_success"]]) + dur = iteration_durations[i - 1] + ops_in_iter = len([r for r in all_results if r["iteration"] == i]) + rps = ops_in_iter / dur if dur > 0 else 0 + s = latency_stats(i_lats) + w(f"| {i} | {s['count']:.0f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p99']:.2f} | {i_errs} | {dur:.2f} | {rps:.2f} |") + w() + + # --- All Operations with Statement IDs --- + w("## All SELECTs with Statement IDs") + w() + w("| Table | Iteration | Latency (ms) | Rows | Statement ID |") + w("|-------|-----------|-------------|------|--------------|") + sorted_results = sorted(all_results, key=lambda r: (r.get("iteration", 0), int(r["table"].replace("table", "")))) + for r in sorted_results: + w( + f"| {r['table']} | {r.get('iteration', '?')} " + f"| {r['select_latency_ms']:.2f} | {r['select_rows']} " + f"| {r.get('select_statement_id', 'N/A')} |" + ) + w() + + # --- By Thread --- + w("## Latency by Thread (ms)") + w() + w("| Thread | SELECTs | P50 | P90 | P99 |") + w("|--------|---------|-----|-----|-----|") + threads_seen = sorted(set(r["thread_id"] for r in select_ok)) + for tid in threads_seen: + t_lats = latency_stats([r["select_latency_ms"] for r in select_ok if r["thread_id"] == tid]) + w(f"| {tid} | {t_lats['count']:.0f} | {t_lats['p50']:.2f} | {t_lats['p90']:.2f} | {t_lats['p99']:.2f} |") + w() + + # --- Rows returned by SELECT --- + w("## Information Schema Rows Returned") + w() + row_counts = [r["select_rows"] for r in select_ok] + if row_counts: + w(f"- **Min rows**: {min(row_counts)}") + w(f"- **Max rows**: {max(row_counts)}") + w(f"- **Mean rows**: {statistics.mean(row_counts):.1f}") + w() + + # --- Error Analysis --- + w("## Error Analysis") + w() + all_errors = [] + for r in all_results: + if not r["select_success"]: + all_errors.append({"table": r["table"], "iteration": r.get("iteration", "?"), + "error_type": r["select_error_type"], "error_message": r["select_error_message"]}) + + if not all_errors: + w("No errors encountered.") + else: + error_groups = defaultdict(list) + for e in all_errors: + error_groups[e["error_type"]].append(e) + w("| Error Type | Count | % of Total | Sample Message |") + w("|------------|-------|------------|----------------|") + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + pct = len(records) / total_ops * 100 + sample = records[0]["error_message"][:200] if records[0]["error_message"] else "N/A" + w(f"| {etype} | {len(records)} | {pct:.1f}% | {sample} |") + w() + + w("### Error Detail") + w() + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + w(f"**{etype}** ({len(records)} occurrences)") + w() + for e in records[:3]: + w(f"- Table: {e['table']}, Iteration: {e['iteration']}") + w(f" Message: {e['error_message']}") + if len(records) > 3: + w(f"- ... and {len(records) - 3} more") + w() + w() + + # --- Retry Analysis --- + ATTEMPT_RE = re.compile(r"\[PROFILE\] (?P\w+) attempt (?P\d+)/(?P\d+)") + SUCCESS_RE = re.compile(r"\[PROFILE\] (?P\w+) succeeded on attempt (?P\d+) in (?P[0-9.]+)s") + SHOULD_RETRY_RE = re.compile(r"\[PROFILE\] should_retry: status=(?P\d+), command=(?P[^,]+),") + RETRY_SLEEP_RE = re.compile(r"\[PROFILE\] (?P\w+) retry sleep=(?P[0-9.]+)s, attempt=(?P\d+)/(?P\d+)") + + parsed_events = [] + for r in profile_handler.records: + msg = r["message"] + for etype, regex in [("attempt", ATTEMPT_RE), ("success", SUCCESS_RE), + ("should_retry", SHOULD_RETRY_RE), ("retry_sleep", RETRY_SLEEP_RE)]: + m = regex.search(msg) + if m: + event = {"type": etype, "thread": r["thread"], "timestamp": r["timestamp"], "message": msg} + event.update(m.groupdict()) + if "attempt" in event: + event["attempt"] = int(event["attempt"]) + if "status" in event: + event["status"] = int(event["status"]) + parsed_events.append(event) + break + + exec_events = [e for e in parsed_events if e.get("cmd") == "ExecuteStatement"] + exec_retry_sleeps = [e for e in exec_events if e["type"] == "retry_sleep"] + exec_should_retry = [e for e in exec_events if e["type"] == "should_retry"] + exec_success_after_retry = [e for e in exec_events if e["type"] == "success" and e["attempt"] > 1] + exec_total_attempts = [e for e in exec_events if e["type"] == "attempt"] + exec_successes = [e for e in exec_events if e["type"] == "success"] + + w("## Statement Retry Analysis (ExecuteStatement only)") + w() + w("*Includes information_schema SELECTs and one warmup SELECT 1 per thread.*") + w() + w(f"- **Total [PROFILE] events (all commands)**: {len(parsed_events)}") + w(f"- **ExecuteStatement attempts**: {len(exec_total_attempts)}") + w(f"- **ExecuteStatement successes**: {len(exec_successes)}") + w(f"- **ExecuteStatement retry sleeps**: {len(exec_retry_sleeps)}") + w(f"- **ExecuteStatement succeeded after retry (attempt > 1)**: {len(exec_success_after_retry)}") + w(f"- **should_retry evaluations**: {len(exec_should_retry)}") + w() + + if exec_retry_sleeps: + w("### Retry Events") + w() + w("| Timestamp | Thread | Attempt | Sleep (s) | Message |") + w("|-----------|--------|---------|-----------|---------|") + for e in exec_retry_sleeps[:50]: + ts = datetime.fromtimestamp(e["timestamp"]).strftime("%H:%M:%S.%f")[:-3] + w(f"| {ts} | {e['thread']} | {e['attempt']} | {e.get('sleep', '?')} | {e['message'][:150]} |") + if len(exec_retry_sleeps) > 50: + w(f"| ... | ... | ... | ... | {len(exec_retry_sleeps) - 50} more |") + w() + + if exec_should_retry: + w("### should_retry Decisions") + w() + status_counts = defaultdict(int) + for e in exec_should_retry: + status_counts[e["status"]] += 1 + w("| HTTP Status | Count |") + w("|-------------|-------|") + for status, count in sorted(status_counts.items()): + w(f"| {status} | {count} |") + w() + + # --- Footer --- + w("---") + w(f"*Generated by profile_read_then_write_tags.py on {datetime.now().isoformat()}*") + + report_text = "\n".join(lines) + with open(report_path, "w") as f: + f.write(report_text) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Profile read-from-information_schema then write-column-tag pattern" + ) + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") + parser.add_argument("--iterations", type=int, required=True, help="Number of times to repeat the full sweep") + parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration") + args = parser.parse_args() + + if args.validate: + args.iterations = 1 + print("=== VALIDATION MODE: 1 iteration only ===\n") + + os.makedirs(RESULTS_DIR, exist_ok=True) + prefix = f"rw_n{args.threads}_i{args.iterations}" + report_path = os.path.join(RESULTS_DIR, f"{prefix}_report.md") + data_path = os.path.join(RESULTS_DIR, f"{prefix}_data.jsonl") + log_path = os.path.join(RESULTS_DIR, f"{prefix}_retries.log") + + profile_handler = setup_logging(log_path) + + print(f"Profile (information_schema.column_tags): threads={args.threads}, iterations={args.iterations}") + print(f"SELECTs per iteration: {NUM_TABLES} (1 per table)") + print(f"Total SELECTs: {NUM_TABLES * args.iterations}") + print(f"Output: {report_path}") + print() + + all_results = [] + iteration_durations = [] + + for i in range(1, args.iterations + 1): + print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) + results, duration = run_iteration(iteration=i, num_threads=args.threads) + all_results.extend(results) + iteration_durations.append(duration) + + errs = len([r for r in results if not r["select_success"]]) + rps = len(results) / duration if duration > 0 else 0 + print(f"done in {duration:.2f}s ({len(results)} SELECTs, {errs} errors, {rps:.1f} SELECTs/sec)") + + print() + + with open(data_path, "w") as f: + for r in all_results: + f.write(json.dumps(r) + "\n") + + generate_report(args, all_results, iteration_durations, profile_handler, report_path) + + print(f"Report written to: {report_path}") + print(f"Raw data written to: {data_path}") + print(f"Retry log written to: {log_path}") + + ok = [r for r in all_results if r["select_success"]] + if ok: + s = latency_stats([r["select_latency_ms"] for r in ok]) + total_dur = sum(iteration_durations) + print() + print("=== Summary ===") + print(f" SELECTs: {len(all_results)} ({len(ok)} ok, {len(all_results) - len(ok)} failed)") + print(f" Latency: p50={s['p50']:.1f}ms p90={s['p90']:.1f}ms p99={s['p99']:.1f}ms max={s['max']:.1f}ms") + print(f" Throughput: {len(all_results) / total_dur:.1f} SELECTs/sec") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/profile_table_tags.py b/examples/profile_table_tags.py new file mode 100644 index 000000000..a1ee5dfa6 --- /dev/null +++ b/examples/profile_table_tags.py @@ -0,0 +1,573 @@ +#!/usr/bin/env python3 +""" +Profile SET TABLE TAGS performance on Databricks. + +Uses existing tables (table1..table64). No column tags — only table-level tags. + +Usage: + # Quick validation + python examples/profile_table_tags.py --tags 1 --threads 1 --iterations 1 --validate + + # Single experiment + python examples/profile_table_tags.py --tags 4 --threads 8 --iterations 10 + + # Full sweep + for t in 1 2 4; do + for n in 1 2 4 8 16; do + python examples/profile_table_tags.py --tags $t --threads $n --iterations 10 + done + done +""" + +import argparse +import json +import logging +import os +import random +import re +import statistics +import string +import sys +import threading +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from queue import Empty, Queue + +# Force unbuffered stdout so output is visible when piped through grep +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql + +# ============================================================ +# CONFIGURATION — loaded from examples/credentials.env +# ============================================================ +from load_credentials import load_credentials +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] +# ============================================================ + +NUM_TABLES = 64 +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "table_tags") + + +# --------------------------------------------------------------------------- +# Logging setup +# --------------------------------------------------------------------------- + +class ProfileLogHandler(logging.Handler): + """Captures [PROFILE] log lines for retry analysis.""" + + def __init__(self): + super().__init__() + self.records: list = [] + + def emit(self, record): + msg = record.getMessage() + if "[PROFILE]" in msg: + self.records.append( + { + "timestamp": record.created, + "thread": record.threadName, + "message": msg, + } + ) + + +def setup_logging(log_path: str) -> ProfileLogHandler: + """Configure logging: file handler for all connector logs, profile handler for [PROFILE] lines.""" + profile_handler = ProfileLogHandler() + profile_handler.setLevel(logging.INFO) + + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s") + ) + + for logger_name in [ + "databricks.sql.backend.thrift_backend", + "databricks.sql.auth.retry", + "databricks.sql.client", + ]: + lgr = logging.getLogger(logger_name) + lgr.setLevel(logging.DEBUG) + lgr.addHandler(profile_handler) + lgr.addHandler(file_handler) + + return profile_handler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def conn_params() -> dict: + return { + "server_hostname": SERVER_HOSTNAME, + "http_path": HTTP_PATH, + "access_token": ACCESS_TOKEN, + "_tls_no_verify": True, + } + + +def random_tag_value(length: int = 5) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) + + +def build_table_tag_sql(table_fqn: str, num_tags: int) -> str: + tags = ", ".join(f"'key{i}' = '{random_tag_value()}'" for i in range(1, num_tags + 1)) + return f"ALTER TABLE {table_fqn} SET TAGS ({tags})" + + +def percentile(data: list, p: float) -> float: + """Return the p-th percentile (0-100) of data.""" + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (p / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[f] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def latency_stats(latencies: list) -> dict: + """Compute full latency statistics for a list of ms values.""" + if not latencies: + return {k: 0.0 for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]} + return { + "count": len(latencies), + "min": min(latencies), + "max": max(latencies), + "mean": statistics.mean(latencies), + "stdev": statistics.stdev(latencies) if len(latencies) > 1 else 0.0, + "p50": percentile(latencies, 50), + "p90": percentile(latencies, 90), + "p95": percentile(latencies, 95), + "p99": percentile(latencies, 99), + } + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + +def worker( + thread_id: int, + table_queue: Queue, + num_tags: int, + alter_results: list, + results_lock: threading.Lock, +): + """Worker thread: pulls tables from queue, sets table-level tags, records metrics.""" + local_results = [] + table_fqn_prefix = f"`{CATALOG}`.`{SCHEMA}`" + + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + # Warmup + cursor.execute("SELECT 1") + + while True: + try: + table_name = table_queue.get_nowait() + except Empty: + break + + table_fqn = f"{table_fqn_prefix}.{table_name}" + alter_sql = build_table_tag_sql(table_fqn, num_tags) + + cmd_start = time.perf_counter() + success = True + error_type = None + error_message = None + error_context = None + + try: + cursor.execute(alter_sql) + except Exception as e: + success = False + error_type = type(e).__name__ + error_message = str(e)[:500] + error_context = getattr(e, "context", None) + + cmd_end = time.perf_counter() + latency_ms = (cmd_end - cmd_start) * 1000 + + local_results.append( + { + "table": table_name, + "thread_id": thread_id, + "latency_ms": round(latency_ms, 2), + "success": success, + "error_type": error_type, + "error_message": error_message, + "error_context": str(error_context) if error_context else None, + "timestamp": cmd_start, + } + ) + + with results_lock: + alter_results.extend(local_results) + + +# --------------------------------------------------------------------------- +# Run one iteration +# --------------------------------------------------------------------------- + +def run_iteration( + iteration: int, + num_tags: int, + num_threads: int, +) -> tuple: + """Run a single iteration: distribute 64 tables across threads.""" + table_queue = Queue() + for t in range(1, NUM_TABLES + 1): + table_queue.put(f"table{t}") + + alter_results: list = [] + results_lock = threading.Lock() + + iter_start = time.perf_counter() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + for tid in range(num_threads): + f = executor.submit( + worker, + tid, + table_queue, + num_tags, + alter_results, + results_lock, + ) + futures.append(f) + + for f in as_completed(futures): + f.result() # raise any thread exceptions + + iter_end = time.perf_counter() + duration_s = iter_end - iter_start + + for r in alter_results: + r["iteration"] = iteration + + return alter_results, duration_s + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- + +def generate_report( + args, + all_results: list, + iteration_durations: list, + profile_handler: ProfileLogHandler, + report_path: str, +): + """Generate the markdown report.""" + lines = [] + + def w(text=""): + lines.append(text) + + total_alters = len(all_results) + total_duration = sum(iteration_durations) + successful = [r for r in all_results if r["success"]] + failed = [r for r in all_results if not r["success"]] + success_latencies = [r["latency_ms"] for r in successful] + + # --- Header --- + w(f"# Table Tags Profile: T={args.tags}, N={args.threads}, I={args.iterations}") + w() + w("## Configuration") + w(f"- **Server**: `{SERVER_HOSTNAME}`") + w(f"- **HTTP Path**: `{HTTP_PATH}`") + w(f"- **Catalog.Schema**: `{CATALOG}.{SCHEMA}`") + w(f"- **Tables**: {NUM_TABLES}") + w(f"- **Tags per ALTER**: {args.tags}") + w(f"- **Threads**: {args.threads}") + w(f"- **Iterations**: {args.iterations}") + w(f"- **Total ALTERs**: {total_alters}") + w(f"- **Date**: {datetime.now().isoformat()}") + w() + + # --- Overall ALTER Latency --- + w("## Per-ALTER Latency — All Iterations (ms)") + w() + stats = latency_stats(success_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k, v in stats.items(): + w(f"| {k} | {v:.2f} |") + w() + + # --- Throughput --- + w("## Throughput") + w() + w(f"- **Total ALTERs**: {total_alters}") + w(f"- **Successful**: {len(successful)}") + w(f"- **Failed**: {len(failed)}") + w(f"- **Total wall-clock**: {total_duration:.2f}s") + if total_duration > 0: + w(f"- **ALTERs/sec**: {total_alters / total_duration:.2f}") + w() + + # --- Cold Start vs Steady State --- + if args.iterations > 1: + w("## Cold Start vs Steady State") + w() + iter1 = [r["latency_ms"] for r in successful if r["iteration"] == 1] + iter_rest = [r["latency_ms"] for r in successful if r["iteration"] > 1] + w("| Phase | ALTERs | Mean (ms) | P50 (ms) | P99 (ms) |") + w("|-------|--------|-----------|----------|----------|") + s1 = latency_stats(iter1) + sr = latency_stats(iter_rest) + w(f"| Iteration 1 | {s1['count']:.0f} | {s1['mean']:.2f} | {s1['p50']:.2f} | {s1['p99']:.2f} |") + w(f"| Iterations 2-{args.iterations} | {sr['count']:.0f} | {sr['mean']:.2f} | {sr['p50']:.2f} | {sr['p99']:.2f} |") + w() + + # --- Per-Iteration Summary --- + w("## Per-Iteration Summary") + w() + w("| Iteration | ALTERs | Mean (ms) | P50 (ms) | P99 (ms) | Errors | Duration (s) | ALTERs/sec |") + w("|-----------|--------|-----------|----------|----------|--------|--------------|------------|") + for i in range(1, args.iterations + 1): + iter_lats = [r["latency_ms"] for r in successful if r["iteration"] == i] + iter_errs = len([r for r in failed if r["iteration"] == i]) + s = latency_stats(iter_lats) + dur = iteration_durations[i - 1] + alters_in_iter = len([r for r in all_results if r["iteration"] == i]) + rps = alters_in_iter / dur if dur > 0 else 0 + w( + f"| {i} | {s['count']:.0f} | {s['mean']:.2f} | {s['p50']:.2f} | {s['p99']:.2f} " + f"| {iter_errs} | {dur:.2f} | {rps:.2f} |" + ) + w() + + # --- Per-ALTER Latency by Table --- + w("## Per-ALTER Latency by Table (ms)") + w() + w("| Table | Count | Min | Max | Mean | P50 | P90 | P95 | P99 |") + w("|-------|-------|-----|-----|------|-----|-----|-----|-----|") + tables_seen = sorted(set(r["table"] for r in successful), key=lambda x: int(x.replace("table", ""))) + for tbl in tables_seen: + tbl_lats = [r["latency_ms"] for r in successful if r["table"] == tbl] + s = latency_stats(tbl_lats) + w( + f"| {tbl} | {s['count']:.0f} | {s['min']:.2f} | {s['max']:.2f} " + f"| {s['mean']:.2f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p95']:.2f} | {s['p99']:.2f} |" + ) + w() + + # --- Per-ALTER Latency by Thread --- + w("## Per-ALTER Latency by Thread (ms)") + w() + w("| Thread | Count | Min | Max | Mean | P50 | P90 | P95 | P99 |") + w("|--------|-------|-----|-----|------|-----|-----|-----|-----|") + threads_seen = sorted(set(r["thread_id"] for r in successful)) + for tid in threads_seen: + thr_lats = [r["latency_ms"] for r in successful if r["thread_id"] == tid] + s = latency_stats(thr_lats) + w( + f"| {tid} | {s['count']:.0f} | {s['min']:.2f} | {s['max']:.2f} " + f"| {s['mean']:.2f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p95']:.2f} | {s['p99']:.2f} |" + ) + w() + + # --- Error Analysis --- + w("## Error Analysis") + w() + if not failed: + w("No errors encountered.") + else: + error_groups = defaultdict(list) + for r in failed: + error_groups[r["error_type"]].append(r) + w("| Error Type | Count | % of Total | Sample Message |") + w("|------------|-------|------------|----------------|") + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + pct = len(records) / total_alters * 100 + sample = records[0]["error_message"][:200] if records[0]["error_message"] else "N/A" + w(f"| {etype} | {len(records)} | {pct:.1f}% | {sample} |") + w() + + w("### Error Detail") + w() + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + w(f"**{etype}** ({len(records)} occurrences)") + w() + for r in records[:3]: + w(f"- Table: {r['table']}, Iteration: {r['iteration']}") + w(f" Latency: {r['latency_ms']:.2f}ms") + w(f" Message: {r['error_message']}") + if r["error_context"]: + w(f" Context: {r['error_context']}") + if len(records) > 3: + w(f"- ... and {len(records) - 3} more") + w() + w() + + # --- Retry Analysis --- + ATTEMPT_RE = re.compile(r"\[PROFILE\] (?P\w+) attempt (?P\d+)/(?P\d+)") + SUCCESS_RE = re.compile(r"\[PROFILE\] (?P\w+) succeeded on attempt (?P\d+) in (?P[0-9.]+)s") + SHOULD_RETRY_RE = re.compile(r"\[PROFILE\] should_retry: status=(?P\d+), command=(?P[^,]+),") + RETRY_SLEEP_RE = re.compile(r"\[PROFILE\] (?P\w+) retry sleep=(?P[0-9.]+)s, attempt=(?P\d+)/(?P\d+)") + + parsed_events = [] + for r in profile_handler.records: + msg = r["message"] + for etype, regex in [("attempt", ATTEMPT_RE), ("success", SUCCESS_RE), + ("should_retry", SHOULD_RETRY_RE), ("retry_sleep", RETRY_SLEEP_RE)]: + m = regex.search(msg) + if m: + event = {"type": etype, "thread": r["thread"], "timestamp": r["timestamp"], "message": msg} + event.update(m.groupdict()) + if "attempt" in event: + event["attempt"] = int(event["attempt"]) + if "status" in event: + event["status"] = int(event["status"]) + parsed_events.append(event) + break + + exec_events = [e for e in parsed_events if e.get("cmd") == "ExecuteStatement"] + exec_retry_sleeps = [e for e in exec_events if e["type"] == "retry_sleep"] + exec_should_retry = [e for e in exec_events if e["type"] == "should_retry"] + exec_success_after_retry = [e for e in exec_events if e["type"] == "success" and e["attempt"] > 1] + exec_total_attempts = [e for e in exec_events if e["type"] == "attempt"] + exec_successes = [e for e in exec_events if e["type"] == "success"] + + w("## Statement Retry Analysis (ExecuteStatement only)") + w() + w("*Includes benchmarked ALTERs + one warmup SELECT 1 per worker thread.*") + w() + w(f"- **Total [PROFILE] events (all commands)**: {len(parsed_events)}") + w(f"- **ExecuteStatement attempts**: {len(exec_total_attempts)}") + w(f"- **ExecuteStatement successes**: {len(exec_successes)}") + w(f"- **ExecuteStatement retry sleeps**: {len(exec_retry_sleeps)}") + w(f"- **ExecuteStatement succeeded after retry (attempt > 1)**: {len(exec_success_after_retry)}") + w(f"- **should_retry evaluations**: {len(exec_should_retry)}") + w() + + if exec_retry_sleeps: + w("### Retry Events") + w() + w("| Timestamp | Thread | Attempt | Sleep (s) | Message |") + w("|-----------|--------|---------|-----------|---------|") + for e in exec_retry_sleeps[:50]: + ts = datetime.fromtimestamp(e["timestamp"]).strftime("%H:%M:%S.%f")[:-3] + w(f"| {ts} | {e['thread']} | {e['attempt']} | {e.get('sleep', '?')} | {e['message'][:150]} |") + if len(exec_retry_sleeps) > 50: + w(f"| ... | ... | ... | ... | {len(exec_retry_sleeps) - 50} more |") + w() + + if exec_should_retry: + w("### should_retry Decisions") + w() + status_counts = defaultdict(int) + for e in exec_should_retry: + status_counts[e["status"]] += 1 + w("| HTTP Status | Count |") + w("|-------------|-------|") + for status, count in sorted(status_counts.items()): + w(f"| {status} | {count} |") + w() + + # --- Footer --- + w("---") + w(f"*Generated by profile_table_tags.py on {datetime.now().isoformat()}*") + + report_text = "\n".join(lines) + with open(report_path, "w") as f: + f.write(report_text) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="Profile SET TABLE TAGS performance") + parser.add_argument("--tags", type=int, required=True, help="Number of tags per ALTER command (1, 2, 4)") + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads (1, 2, 4, 8, 16)") + parser.add_argument("--iterations", type=int, required=True, help="Number of times to repeat the full sweep") + parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration, print result") + args = parser.parse_args() + + if args.validate: + args.iterations = 1 + print("=== VALIDATION MODE: 1 iteration only ===\n") + + # File paths + os.makedirs(RESULTS_DIR, exist_ok=True) + prefix = f"tt_t{args.tags}_n{args.threads}_i{args.iterations}" + report_path = os.path.join(RESULTS_DIR, f"{prefix}_report.md") + data_path = os.path.join(RESULTS_DIR, f"{prefix}_data.jsonl") + log_path = os.path.join(RESULTS_DIR, f"{prefix}_retries.log") + + # Logging + profile_handler = setup_logging(log_path) + + print(f"Profile (TABLE TAGS): tags={args.tags}, threads={args.threads}, iterations={args.iterations}") + print(f"ALTERs per iteration: {NUM_TABLES} (one per table)") + print(f"Total ALTERs: {NUM_TABLES * args.iterations}") + print(f"Output: {report_path}") + print() + + # Run iterations + all_results = [] + iteration_durations = [] + + for i in range(1, args.iterations + 1): + print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) + results, duration = run_iteration( + iteration=i, + num_tags=args.tags, + num_threads=args.threads, + ) + all_results.extend(results) + iteration_durations.append(duration) + + errors = len([r for r in results if not r["success"]]) + rps = len(results) / duration if duration > 0 else 0 + print(f"done in {duration:.2f}s ({len(results)} ALTERs, {errors} errors, {rps:.1f} ALTERs/sec)") + + print() + + # Write raw data + with open(data_path, "w") as f: + for r in all_results: + f.write(json.dumps(r) + "\n") + + # Generate report + generate_report(args, all_results, iteration_durations, profile_handler, report_path) + + print(f"Report written to: {report_path}") + print(f"Raw data written to: {data_path}") + print(f"Retry log written to: {log_path}") + + # Print summary to stdout + success_lats = [r["latency_ms"] for r in all_results if r["success"]] + if success_lats: + s = latency_stats(success_lats) + total_dur = sum(iteration_durations) + print() + print("=== Summary ===") + print(f" ALTERs: {len(all_results)} ({len(success_lats)} ok, {len(all_results) - len(success_lats)} failed)") + print(f" Latency: p50={s['p50']:.1f}ms p90={s['p90']:.1f}ms p95={s['p95']:.1f}ms p99={s['p99']:.1f}ms max={s['max']:.1f}ms") + print(f" Throughput: {len(all_results) / total_dur:.1f} ALTERs/sec") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/test_connection.py b/examples/test_connection.py new file mode 100644 index 000000000..67adb49cc --- /dev/null +++ b/examples/test_connection.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +"""Quick smoke test: connect and run one CREATE TABLE.""" + +import logging +import time +from databricks import sql +from load_credentials import load_credentials + +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(name)s %(levelname)s %(message)s") + +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] + +print("Connecting...") +t0 = time.time() + +with sql.connect( + server_hostname=SERVER_HOSTNAME, + http_path=HTTP_PATH, + access_token=ACCESS_TOKEN, + _tls_no_verify=True, +) as conn: + print(f"Connected in {time.time() - t0:.2f}s") + + with conn.cursor() as cursor: + cursor.execute(f"USE CATALOG {CATALOG}") + print(f"USE CATALOG done in {time.time() - t0:.2f}s") + + cursor.execute(f"USE SCHEMA {SCHEMA}") + print(f"USE SCHEMA done in {time.time() - t0:.2f}s") + + t1 = time.time() + cursor.execute("CREATE TABLE IF NOT EXISTS test_conn_check (col1 STRING, col2 STRING)") + print(f"CREATE TABLE done in {time.time() - t1:.2f}s") + + t1 = time.time() + cursor.execute("SELECT 1") + print(f"SELECT 1 done in {time.time() - t1:.2f}s") + + t1 = time.time() + cursor.execute("DROP TABLE IF EXISTS test_conn_check") + print(f"DROP TABLE done in {time.time() - t1:.2f}s") + +print(f"Total: {time.time() - t0:.2f}s") \ No newline at end of file diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index b0c2f497d..a8b921296 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -231,11 +231,15 @@ def new( # Include urllib3's current state in our __init__ params databricks_init_params["urllib3_kwargs"].update(**urllib3_init_params) # type: ignore[attr-defined] - return type(self).__private_init__( + new_instance = type(self).__private_init__( retry_start_time=self._retry_start_time, command_type=self.command_type, **databricks_init_params, ) + # Carry profiling state across retries + new_instance.thrift_method_name = getattr(self, "thrift_method_name", None) + new_instance.last_sql_statement = getattr(self, "last_sql_statement", None) + return new_instance @property def command_type(self) -> Optional[CommandType]: @@ -294,9 +298,16 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: else: proposed_wait = self.get_backoff_time() - proposed_wait = max(proposed_wait, self.delay_max) + proposed_wait = min(proposed_wait, self.delay_max) self.check_proposed_wait(proposed_wait) - logger.debug(f"Retrying after {proposed_wait} seconds") + logger.info( + "[PROFILE] urllib3_retry sleep=%.1fs, command=%s, method=%s, sql=%s, retry_after=%s", + proposed_wait, + self.command_type and self.command_type.value, + getattr(self, "thrift_method_name", "unknown"), + getattr(self, "last_sql_statement", None), + retry_after, + ) time.sleep(proposed_wait) return True @@ -358,6 +369,14 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: if status_code // 100 <= 3: return False, "2xx/3xx codes are not retried" + logger.info( + "[PROFILE] should_retry: status=%d, command=%s, method=%s, sql=%s, evaluating_retry", + status_code, + self.command_type and self.command_type.value, + getattr(self, "thrift_method_name", "unknown"), + getattr(self, "last_sql_statement", None), + ) + if status_code == 400: return ( False, @@ -416,6 +435,13 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: logger.debug( f"This request should be retried: {self.command_type and self.command_type.value}" ) + logger.info( + "[PROFILE] should_retry: status=%d, command=%s, method=%s, sql=%s, decision=True, reason=default_retry_policy", + status_code, + self.command_type and self.command_type.value, + getattr(self, "thrift_method_name", "unknown"), + getattr(self, "last_sql_statement", None), + ) return ( True, "Failed requests are retried by default per configured DatabricksRetryPolicy", diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e23f3389b..171ef4f0d 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -354,6 +354,16 @@ def _handle_request_error(self, error_info, attempt, elapsed): error_info.retry_delay, full_error_info_context ) ) + logger.info( + "[PROFILE] %s retry sleep=%.1fs, attempt=%d/%d, elapsed=%.1fs/%ds, http_code=%s", + error_info.method, + error_info.retry_delay, + attempt, + max_attempts, + elapsed, + max_duration_s, + error_info.http_code, + ) time.sleep(error_info.retry_delay) # FUTURE: Consider moving to https://github.com/litl/backoff or @@ -410,6 +420,14 @@ def attempt_request(attempt): logger.debug("Sending request: {}()".format(this_method_name)) unsafe_logger.debug("Sending request: {}".format(request)) + # Always set the method name and SQL text for profiling + if hasattr(self._transport, "retry_policy") and self._transport.retry_policy: + self._transport.retry_policy.thrift_method_name = this_method_name + sql_statement = getattr(request, "statement", None) + self._transport.retry_policy.last_sql_statement = ( + sql_statement[:200] if sql_statement else None + ) + # These three lines are no-ops if the v3 retry policy is not in use if self.enable_v3_retries: this_command_type = CommandType.get(this_method_name) @@ -506,6 +524,13 @@ def attempt_request(attempt): # use index-1 counting for logging/human consistency for attempt in range(1, max_attempts + 1): + logger.info( + "[PROFILE] %s attempt %d/%d (elapsed=%.3fs)", + getattr(method, "__name__", "unknown"), + attempt, + max_attempts, + get_elapsed(), + ) # We have a lock here because .cancel can be called from a separate thread. # We do not want threads to be simultaneously sharing the Thrift Transport # because we use its state to determine retries @@ -515,7 +540,12 @@ def attempt_request(attempt): # conditions: success, non-retry-able, no-attempts-left, no-time-left, delay+retry if not isinstance(response_or_error_info, RequestErrorInfo): - # log nothing here, presume that main request logging covers + logger.info( + "[PROFILE] %s succeeded on attempt %d in %.3fs", + getattr(method, "__name__", "unknown"), + attempt, + elapsed, + ) response = response_or_error_info ThriftDatabricksClient._check_response_for_error(response, self._host) return response @@ -1059,6 +1089,14 @@ def execute_command( ) resp = self.make_request(self._client.ExecuteStatement, req) + if resp.operationHandle: + _cmd_id = CommandId.from_thrift_handle(resp.operationHandle) + logger.info( + "[PROFILE] ExecuteStatement statement_id=%s, sql=%s", + _cmd_id, + operation[:200] if operation else None, + ) + if async_op: self._handle_execute_response_async(resp, cursor) return None From 53a3088bcf07cbe30395488d8c9b9d7a4b8c6ccd Mon Sep 17 00:00:00 2001 From: "tejas.sp" <241722411+tejassp-db@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:36:17 +0530 Subject: [PATCH 2/3] Disable telemetry in all profiling scripts Pass enable_telemetry=False to sql.connect() to avoid telemetry HTTP calls during profiling runs. Co-authored-by: Isaac --- examples/cleanup_column_tags.py | 1 + examples/profile_column_tags.py | 1 + examples/profile_read_then_write_table_tags.py | 1 + examples/profile_read_then_write_tags.py | 1 + examples/profile_table_tags.py | 1 + examples/test_connection.py | 1 + 6 files changed, 6 insertions(+) diff --git a/examples/cleanup_column_tags.py b/examples/cleanup_column_tags.py index e816d0a40..5fadd97df 100644 --- a/examples/cleanup_column_tags.py +++ b/examples/cleanup_column_tags.py @@ -33,6 +33,7 @@ def cleanup_table(table_name): http_path=HTTP_PATH, access_token=ACCESS_TOKEN, _tls_no_verify=True, + enable_telemetry=False, ) as conn: with conn.cursor() as cursor: # --- Clean up column tags --- diff --git a/examples/profile_column_tags.py b/examples/profile_column_tags.py index 3f779f0ac..e377be8fd 100644 --- a/examples/profile_column_tags.py +++ b/examples/profile_column_tags.py @@ -117,6 +117,7 @@ def conn_params() -> dict: "http_path": HTTP_PATH, "access_token": ACCESS_TOKEN, "_tls_no_verify": True, + "enable_telemetry": False, } diff --git a/examples/profile_read_then_write_table_tags.py b/examples/profile_read_then_write_table_tags.py index a8561f7a3..8ecec9183 100644 --- a/examples/profile_read_then_write_table_tags.py +++ b/examples/profile_read_then_write_table_tags.py @@ -105,6 +105,7 @@ def conn_params() -> dict: "http_path": HTTP_PATH, "access_token": ACCESS_TOKEN, "_tls_no_verify": True, + "enable_telemetry": False, } diff --git a/examples/profile_read_then_write_tags.py b/examples/profile_read_then_write_tags.py index a80f7a994..39b584f02 100644 --- a/examples/profile_read_then_write_tags.py +++ b/examples/profile_read_then_write_tags.py @@ -106,6 +106,7 @@ def conn_params() -> dict: "http_path": HTTP_PATH, "access_token": ACCESS_TOKEN, "_tls_no_verify": True, + "enable_telemetry": False, } diff --git a/examples/profile_table_tags.py b/examples/profile_table_tags.py index a1ee5dfa6..6113dda54 100644 --- a/examples/profile_table_tags.py +++ b/examples/profile_table_tags.py @@ -116,6 +116,7 @@ def conn_params() -> dict: "http_path": HTTP_PATH, "access_token": ACCESS_TOKEN, "_tls_no_verify": True, + "enable_telemetry": False, } diff --git a/examples/test_connection.py b/examples/test_connection.py index 67adb49cc..0332a2980 100644 --- a/examples/test_connection.py +++ b/examples/test_connection.py @@ -23,6 +23,7 @@ http_path=HTTP_PATH, access_token=ACCESS_TOKEN, _tls_no_verify=True, + enable_telemetry=False, ) as conn: print(f"Connected in {time.time() - t0:.2f}s") From d0dcfd223daa3f145b2a7047bfa457a6a7194afc Mon Sep 17 00:00:00 2001 From: "tejas.sp" <241722411+tejassp-db@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:30:44 +0530 Subject: [PATCH 3/3] Add --tables-per-iteration arg, separate chart PNGs, disable telemetry - All 4 profiling scripts now accept --tables-per-iteration (defaults to --threads, i.e. 1 table per thread per iteration). - NUM_TABLES bumped to 128. - plot_comparison.py generates separate PNGs for table-level comparison (wall-clock, tables/sec) and individual operation detail (ops/sec, P50, P99, max). Column tags and table tags get their own PNGs. - Chart labels spell out parameters (columns, tags_per_column, tables). - All scripts pass enable_telemetry=False to sql.connect(). Co-authored-by: Isaac --- examples/plot_comparison.py | 112 +++++++++++++----- examples/profile_column_tags.py | 36 ++++-- .../profile_read_then_write_table_tags.py | 28 +++-- examples/profile_read_then_write_tags.py | 28 +++-- examples/profile_table_tags.py | 32 +++-- 5 files changed, 166 insertions(+), 70 deletions(-) diff --git a/examples/plot_comparison.py b/examples/plot_comparison.py index 2f22cfd82..a57853da2 100644 --- a/examples/plot_comparison.py +++ b/examples/plot_comparison.py @@ -38,7 +38,7 @@ def parse_report(filepath): m = re.search(r"\*\*(ALTERs/sec|SELECTs/sec|Operations/sec)\*\*:\s*([\d.]+)", content) if m: - metrics["throughput"] = float(m.group(2)) + metrics["throughput_ops"] = float(m.group(2)) for pct in ["p50", "p90", "p95", "p99"]: m = re.search(rf"\|\s*{pct}\s*\|\s*([\d.]+)\s*\|", content) @@ -65,6 +65,17 @@ def parse_report(filepath): if m: metrics["columns"] = int(m.group(1)) + m = re.search(r"\*\*Tables per iteration\*\*:\s*(\d+)", content) + if m: + metrics["tables_per_iteration"] = int(m.group(1)) + + # Also match older reports that used "Tables": N + if "tables_per_iteration" not in metrics: + m = re.search(r"\*\*Total SELECTs\*\*:\s*(\d+)", content) + iters = metrics.get("iterations", 1) + if m and iters: + metrics["tables_per_iteration"] = int(float(m.group(1))) // iters + m = re.search(r"\*\*Tags per ALTER\*\*:\s*(\d+)", content) if m: metrics["tags"] = int(m.group(1)) @@ -111,17 +122,19 @@ def discover_reports(): threads = metrics["threads"] + tbl = metrics.get("tables_per_iteration", "?") + if report_type == "alter" and category == "column": cols = metrics.get("columns", "?") tags = metrics.get("tags", "?") - label = f"ALTER column tags (c={cols}, t={tags})" + label = f"ALTER column tags (columns={cols}, tags_per_column={tags}, tables={tbl})" elif report_type == "alter" and category == "table": tags = metrics.get("tags", "?") - label = f"ALTER table tags (t={tags})" + label = f"ALTER table tags (tags={tags}, tables={tbl})" elif report_type == "info_schema" and category == "column": - label = "info_schema column_tags SELECT" + label = f"info_schema column_tags SELECT (tables={tbl})" elif report_type == "info_schema" and category == "table": - label = "info_schema table_tags SELECT" + label = f"info_schema table_tags SELECT (tables={tbl})" else: continue @@ -130,23 +143,24 @@ def discover_reports(): if existing and metrics.get("iterations", 0) <= existing.get("iterations", 0): continue + # Compute tables/sec from wall-clock and tables_per_iteration + tpi = metrics.get("tables_per_iteration") + wc = metrics.get("wall_clock_s") + if tpi and wc and wc > 0: + metrics["tables_per_sec"] = round(tpi / wc, 2) + categories[category][label][threads] = metrics print(f" [{category}] {label} threads={threads}: " f"wall={metrics.get('wall_clock_s', '?')}s, " f"p50={metrics.get('p50', '?')}ms, " - f"throughput={metrics.get('throughput', '?')} ops/s " + f"tables/s={metrics.get('tables_per_sec', '?')} " f"[{fname}]") return categories -def plot_category(category_name, series, output_path): - """Generate a 2x2 chart PNG for one category (column or table).""" - if not series: - print(f" No data for {category_name}, skipping.") - return - - # Color/style assignment +def build_style_map(series): + """Assign colors and styles to series labels.""" colors_info = ["#d62728", "#ff7f0e"] colors_alter = ["#1f77b4", "#2ca02c", "#9467bd", "#17becf", "#8c564b"] info_idx = 0 @@ -161,16 +175,20 @@ def plot_category(category_name, series, output_path): style_map[label] = {"color": colors_alter[alter_idx % len(colors_alter)], "marker": "s", "linestyle": "-"} alter_idx += 1 - fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + return style_map - chart_configs = [ - (axes[0][0], "wall_clock_s", "Wall-Clock Time (seconds)", "Wall-Clock Time vs Thread Count"), - (axes[0][1], "throughput", "Operations / second", "Throughput vs Thread Count"), - (axes[1][0], "p50", "P50 Latency (ms)", "P50 Latency vs Thread Count"), - (axes[1][1], "p99", "P99 Latency (ms)", "P99 Latency vs Thread Count"), - ] - for ax, metric_key, ylabel, title in chart_configs: +def plot_charts(series, style_map, chart_configs, suptitle, output_path): + """Generate a chart PNG with len(chart_configs) subplots.""" + n = len(chart_configs) + cols = 2 + rows = (n + 1) // 2 + fig, axes = plt.subplots(rows, cols, figsize=(16, 6 * rows)) + if rows == 1: + axes = [axes] + + for idx, (metric_key, ylabel, title) in enumerate(chart_configs): + ax = axes[idx // cols][idx % cols] for label, thread_data in sorted(series.items()): threads = sorted(thread_data.keys()) values = [thread_data[t].get(metric_key) for t in threads] @@ -184,15 +202,51 @@ def plot_category(category_name, series, output_path): ax.legend(fontsize=8) ax.grid(True, alpha=0.3) - title_label = "Column Tags" if category_name == "column" else "Table Tags" - plt.suptitle(f"SET TAGS Profiling: {title_label} — info_schema SELECT vs Direct ALTER", - fontsize=14, fontweight="bold") + # Hide unused subplot if odd number of charts + if n % 2 == 1: + axes[rows - 1][1].set_visible(False) + + plt.suptitle(suptitle, fontsize=14, fontweight="bold") plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) print(f" Chart saved to: {output_path}") +def plot_category(category_name, series, output_dir): + """Generate two PNGs per category: table-level comparison + individual operation detail.""" + if not series: + print(f" No data for {category_name}, skipping.") + return + + style_map = build_style_map(series) + title_label = "Column Tags" if category_name == "column" else "Table Tags" + + # Chart 1: Table-level comparison (apples-to-apples across approaches) + table_charts = [ + ("wall_clock_s", "Wall-Clock Time (seconds)", "Wall-Clock Time vs Thread Count (Lower is better)"), + ("tables_per_sec", "Tables / second", "Tables Processed per Second vs Thread Count (Higher is better)"), + ] + plot_charts( + series, style_map, table_charts, + f"{title_label}: Table-Level Comparison — info_schema SELECT vs Direct ALTER", + os.path.join(output_dir, f"comparison_{category_name}_tags_tables.png"), + ) + + # Chart 2: Individual operation detail (per-op latency) + op_charts = [ + ("throughput_ops", "Individual Operations / second", "Individual Op Throughput vs Thread Count (Higher is better)"), + ("p50", "P50 Latency per Op (ms)", "P50 Latency vs Thread Count (Lower is better)"), + ("p99", "P99 Latency per Op (ms)", "P99 Latency vs Thread Count (Lower is better)"), + ("max", "Max Latency per Op (ms)", "Max Latency vs Thread Count (Lower is better)"), + ] + plot_charts( + series, style_map, op_charts, + f"{title_label}: Individual Operation Detail", + os.path.join(output_dir, f"comparison_{category_name}_tags_ops.png"), + ) + + if __name__ == "__main__": print("Discovering results...\n") categories = discover_reports() @@ -202,14 +256,12 @@ def plot_category(category_name, series, output_path): print(f"\nFound {total_series} series across {total_points} data points.\n") if "column" in categories: - print("Generating column tags chart...") - plot_category("column", categories["column"], - os.path.join(RESULTS_DIR, "comparison_column_tags.png")) + print("Generating column tags charts...") + plot_category("column", categories["column"], RESULTS_DIR) if "table" in categories: - print("Generating table tags chart...") - plot_category("table", categories["table"], - os.path.join(RESULTS_DIR, "comparison_table_tags.png")) + print("Generating table tags charts...") + plot_category("table", categories["table"], RESULTS_DIR) if not categories: print("No results found. Run experiments first.") diff --git a/examples/profile_column_tags.py b/examples/profile_column_tags.py index e377be8fd..49a19ebc4 100644 --- a/examples/profile_column_tags.py +++ b/examples/profile_column_tags.py @@ -55,7 +55,7 @@ SCHEMA = _creds["SCHEMA"] # ============================================================ -NUM_TABLES = 64 +NUM_TABLES = 128 # total tables available (table1..table128) MAX_COLUMNS = 128 # tables always created with this many columns RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "column_tags") @@ -279,11 +279,14 @@ def run_iteration( num_columns: int, num_tags: int, num_threads: int, + tables_per_iteration: int, ) -> tuple: - """Run a single iteration: distribute 20 tables across threads.""" + """Run a single iteration: tables_per_iteration tables distributed across num_threads threads.""" table_queue = Queue() - for t in range(1, NUM_TABLES + 1): - table_queue.put(f"table{t}") + start = ((iteration - 1) * tables_per_iteration) % NUM_TABLES + for i in range(tables_per_iteration): + table_idx = start + i + 1 + table_queue.put(f"table{table_idx}") alter_results: list = [] table_results: list = [] @@ -353,7 +356,7 @@ def w(text=""): w(f"- **Server**: `{SERVER_HOSTNAME}`") w(f"- **HTTP Path**: `{HTTP_PATH}`") w(f"- **Catalog.Schema**: `{CATALOG}.{SCHEMA}`") - w(f"- **Tables**: {NUM_TABLES}") + w(f"- **Tables per iteration**: {args.tables_per_iteration}") w(f"- **Columns tagged per table**: {args.columns}") w(f"- **Tags per ALTER**: {args.tags}") w(f"- **Threads**: {args.threads}") @@ -583,18 +586,26 @@ def w(text=""): def main(): parser = argparse.ArgumentParser(description="Profile SET COLUMN TAGS performance") - parser.add_argument("--columns", type=int, required=True, help="Number of columns to tag per table (1, 2, 4)") - parser.add_argument("--tags", type=int, required=True, help="Number of tags per ALTER command (1, 2, 4)") - parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads (1, 2, 4, 8, 16)") - parser.add_argument("--iterations", type=int, required=True, help="Number of times to repeat the full sweep") + parser.add_argument("--columns", type=int, required=True, help="Number of columns to tag per table") + parser.add_argument("--tags", type=int, required=True, help="Number of tags per ALTER command") + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") + parser.add_argument("--iterations", type=int, required=True, help="Number of iterations") + parser.add_argument("--tables-per-iteration", type=int, default=None, help="Tables to process per iteration (default = --threads, i.e. 1 table per thread)") parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration, print result") parser.add_argument("--skip-setup", action="store_true", help="Skip table creation (tables already exist)") args = parser.parse_args() + if args.tables_per_iteration is None: + args.tables_per_iteration = args.threads + if args.columns > MAX_COLUMNS: print(f"Error: --columns {args.columns} exceeds MAX_COLUMNS={MAX_COLUMNS}") sys.exit(1) + if args.tables_per_iteration > NUM_TABLES: + print(f"Error: --tables-per-iteration {args.tables_per_iteration} exceeds NUM_TABLES={NUM_TABLES}") + sys.exit(1) + if args.validate: args.iterations = 1 print("=== VALIDATION MODE: 1 iteration only ===\n") @@ -609,9 +620,9 @@ def main(): # Logging profile_handler = setup_logging(log_path) - print(f"Profile: columns={args.columns}, tags={args.tags}, threads={args.threads}, iterations={args.iterations}") - print(f"ALTERs per iteration: {NUM_TABLES * args.columns}") - print(f"Total ALTERs: {NUM_TABLES * args.columns * args.iterations}") + print(f"Profile: columns={args.columns}, tags={args.tags}, threads={args.threads}, iterations={args.iterations}, tables_per_iteration={args.tables_per_iteration}") + print(f"ALTERs per iteration: {args.tables_per_iteration * args.columns} ({args.tables_per_iteration} tables x {args.columns} columns)") + print(f"Total ALTERs: {args.tables_per_iteration * args.columns * args.iterations}") print(f"Output: {report_path}") print() @@ -631,6 +642,7 @@ def main(): num_columns=args.columns, num_tags=args.tags, num_threads=args.threads, + tables_per_iteration=args.tables_per_iteration, ) all_alter_results.extend(alter_results) all_table_results.extend(table_results) diff --git a/examples/profile_read_then_write_table_tags.py b/examples/profile_read_then_write_table_tags.py index 8ecec9183..5aaf729c4 100644 --- a/examples/profile_read_then_write_table_tags.py +++ b/examples/profile_read_then_write_table_tags.py @@ -45,7 +45,7 @@ SCHEMA = _creds["SCHEMA"] # ============================================================ -NUM_TABLES = 64 +NUM_TABLES = 128 RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "read_then_write_table_tags") SELECT_TEMPLATE = """SELECT tag_name, tag_value @@ -215,10 +215,12 @@ def worker( # Run one iteration # --------------------------------------------------------------------------- -def run_iteration(iteration: int, num_threads: int) -> tuple: +def run_iteration(iteration: int, num_threads: int, tables_per_iteration: int) -> tuple: table_queue = Queue() - for t in range(1, NUM_TABLES + 1): - table_queue.put(f"table{t}") + start = ((iteration - 1) * tables_per_iteration) % NUM_TABLES + for i in range(tables_per_iteration): + table_idx = start + i + 1 + table_queue.put(f"table{table_idx}") results: list = [] results_lock = threading.Lock() @@ -484,10 +486,18 @@ def main(): description="Profile read-from-information_schema then write-table-tag pattern" ) parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") - parser.add_argument("--iterations", type=int, required=True, help="Number of times to repeat the full sweep") + parser.add_argument("--iterations", type=int, required=True, help="Number of iterations") + parser.add_argument("--tables-per-iteration", type=int, default=None, help="Tables per iteration (default = --threads)") parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration") args = parser.parse_args() + if args.tables_per_iteration is None: + args.tables_per_iteration = args.threads + + if args.tables_per_iteration > NUM_TABLES: + print(f"Error: --tables-per-iteration {args.tables_per_iteration} exceeds NUM_TABLES={NUM_TABLES}") + sys.exit(1) + if args.validate: args.iterations = 1 print("=== VALIDATION MODE: 1 iteration only ===\n") @@ -500,9 +510,9 @@ def main(): profile_handler = setup_logging(log_path) - print(f"Profile (information_schema.table_tags): threads={args.threads}, iterations={args.iterations}") - print(f"SELECTs per iteration: {NUM_TABLES} (1 per table)") - print(f"Total SELECTs: {NUM_TABLES * args.iterations}") + print(f"Profile (information_schema.table_tags): threads={args.threads}, iterations={args.iterations}, tables_per_iteration={args.tables_per_iteration}") + print(f"SELECTs per iteration: {args.tables_per_iteration} (1 per table)") + print(f"Total SELECTs: {args.tables_per_iteration * args.iterations}") print(f"Output: {report_path}") print() @@ -511,7 +521,7 @@ def main(): for i in range(1, args.iterations + 1): print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) - results, duration = run_iteration(iteration=i, num_threads=args.threads) + results, duration = run_iteration(iteration=i, num_threads=args.threads, tables_per_iteration=args.tables_per_iteration) all_results.extend(results) iteration_durations.append(duration) diff --git a/examples/profile_read_then_write_tags.py b/examples/profile_read_then_write_tags.py index 39b584f02..0ac138a18 100644 --- a/examples/profile_read_then_write_tags.py +++ b/examples/profile_read_then_write_tags.py @@ -46,7 +46,7 @@ SCHEMA = _creds["SCHEMA"] # ============================================================ -NUM_TABLES = 64 +NUM_TABLES = 128 RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "read_then_write") SELECT_TEMPLATE = """SELECT column_name, tag_name, tag_value @@ -216,10 +216,12 @@ def worker( # Run one iteration # --------------------------------------------------------------------------- -def run_iteration(iteration: int, num_threads: int) -> tuple: +def run_iteration(iteration: int, num_threads: int, tables_per_iteration: int) -> tuple: table_queue = Queue() - for t in range(1, NUM_TABLES + 1): - table_queue.put(f"table{t}") + start = ((iteration - 1) * tables_per_iteration) % NUM_TABLES + for i in range(tables_per_iteration): + table_idx = start + i + 1 + table_queue.put(f"table{table_idx}") results: list = [] results_lock = threading.Lock() @@ -485,10 +487,18 @@ def main(): description="Profile read-from-information_schema then write-column-tag pattern" ) parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") - parser.add_argument("--iterations", type=int, required=True, help="Number of times to repeat the full sweep") + parser.add_argument("--iterations", type=int, required=True, help="Number of iterations") + parser.add_argument("--tables-per-iteration", type=int, default=None, help="Tables per iteration (default = --threads)") parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration") args = parser.parse_args() + if args.tables_per_iteration is None: + args.tables_per_iteration = args.threads + + if args.tables_per_iteration > NUM_TABLES: + print(f"Error: --tables-per-iteration {args.tables_per_iteration} exceeds NUM_TABLES={NUM_TABLES}") + sys.exit(1) + if args.validate: args.iterations = 1 print("=== VALIDATION MODE: 1 iteration only ===\n") @@ -501,9 +511,9 @@ def main(): profile_handler = setup_logging(log_path) - print(f"Profile (information_schema.column_tags): threads={args.threads}, iterations={args.iterations}") - print(f"SELECTs per iteration: {NUM_TABLES} (1 per table)") - print(f"Total SELECTs: {NUM_TABLES * args.iterations}") + print(f"Profile (information_schema.column_tags): threads={args.threads}, iterations={args.iterations}, tables_per_iteration={args.tables_per_iteration}") + print(f"SELECTs per iteration: {args.tables_per_iteration} (1 per table)") + print(f"Total SELECTs: {args.tables_per_iteration * args.iterations}") print(f"Output: {report_path}") print() @@ -512,7 +522,7 @@ def main(): for i in range(1, args.iterations + 1): print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) - results, duration = run_iteration(iteration=i, num_threads=args.threads) + results, duration = run_iteration(iteration=i, num_threads=args.threads, tables_per_iteration=args.tables_per_iteration) all_results.extend(results) iteration_durations.append(duration) diff --git a/examples/profile_table_tags.py b/examples/profile_table_tags.py index 6113dda54..f73d6074e 100644 --- a/examples/profile_table_tags.py +++ b/examples/profile_table_tags.py @@ -55,7 +55,7 @@ SCHEMA = _creds["SCHEMA"] # ============================================================ -NUM_TABLES = 64 +NUM_TABLES = 128 RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "table_tags") @@ -230,11 +230,14 @@ def run_iteration( iteration: int, num_tags: int, num_threads: int, + tables_per_iteration: int, ) -> tuple: - """Run a single iteration: distribute 64 tables across threads.""" + """Run a single iteration: tables_per_iteration tables distributed across threads.""" table_queue = Queue() - for t in range(1, NUM_TABLES + 1): - table_queue.put(f"table{t}") + start = ((iteration - 1) * tables_per_iteration) % NUM_TABLES + for i in range(tables_per_iteration): + table_idx = start + i + 1 + table_queue.put(f"table{table_idx}") alter_results: list = [] results_lock = threading.Lock() @@ -500,12 +503,20 @@ def w(text=""): def main(): parser = argparse.ArgumentParser(description="Profile SET TABLE TAGS performance") - parser.add_argument("--tags", type=int, required=True, help="Number of tags per ALTER command (1, 2, 4)") - parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads (1, 2, 4, 8, 16)") - parser.add_argument("--iterations", type=int, required=True, help="Number of times to repeat the full sweep") + parser.add_argument("--tags", type=int, required=True, help="Number of tags per ALTER command") + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") + parser.add_argument("--iterations", type=int, required=True, help="Number of iterations") + parser.add_argument("--tables-per-iteration", type=int, default=None, help="Tables per iteration (default = --threads)") parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration, print result") args = parser.parse_args() + if args.tables_per_iteration is None: + args.tables_per_iteration = args.threads + + if args.tables_per_iteration > NUM_TABLES: + print(f"Error: --tables-per-iteration {args.tables_per_iteration} exceeds NUM_TABLES={NUM_TABLES}") + sys.exit(1) + if args.validate: args.iterations = 1 print("=== VALIDATION MODE: 1 iteration only ===\n") @@ -520,9 +531,9 @@ def main(): # Logging profile_handler = setup_logging(log_path) - print(f"Profile (TABLE TAGS): tags={args.tags}, threads={args.threads}, iterations={args.iterations}") - print(f"ALTERs per iteration: {NUM_TABLES} (one per table)") - print(f"Total ALTERs: {NUM_TABLES * args.iterations}") + print(f"Profile (TABLE TAGS): tags={args.tags}, threads={args.threads}, iterations={args.iterations}, tables_per_iteration={args.tables_per_iteration}") + print(f"ALTERs per iteration: {args.tables_per_iteration} (one per table)") + print(f"Total ALTERs: {args.tables_per_iteration * args.iterations}") print(f"Output: {report_path}") print() @@ -536,6 +547,7 @@ def main(): iteration=i, num_tags=args.tags, num_threads=args.threads, + tables_per_iteration=args.tables_per_iteration, ) all_results.extend(results) iteration_durations.append(duration)