Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions examples/prediction_batch_example.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,61 @@
import pandas as pd
import matplotlib.pyplot as plt
import sys

import matplotlib.pyplot as plt
import pandas as pd

sys.path.append("../")
from model import Kronos, KronosTokenizer, KronosPredictor
from model import Kronos, KronosPredictor, KronosTokenizer


def plot_prediction(kline_df, pred_df):
pred_df.index = kline_df.index[-pred_df.shape[0]:]
sr_close = kline_df['close']
sr_pred_close = pred_df['close']
sr_close.name = 'Ground Truth'
pred_df.index = kline_df.index[-pred_df.shape[0] :]
sr_close = kline_df["close"]
sr_pred_close = pred_df["close"]
sr_close.name = "Ground Truth"
sr_pred_close.name = "Prediction"

sr_volume = kline_df['volume']
sr_pred_volume = pred_df['volume']
sr_volume.name = 'Ground Truth'
sr_volume = kline_df["volume"]
sr_pred_volume = pred_df["volume"]
sr_volume.name = "Ground Truth"
sr_pred_volume.name = "Prediction"

close_df = pd.concat([sr_close, sr_pred_close], axis=1)
volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)

ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax1.set_ylabel('Close Price', fontsize=14)
ax1.legend(loc='lower left', fontsize=12)
ax1.plot(
close_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5
)
ax1.plot(close_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
ax1.set_ylabel("Close Price", fontsize=14)
ax1.legend(loc="lower left", fontsize=12)
ax1.grid(True)

ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax2.set_ylabel('Volume', fontsize=14)
ax2.legend(loc='upper left', fontsize=12)
ax2.plot(
volume_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5
)
ax2.plot(volume_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
ax2.set_ylabel("Volume", fontsize=14)
ax2.legend(loc="upper left", fontsize=12)
ax2.grid(True)

plt.tight_layout()
plt.show()


# 1. Load Model and Tokenizer
tokenizer = KronosTokenizer.from_pretrained('/home/csc/huggingface/Kronos-Tokenizer-base/')
tokenizer = KronosTokenizer.from_pretrained(
"/home/csc/huggingface/Kronos-Tokenizer-base/"
)
model = Kronos.from_pretrained("/home/csc/huggingface/Kronos-base/")

# 2. Instantiate Predictor
predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)

# 3. Prepare Data
df = pd.read_csv("./data/XSHG_5min_600977.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])
df["timestamps"] = pd.to_datetime(df["timestamps"])

lookback = 400
pred_len = 120
Expand All @@ -56,9 +64,14 @@ def plot_prediction(kline_df, pred_df):
xtsp = []
ytsp = []
for i in range(5):
idf = df.loc[(i*400):(i*400+lookback-1), ['open', 'high', 'low', 'close', 'volume', 'amount']]
i_x_timestamp = df.loc[(i*400):(i*400+lookback-1), 'timestamps']
i_y_timestamp = df.loc[(i*400+lookback):(i*400+lookback+pred_len-1), 'timestamps']
idf = df.loc[
(i * 400) : (i * 400 + lookback - 1),
["open", "high", "low", "close", "volume", "amount"],
]
i_x_timestamp = df.loc[(i * 400) : (i * 400 + lookback - 1), "timestamps"]
i_y_timestamp = df.loc[
(i * 400 + lookback) : (i * 400 + lookback + pred_len - 1), "timestamps"
]

dfs.append(idf)
xtsp.append(i_x_timestamp)
Expand Down
78 changes: 51 additions & 27 deletions examples/prediction_cn_markets_day.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
prediction_cn_markets_day.py

Expand All @@ -21,15 +20,17 @@
python3 prediction_cn_markets_day.py --symbol 002594
"""

import os
import argparse
import os
import sys
import time
import pandas as pd

import akshare as ak
import matplotlib.pyplot as plt
import sys
import pandas as pd

sys.path.append("../")
from model import Kronos, KronosTokenizer, KronosPredictor
from model import Kronos, KronosPredictor, KronosTokenizer

save_dir = "./outputs"
os.makedirs(save_dir, exist_ok=True)
Expand All @@ -45,6 +46,7 @@
TOP_P = 0.9
SAMPLE_COUNT = 1


def load_data(symbol: str) -> pd.DataFrame:
print(f"📥 Fetching {symbol} daily data from akshare ...")

Expand All @@ -63,18 +65,23 @@ def load_data(symbol: str) -> pd.DataFrame:

# If still empty after retries
if df is None or df.empty:
print(f"❌ Failed to fetch data for {symbol} after {max_retries} attempts. Exiting.")
print(
f"❌ Failed to fetch data for {symbol} after {max_retries} attempts. Exiting."
)
sys.exit(1)

df.rename(columns={
"日期": "date",
"开盘": "open",
"收盘": "close",
"最高": "high",
"最低": "low",
"成交量": "volume",
"成交额": "amount"
}, inplace=True)

df.rename(
columns={
"日期": "date",
"开盘": "open",
"收盘": "close",
"最高": "high",
"最低": "low",
"成交量": "volume",
"成交额": "amount",
},
inplace=True,
)

df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date").reset_index(drop=True)
Expand All @@ -101,7 +108,9 @@ def load_data(symbol: str) -> pd.DataFrame:
if df["amount"].isna().all() or (df["amount"] == 0).all():
df["amount"] = df["close"] * df["volume"]

print(f"✅ Data loaded: {len(df)} rows, range: {df['date'].min()} ~ {df['date'].max()}")
print(
f"✅ Data loaded: {len(df)} rows, range: {df['date'].min()} ~ {df['date'].max()}"
)

print("Data Head:")
print(df.head())
Expand All @@ -110,13 +119,16 @@ def load_data(symbol: str) -> pd.DataFrame:


def prepare_inputs(df):
x_df = df.iloc[-LOOKBACK:][["open","high","low","close","volume","amount"]]
x_df = df.iloc[-LOOKBACK:][["open", "high", "low", "close", "volume", "amount"]]
x_timestamp = df.iloc[-LOOKBACK:]["date"]
y_timestamp = pd.bdate_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=PRED_LEN)
y_timestamp = pd.bdate_range(
start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=PRED_LEN
)
return x_df, pd.Series(x_timestamp), pd.Series(y_timestamp)


def apply_price_limits(pred_df, last_close, limit_rate=0.1):
print(f"🔒 Applying ±{limit_rate*100:.0f}% price limit ...")
print(f"🔒 Applying ±{limit_rate * 100:.0f}% price limit ...")

# Ensure integer index
pred_df = pred_df.reset_index(drop=True)
Expand All @@ -143,7 +155,13 @@ def apply_price_limits(pred_df, last_close, limit_rate=0.1):
def plot_result(df_hist, df_pred, symbol):
plt.figure(figsize=(12, 6))
plt.plot(df_hist["date"], df_hist["close"], label="Historical", color="blue")
plt.plot(df_pred["date"], df_pred["close"], label="Predicted", color="red", linestyle="--")
plt.plot(
df_pred["date"],
df_pred["close"],
label="Predicted",
color="red",
linestyle="--",
)
plt.title(f"Kronos Prediction for {symbol}")
plt.xlabel("Date")
plt.ylabel("Close Price")
Expand All @@ -157,10 +175,14 @@ def plot_result(df_hist, df_pred, symbol):


def predict_future(symbol):
print(f"🚀 Loading Kronos tokenizer:{TOKENIZER_PRETRAINED} model:{MODEL_PRETRAINED} ...")
print(
f"🚀 Loading Kronos tokenizer:{TOKENIZER_PRETRAINED} model:{MODEL_PRETRAINED} ..."
)
tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_PRETRAINED)
model = Kronos.from_pretrained(MODEL_PRETRAINED)
predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=MAX_CONTEXT)
predictor = KronosPredictor(
model, tokenizer, device=DEVICE, max_context=MAX_CONTEXT
)

df = load_data(symbol)
x_df, x_timestamp, y_timestamp = prepare_inputs(df)
Expand All @@ -184,10 +206,12 @@ def predict_future(symbol):
pred_df = apply_price_limits(pred_df, last_close, limit_rate=0.1)

# Merge historical and predicted data
df_out = pd.concat([
df[["date", "open", "high", "low", "close", "volume", "amount"]],
pred_df[["date", "open", "high", "low", "close", "volume", "amount"]]
]).reset_index(drop=True)
df_out = pd.concat(
[
df[["date", "open", "high", "low", "close", "volume", "amount"]],
pred_df[["date", "open", "high", "low", "close", "volume", "amount"]],
]
).reset_index(drop=True)

# Save CSV
out_file = os.path.join(save_dir, f"pred_{symbol.replace('.', '_')}_data.csv")
Expand Down
55 changes: 30 additions & 25 deletions examples/prediction_example.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,43 @@
import pandas as pd
import matplotlib.pyplot as plt
import sys

import matplotlib.pyplot as plt
import pandas as pd

sys.path.append("../")
from model import Kronos, KronosTokenizer, KronosPredictor
from model import Kronos, KronosPredictor, KronosTokenizer


def plot_prediction(kline_df, pred_df):
pred_df.index = kline_df.index[-pred_df.shape[0]:]
sr_close = kline_df['close']
sr_pred_close = pred_df['close']
sr_close.name = 'Ground Truth'
pred_df.index = kline_df.index[-pred_df.shape[0] :]
sr_close = kline_df["close"]
sr_pred_close = pred_df["close"]
sr_close.name = "Ground Truth"
sr_pred_close.name = "Prediction"

sr_volume = kline_df['volume']
sr_pred_volume = pred_df['volume']
sr_volume.name = 'Ground Truth'
sr_volume = kline_df["volume"]
sr_pred_volume = pred_df["volume"]
sr_volume.name = "Ground Truth"
sr_pred_volume.name = "Prediction"

close_df = pd.concat([sr_close, sr_pred_close], axis=1)
volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)

ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax1.set_ylabel('Close Price', fontsize=14)
ax1.legend(loc='lower left', fontsize=12)
ax1.plot(
close_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5
)
ax1.plot(close_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
ax1.set_ylabel("Close Price", fontsize=14)
ax1.legend(loc="lower left", fontsize=12)
ax1.grid(True)

ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax2.set_ylabel('Volume', fontsize=14)
ax2.legend(loc='upper left', fontsize=12)
ax2.plot(
volume_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5
)
ax2.plot(volume_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
ax2.set_ylabel("Volume", fontsize=14)
ax2.legend(loc="upper left", fontsize=12)
ax2.grid(True)

plt.tight_layout()
Expand All @@ -47,14 +53,14 @@ def plot_prediction(kline_df, pred_df):

# 3. Prepare Data
df = pd.read_csv("./data/XSHG_5min_600977.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])
df["timestamps"] = pd.to_datetime(df["timestamps"])

lookback = 400
pred_len = 120

x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']]
x_timestamp = df.loc[:lookback-1, 'timestamps']
y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps']
x_df = df.loc[: lookback - 1, ["open", "high", "low", "close", "volume", "amount"]]
x_timestamp = df.loc[: lookback - 1, "timestamps"]
y_timestamp = df.loc[lookback : lookback + pred_len - 1, "timestamps"]

# 4. Make Prediction
pred_df = predictor.predict(
Expand All @@ -65,16 +71,15 @@ def plot_prediction(kline_df, pred_df):
T=1.0,
top_p=0.9,
sample_count=1,
verbose=True
verbose=True,
)

# 5. Visualize Results
print("Forecasted Data Head:")
print(pred_df.head())

# Combine historical and forecasted data for plotting
kline_df = df.loc[:lookback+pred_len-1]
kline_df = df.loc[: lookback + pred_len - 1]

# visualize
plot_prediction(kline_df, pred_df)

Loading