forked from lab/TPM
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
260 lines
8.1 KiB
260 lines
8.1 KiB
|
2 months ago
|
"""
|
||
|
|
市場基準資料模組
|
||
|
|
|
||
|
|
從資料庫取得實際的市場基準資料(台股加權指數、S&P 500)
|
||
|
|
用於 Context Engineering 的市場環境背景
|
||
|
|
"""
|
||
|
|
|
||
|
|
import psycopg2
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
from typing import Dict, Any, Optional
|
||
|
|
import logging
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# 從 config 匯入資料庫設定
|
||
|
|
try:
|
||
|
|
from config import SQL_CONFIG
|
||
|
|
except ImportError:
|
||
|
|
# Fallback 設定
|
||
|
|
SQL_CONFIG = {
|
||
|
|
"database": "portfolio_platform",
|
||
|
|
"user": "postgres",
|
||
|
|
"host": "db",
|
||
|
|
"port": "5432",
|
||
|
|
"password": "thiispassword1qaz!QAZ"
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
class MarketBenchmark:
|
||
|
|
"""市場基準資料類別"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
"""初始化市場基準資料"""
|
||
|
|
self.cache = {}
|
||
|
|
self.cache_timeout = 3600 # 1小時快取
|
||
|
|
self.cache_time = {}
|
||
|
|
|
||
|
|
def _is_cache_valid(self, key: str) -> bool:
|
||
|
|
"""檢查快取是否有效"""
|
||
|
|
if key not in self.cache_time:
|
||
|
|
return False
|
||
|
|
return (datetime.now().timestamp() - self.cache_time[key]) < self.cache_timeout
|
||
|
|
|
||
|
|
def get_market_context(self, tw: bool = True, force_refresh: bool = False) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
獲取市場環境背景(從資料庫計算實際數據)
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tw: True=台灣市場,False=美國市場
|
||
|
|
force_refresh: 強制重新計算(不使用快取)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
市場環境背景字典
|
||
|
|
"""
|
||
|
|
cache_key = f"market_{'tw' if tw else 'us'}"
|
||
|
|
|
||
|
|
# 檢查快取
|
||
|
|
if not force_refresh and self._is_cache_valid(cache_key):
|
||
|
|
logger.info(f"Using cached market context for {'TW' if tw else 'US'}")
|
||
|
|
return self.cache[cache_key]
|
||
|
|
|
||
|
|
try:
|
||
|
|
if tw:
|
||
|
|
context = self._get_tw_market_context()
|
||
|
|
else:
|
||
|
|
context = self._get_us_market_context()
|
||
|
|
|
||
|
|
# 更新快取
|
||
|
|
self.cache[cache_key] = context
|
||
|
|
self.cache_time[cache_key] = datetime.now().timestamp()
|
||
|
|
|
||
|
|
logger.info(f"Calculated market context for {'TW' if tw else 'US'}: {context}")
|
||
|
|
return context
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error getting market context: {e}")
|
||
|
|
# Fallback 到靜態資料
|
||
|
|
return self._get_fallback_context(tw)
|
||
|
|
|
||
|
|
def _get_tw_market_context(self) -> Dict[str, Any]:
|
||
|
|
"""取得台灣市場基準資料(從資料庫計算)"""
|
||
|
|
conn = psycopg2.connect(**SQL_CONFIG)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# 取得 0050.TW 近期資料
|
||
|
|
query = """
|
||
|
|
SELECT date, price
|
||
|
|
FROM stock_price_tw
|
||
|
|
WHERE ticker = '0050.TW'
|
||
|
|
ORDER BY date DESC
|
||
|
|
LIMIT 1260 -- 約5年交易日
|
||
|
|
"""
|
||
|
|
|
||
|
|
df = pd.read_sql(query, conn)
|
||
|
|
df = df.sort_values('date')
|
||
|
|
df['return'] = df['price'].pct_change()
|
||
|
|
|
||
|
|
# 計算各項指標
|
||
|
|
latest_price = df['price'].iloc[-1]
|
||
|
|
year_start_idx = max(0, len(df) - 252) # 今年開始(約252交易日)
|
||
|
|
ytd_return = (latest_price / df['price'].iloc[year_start_idx]) - 1
|
||
|
|
|
||
|
|
# 近5年年化報酬
|
||
|
|
total_return = (latest_price / df['price'].iloc[0]) - 1
|
||
|
|
years = len(df) / 252
|
||
|
|
avg_5y_return = (1 + total_return) ** (1 / years) - 1
|
||
|
|
|
||
|
|
# 年化波動率
|
||
|
|
volatility = df['return'].std() * np.sqrt(252)
|
||
|
|
|
||
|
|
# 市場情緒判斷(基於近期趨勢)
|
||
|
|
recent_returns = df['return'].iloc[-63:].sum() # 最近3個月
|
||
|
|
if recent_returns > 0.05:
|
||
|
|
sentiment = "bull"
|
||
|
|
elif recent_returns < -0.05:
|
||
|
|
sentiment = "bear"
|
||
|
|
else:
|
||
|
|
sentiment = "neutral"
|
||
|
|
|
||
|
|
return {
|
||
|
|
"market_name": "台灣加權指數(0050.TW)",
|
||
|
|
"ytd_return": float(ytd_return),
|
||
|
|
"avg_5y_return": float(avg_5y_return),
|
||
|
|
"current_price": float(latest_price),
|
||
|
|
"volatility": float(volatility),
|
||
|
|
"sentiment": sentiment,
|
||
|
|
"last_update": df['date'].iloc[-1].strftime("%Y-%m-%d"),
|
||
|
|
"data_points": len(df)
|
||
|
|
}
|
||
|
|
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def _get_us_market_context(self) -> Dict[str, Any]:
|
||
|
|
"""取得美國市場基準資料(從資料庫計算)"""
|
||
|
|
conn = psycopg2.connect(**SQL_CONFIG)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# 取得 SPY 近期資料
|
||
|
|
query = """
|
||
|
|
SELECT date, price
|
||
|
|
FROM stock_price
|
||
|
|
WHERE ticker = 'SPY'
|
||
|
|
ORDER BY date DESC
|
||
|
|
LIMIT 1260 -- 約5年交易日
|
||
|
|
"""
|
||
|
|
|
||
|
|
df = pd.read_sql(query, conn)
|
||
|
|
df = df.sort_values('date')
|
||
|
|
df['return'] = df['price'].pct_change()
|
||
|
|
|
||
|
|
# 計算各項指標
|
||
|
|
latest_price = df['price'].iloc[-1]
|
||
|
|
year_start_idx = max(0, len(df) - 252)
|
||
|
|
ytd_return = (latest_price / df['price'].iloc[year_start_idx]) - 1
|
||
|
|
|
||
|
|
# 近5年年化報酬
|
||
|
|
total_return = (latest_price / df['price'].iloc[0]) - 1
|
||
|
|
years = len(df) / 252
|
||
|
|
avg_5y_return = (1 + total_return) ** (1 / years) - 1
|
||
|
|
|
||
|
|
# 年化波動率
|
||
|
|
volatility = df['return'].std() * np.sqrt(252)
|
||
|
|
|
||
|
|
# 市場情緒判斷
|
||
|
|
recent_returns = df['return'].iloc[-63:].sum()
|
||
|
|
if recent_returns > 0.05:
|
||
|
|
sentiment = "bull"
|
||
|
|
elif recent_returns < -0.05:
|
||
|
|
sentiment = "bear"
|
||
|
|
else:
|
||
|
|
sentiment = "neutral"
|
||
|
|
|
||
|
|
return {
|
||
|
|
"market_name": "S&P 500(SPY)",
|
||
|
|
"ytd_return": float(ytd_return),
|
||
|
|
"avg_5y_return": float(avg_5y_return),
|
||
|
|
"current_price": float(latest_price),
|
||
|
|
"volatility": float(volatility),
|
||
|
|
"sentiment": sentiment,
|
||
|
|
"last_update": df['date'].iloc[-1].strftime("%Y-%m-%d"),
|
||
|
|
"data_points": len(df)
|
||
|
|
}
|
||
|
|
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def _get_fallback_context(self, tw: bool) -> Dict[str, Any]:
|
||
|
|
"""Fallback 靜態資料(資料庫查詢失敗時使用)"""
|
||
|
|
if tw:
|
||
|
|
return {
|
||
|
|
"market_name": "台灣加權指數",
|
||
|
|
"ytd_return": 0.18,
|
||
|
|
"avg_5y_return": 0.09,
|
||
|
|
"volatility": 0.15,
|
||
|
|
"sentiment": "neutral",
|
||
|
|
"last_update": "static",
|
||
|
|
"is_fallback": True
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
return {
|
||
|
|
"market_name": "S&P 500",
|
||
|
|
"ytd_return": 0.22,
|
||
|
|
"avg_5y_return": 0.12,
|
||
|
|
"volatility": 0.14,
|
||
|
|
"sentiment": "bull",
|
||
|
|
"last_update": "static",
|
||
|
|
"is_fallback": True
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# 單例模式
|
||
|
|
_market_benchmark_instance = None
|
||
|
|
|
||
|
|
def get_market_benchmark() -> MarketBenchmark:
|
||
|
|
"""獲取市場基準實例(單例)"""
|
||
|
|
global _market_benchmark_instance
|
||
|
|
if _market_benchmark_instance is None:
|
||
|
|
_market_benchmark_instance = MarketBenchmark()
|
||
|
|
return _market_benchmark_instance
|
||
|
|
|
||
|
|
|
||
|
|
# 便利函數(向後兼容)
|
||
|
|
def get_market_context(tw: bool = True) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
獲取市場環境背景
|
||
|
|
|
||
|
|
此函數與 prompts/investment_advice_v2.py 中的函數簽名相同
|
||
|
|
可直接替換使用
|
||
|
|
"""
|
||
|
|
benchmark = get_market_benchmark()
|
||
|
|
return benchmark.get_market_context(tw)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
# 測試腳本
|
||
|
|
import json
|
||
|
|
|
||
|
|
logging.basicConfig(level=logging.INFO)
|
||
|
|
|
||
|
|
print("="*80)
|
||
|
|
print("測試市場基準資料模組")
|
||
|
|
print("="*80)
|
||
|
|
|
||
|
|
# 測試台灣市場
|
||
|
|
print("\n台灣市場基準:")
|
||
|
|
tw_context = get_market_context(tw=True)
|
||
|
|
print(json.dumps(tw_context, indent=2, ensure_ascii=False))
|
||
|
|
|
||
|
|
# 測試美國市場
|
||
|
|
print("\n美國市場基準:")
|
||
|
|
us_context = get_market_context(tw=False)
|
||
|
|
print(json.dumps(us_context, indent=2, ensure_ascii=False))
|
||
|
|
|
||
|
|
print("\n" + "="*80)
|
||
|
|
print("測試完成!")
|
||
|
|
print("="*80)
|