文章目录
- Toolkit 作用
- Toolkit 逐函数解析
- 1. 获取默认配置
- 2. update_config
- 3. config
- 4. `__init__`
- 5. get_reddit_news
- 6. get_finnhub_news
- 7. get_reddit_stock_info
- 8. get_chinese_social_sentiment
- 9. get_finnhub_company_insider_sentiment
- 10. get_YFin_data
- 11. get_YFin_data_online
- 12. get_stockstats_indicators_report
- 13. get_stockstats_indicators_report_online
- 14. get_simfin_balance_sheet 资产负债表
- 15. get_simfin_income_statement利润表
- 16. get_simfin_cashflow_statement 现金流量表
- 17. get_simfin_ratios 财务指标比率
- 18. get_simfin_company_info 公司基本信息
- 19. get_simfin_shareprices 历史股价数据
- 20. get_fundamentals_openai 财务+估值
- 21. get_china_fundamentals A 股财务数据
- 22. get_stock_fundamentals_unified 整合财务+估值
- create_msg_delete函数(非Toolkit类)
- ChromaDBManager
- 类说明
- 逐函数解析
- 类属性
- `__new__(cls)`
- `__init__(self)`
- get_or_create_collection(self, name: str)
- FinancialSituationMemory类
- 类说明
- 简化后代码
- Example
- 逐函数解析
- `__init__(self, config)`
- `_smart_text_truncation(self, text, max_length)`
- get_embedding(self, text, max_length=1000)
- get_embedding_config_status(self)
- get_last_text_info(self)
- add_situations(self, situations)
- get_cache_info(self)
- get_memories(self, situation, n_results=5)
Toolkit 作用
-
Toolkit(tradingagents/agents/utils/agent_utils.py)
是 TradingAgents 框架里的工具集,封装了各种外部数据源的调用接口(新闻、财务、行情、情绪、指标等),并通过@tool
装饰器暴露给 LLM 使用。 -
TradingAgents Toolkit 功能总览表
模块类别 | 主要函数/工具 | 功能描述 | 数据来源 | 状态 |
---|---|---|---|---|
市场行情 (Market) | get_market_data_unified | 获取股票/指数的市场行情数据(价格、成交量、技术指标) | Yahoo Finance, Tushare, 东方财富等 | |
get_price_history | 获取历史 K 线数据 | 同上 | ||
get_technical_indicators | 生成技术指标(MA, RSI, MACD 等) | 本地计算 | ||
新闻 (News) | get_news_articles | 抓取金融新闻 | Google News, 东方财富新闻等 | |
summarize_news | 对新闻做摘要 | LLM 处理 | ||
社交媒体 (Social) | get_social_sentiment | 获取社交媒体情绪(Twitter, 微博) | API / 爬虫 | 数据源可能受限 |
analyze_sentiment | 使用 LLM 对评论、帖子做情绪分析 | LLM | ||
基本面 (Fundamentals) | get_stock_fundamentals_unified | 统一接口:获取美股/A股/港股的财务数据与估值 | Yahoo Finance, SimFin, Tushare, AKShare, 东方财富 | |
get_fundamentals_openai | 使用 OpenAI Agent 调用财务数据 | OpenAI Agent | 废弃 | |
get_china_fundamentals | 获取中国 A 股财务数据(旧接口) | Tushare, AKShare | 废弃 | |
风险分析 (Risk) | get_risk_metrics | 计算风险指标(波动率、夏普比率、回撤) | 本地计算 | |
stress_test | 压力测试(不同市场情景下的资产表现) | 本地模拟 | ||
portfolio_risk_analysis | 投资组合风险分析 | 本地计算 + 历史行情 |
Toolkit 逐函数解析
- 大部分函数都是通过
tradingagents.dataflows.interface
获取数据。
1. 获取默认配置
from tradingagents.default_config import DEFAULT_CONFIG
_config = DEFAULT_CONFIG.copy()
- 获取默认配置,默认配置如下:
import osDEFAULT_CONFIG = {"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),"data_dir": os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data"),"data_cache_dir": os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),"dataflows/data_cache",),# LLM settings"llm_provider": "openai","deep_think_llm": "o4-mini","quick_think_llm": "gpt-4o-mini","backend_url": "https://api.openai.com/v1",# Debate and discussion settings"max_debate_rounds": 1,"max_risk_discuss_rounds": 1,"max_recur_limit": 100,# Tool settings"online_tools": True,# Note: Database and cache configuration is now managed by .env file and config.database_manager# No database/cache settings in default config to avoid configuration conflicts
}
2. update_config
@classmethod
def update_config(cls, config):"""Update the class-level configuration."""cls._config.update(config)
- 更新
Toolkit
的全局配置(类级别),比如数据源 API key、默认参数等。
3. config
@property
def config(self):"""Access the configuration."""return self._config
- 返回当前配置。
4. __init__
def __init__(self, config=None):if config:self.update_config(config)
- 构造函数,可传入配置并更新默认配置。
5. get_reddit_news
@staticmethod
@tool
def get_reddit_news(curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
) -> str:"""Retrieve global news from Reddit within a specified time frame.Args:curr_date (str): Date you want to get news for in yyyy-mm-dd formatReturns:str: A formatted dataframe containing the latest global news from Reddit in the specified time frame."""global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)return global_news_result
- 从 Reddit 获取某天起过去 7 天内的 全球新闻(最多 5 条),主要用于宏观舆情分析。
6. get_finnhub_news
@staticmethod
@tool
def get_finnhub_news(ticker: Annotated[str, "Search query of a company, e.g. 'AAPL, TSM, etc."],start_date: Annotated[str, "Start date in yyyy-mm-dd format"],end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):"""Retrieve the latest news about a given stock from Finnhub within a date rangeArgs:ticker (str): Ticker of a company. e.g. AAPL, TSMstart_date (str): Start date in yyyy-mm-dd formatend_date (str): End date in yyyy-mm-dd formatReturns:str: A formatted dataframe containing news about the company within the date range from start_date to end_date"""end_date_str = end_dateend_date = datetime.strptime(end_date, "%Y-%m-%d")start_date = datetime.strptime(start_date, "%Y-%m-%d")look_back_days = (end_date - start_date).daysfinnhub_news_result = interface.get_finnhub_news(ticker, end_date_str, look_back_days)return finnhub_news_result
- 调用 Finnhub API 获取指定股票在
start_date ~ end_date
的新闻。
7. get_reddit_stock_info
@staticmethod
@tool
def get_reddit_stock_info(ticker: Annotated[str, "Ticker of a company. e.g. AAPL, TSM"],curr_date: Annotated[str, "Current date you want to get news for"],
) -> str:"""Retrieve the latest news about a given stock from Reddit, given the current date.Args:ticker (str): Ticker of a company. e.g. AAPL, TSMcurr_date (str): current date in yyyy-mm-dd format to get news forReturns:str: A formatted dataframe containing the latest news about the company on the given date"""stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)return stock_news_results
- 从 Reddit 获取某只股票的近期新闻和讨论。
8. get_chinese_social_sentiment
@staticmethod
@tool
def get_chinese_social_sentiment(ticker: Annotated[str, "股票代码,例如 '600519.SS' 或 '000001.SZ'"],curr_date: Annotated[str, "要获取的日期,yyyy-mm-dd 格式"],
) -> str:"""获取某只股票在中国社交媒体上的情绪分析。Args:ticker (str): 股票代码 (A股格式)curr_date (str): 要获取的日期Returns:str: 包含中国社交媒体情绪分析的格式化 dataframe"""try:sentiment_result = interface.get_chinese_sentiment(ticker, curr_date, 7)except Exception as e:logger.warning(f"中国舆情数据获取失败,回退到 Reddit: {e}")sentiment_result = interface.get_reddit_company_news(ticker, curr_date, 7, 5)return sentiment_result
- 获取 A 股股票在 中国本土社交媒体 上的舆情/情绪数据。
- 如果中国数据源失败 → 自动回退到 Reddit。
- 常用于国内公司投资情绪分析。
9. get_finnhub_company_insider_sentiment
@staticmethod
@tool
def get_finnhub_company_insider_sentiment(ticker: Annotated[str, "股票代码,例如 'AAPL', 'TSM'"],
) -> str:"""获取公司内部人买卖行为的情绪数据。Args:ticker (str): 公司代码Returns:str: 格式化的内部人情绪数据"""insider_sentiment = interface.get_finnhub_insider_sentiment(ticker)return insider_sentiment
- 调用 Finnhub API 获取 内部人买卖股票的情绪指标(比如 CEO、CFO 买入/卖出, 属于“聪明钱”指标)。
10. get_YFin_data
@staticmethod
@tool
def get_YFin_data(ticker: Annotated[str, "公司股票代码,例如 'AAPL', 'TSM'"],period: Annotated[str, "数据周期,例如 '1mo', '6mo', '1y'"],
) -> str:"""获取 Yahoo Finance 历史行情数据。Args:ticker (str): 股票代码period (str): 时间范围Returns:str: 格式化的历史行情 dataframe"""df = interface.get_yahoo_finance_history(ticker, period)return df.to_string()
- 从 Yahoo Finance 获取某只股票的历史行情(K 线数据)。
11. get_YFin_data_online
@staticmethod
@tool
def get_YFin_data_online(ticker: Annotated[str, "公司股票代码"],start_date: Annotated[str, "开始日期 yyyy-mm-dd"],end_date: Annotated[str, "结束日期 yyyy-mm-dd"],
) -> str:"""获取 Yahoo Finance 区间行情。"""df = interface.get_yahoo_finance_range(ticker, start_date, end_date)return df.to_string()
- 类似
get_YFin_data
,但支持 指定开始和结束日期。
12. get_stockstats_indicators_report
@staticmethod
@tool
def get_stockstats_indicators_report(ticker: Annotated[str, "公司股票代码"],period: Annotated[str, "分析周期,例如 '6mo'"],
) -> str:"""获取技术指标分析报告(基于 stockstats)。"""df = interface.get_stockstats_indicators(ticker, period)return df.to_string()
- 基于 stockstats 库计算技术指标(均线、RSI、MACD 等)。
13. get_stockstats_indicators_report_online
@staticmethod
@tool
def get_stockstats_indicators_report_online(ticker: Annotated[str, "公司股票代码"],start_date: Annotated[str, "开始日期"],end_date: Annotated[str, "结束日期"],
) -> str:"""获取指定区间的技术指标分析。"""df = interface.get_stockstats_indicators_range(ticker, start_date, end_date)return df.to_string()
get_stockstats_indicators_report
类似,但支持 自定义区间。
14. get_simfin_balance_sheet 资产负债表
@staticmethod
@tool
def get_simfin_balance_sheet(ticker: Annotated[str, "公司股票代码,例如 'AAPL', 'TSM'"],report_type: Annotated[str, "报告类型,例如 'annual', 'quarterly'"],
) -> str:"""获取公司资产负债表 (Balance Sheet)。"""df = interface.get_simfin_balance_sheet(ticker, report_type)return df.to_string()
- 从 SimFin API 获取 资产负债表。
- 可选年度 (annual) 或季度 (quarterly)。
- 返回格式化的 dataframe。
15. get_simfin_income_statement利润表
@staticmethod
@tool
def get_simfin_income_statement(ticker: Annotated[str, "公司股票代码"],report_type: Annotated[str, "annual 或 quarterly"],
) -> str:"""获取公司利润表 (Income Statement)。"""df = interface.get_simfin_income_statement(ticker, report_type)return df.to_string()
- 获取 利润表(营业收入、净利润、毛利率等)。
- 主要用于盈利能力分析。
16. get_simfin_cashflow_statement 现金流量表
@staticmethod
@tool
def get_simfin_cashflow_statement(ticker: Annotated[str, "公司股票代码"],report_type: Annotated[str, "annual 或 quarterly"],
) -> str:"""获取公司现金流量表 (Cashflow Statement)。"""df = interface.get_simfin_cashflow_statement(ticker, report_type)return df.to_string()
- 获取 现金流量表(经营活动、投资活动、融资活动现金流)。
- 用于衡量公司“造血能力”和资金链稳定性。
17. get_simfin_ratios 财务指标比率
@staticmethod
@tool
def get_simfin_ratios(ticker: Annotated[str, "公司股票代码"],report_type: Annotated[str, "annual 或 quarterly"],
) -> str:"""获取公司财务比率 (Ratios),例如 PE、ROE、负债率。"""df = interface.get_simfin_ratios(ticker, report_type)return df.to_string()
- 获取 财务指标比率(PE, PB, ROE, 负债率, 流动比率)。
- 适合做跨公司对比。
18. get_simfin_company_info 公司基本信息
@staticmethod
@tool
def get_simfin_company_info(ticker: Annotated[str, "公司股票代码"]
) -> str:"""获取公司基本信息(行业、地区、规模等)。"""df = interface.get_simfin_company_info(ticker)return df.to_string()
- 获取公司的 基本信息(行业分类、上市地、公司规模)。
- 在做行业对比或聚类分析时很有用。
19. get_simfin_shareprices 历史股价数据
@staticmethod
@tool
def get_simfin_shareprices(ticker: Annotated[str, "公司股票代码"],start_date: Annotated[str, "开始日期"],end_date: Annotated[str, "结束日期"],
) -> str:"""获取公司历史股价 (Share Prices)。"""df = interface.get_simfin_shareprices(ticker, start_date, end_date)return df.to_string()
- 获取 SimFin 的历史股价数据。
- 类似 Yahoo Finance,但数据源不同。
20. get_fundamentals_openai 财务+估值
@staticmethod
@tool
def get_fundamentals_openai(ticker: Annotated[str, "公司代码,例如 AAPL, TSLA"],report_type: Annotated[str, "annual 或 quarterly"] = "annual",
) -> str:"""使用 OpenAI Agent 获取公司财务和估值信息。(已废弃,推荐使用 get_stock_fundamentals_unified)"""logger.warning("⚠️ [DEPRECATED] 推荐使用 get_stock_fundamentals_unified() 代替")return interface.get_fundamentals_openai(ticker, report_type)
- 原始版本,用 OpenAI Agent 来获取财务+估值。
- 已经 废弃,现在统一整合进
get_stock_fundamentals_unified
。
- 已经 废弃,现在统一整合进
21. get_china_fundamentals A 股财务数据
@staticmethod
@tool
def get_china_fundamentals(ticker: Annotated[str, "中国股票代码,例如 600519"],report_type: Annotated[str, "年度/季度"] = "annual",
) -> str:"""获取中国 A 股财务数据(通过 Tushare 或 AKShare)。(已废弃,推荐使用 get_stock_fundamentals_unified)"""logger.warning("⚠️ [DEPRECATED] 推荐使用 get_stock_fundamentals_unified() 代替")return interface.get_china_fundamentals(ticker, report_type)
- 早期用于获取 A 股财务数据。
- 数据源:Tushare / AKShare。
- 已 废弃,功能已被统一接口替代。
22. get_stock_fundamentals_unified 整合财务+估值
@staticmethod
@tool
def get_stock_fundamentals_unified(ticker: Annotated[str, "股票代码,例如 AAPL, 600519, 00700.HK"],market: Annotated[str, "市场类型:us / cn / hk"] = "us",report_type: Annotated[str, "annual 或 quarterly"] = "annual",
) -> str:"""统一接口:自动识别市场并获取财务数据与估值。"""return interface.get_stock_fundamentals_unified(ticker, market, report_type)
- 核心统一入口,整合了所有财务+估值工具:
- 美股 → Yahoo Finance + SimFin
- A 股 → Tushare + AKShare
- 港股 → Yahoo Finance + 东方财富
- 自动根据
market
参数(us/cn/hk)选择合适数据源。 - 返回内容包括:
- 财报(资产负债表、利润表、现金流)
- 估值指标(PE, PB, PEG, ROE 等)
- 行业对比
create_msg_delete函数(非Toolkit类)
def create_msg_delete():def delete_messages(state):"""Clear messages and add placeholder for Anthropic compatibility"""messages = state["messages"]# Remove all messagesremoval_operations = [RemoveMessage(id=m.id) for m in messages]# Add a minimal placeholder messageplaceholder = HumanMessage(content="Continue")return {"messages": removal_operations + [placeholder]}return delete_messages
- 清空消息历史,并插入一个占位的
HumanMessage("Continue")
,保证在像 Anthropic 这类模型里保持对话兼容性。
ChromaDBManager
类说明
- 目标:保证整个项目里只存在一个 ChromaDB 客户端,避免多线程或多进程同时初始化带来的冲突。
- 关键点:用了 单例模式 + 线程锁 来保证全局唯一性。
class ChromaDBManager:"""单例ChromaDB管理器,避免并发创建集合的冲突"""_instance = None_lock = threading.Lock()_collections: Dict[str, any] = {}_client = Nonedef __new__(cls):if cls._instance is None:...cls._instance = super(ChromaDBManager, cls).__new__(cls)cls._instance._initialized = Falsereturn cls._instancedef __init__(self):if not self._initialized:try:...self._initialized = Truedef get_or_create_collection(self, name: str):"""线程安全地获取或创建集合"""with self._lock:if name in self._collections:logger.info(f"📚 [ChromaDB] 使用缓存集合: {name}")return self._collections[name]try:# 尝试获取现有集合collection = self._client.get_collection(name=name)logger.info(f"📚 [ChromaDB] 获取现有集合: {name}")except Exception:try:# 创建新集合...# 缓存集合self._collections[name] = collectionreturn collection
逐函数解析
类属性
_instance = None # 存储类的唯一实例
_lock = threading.Lock() # 线程锁,保证并发安全
_collections: Dict[str, any] = {} # 缓存已经创建/获取的集合,避免重复创建
_client = None # 底层 ChromaDB 客户端实例
__new__(cls)
def __new__(cls):if cls._instance is None:with cls._lock:if cls._instance is None:cls._instance = super(ChromaDBManager, cls).__new__(cls)cls._instance._initialized = Falsereturn cls._instance
-
作用:实现单例模式,确保只创建一个实例。
-
逻辑:
- 如果
_instance
还没创建 → 加锁。 - 再次检查
_instance
(双重检查锁 DCL,避免竞态)。 - 创建实例,并标记
_initialized=False
,表示还没初始化。
- 如果
-
返回值:类的唯一实例。
__init__(self)
def __init__(self):if not self._initialized:try:# 自动检测操作系统版本并使用最优配置import platformsystem = platform.system()if system == "Windows":# 使用改进的Windows 11检测from .chromadb_win11_config import is_windows_11if is_windows_11():# Windows 11 或更新版本,使用优化配置from .chromadb_win11_config import get_win11_chromadb_clientself._client = get_win11_chromadb_client()logger.info(f"📚 [ChromaDB] Windows 11优化配置初始化完成 (构建号: {platform.version()})")else:# Windows 10 或更老版本,使用兼容配置from .chromadb_win10_config import get_win10_chromadb_clientself._client = get_win10_chromadb_client()logger.info(f"📚 [ChromaDB] Windows 10兼容配置初始化完成")else:# 非Windows系统,使用标准配置settings = Settings(allow_reset=True,anonymized_telemetry=False,is_persistent=False)self._client = chromadb.Client(settings)logger.info(f"📚 [ChromaDB] {system}标准配置初始化完成")self._initialized = Trueexcept Exception as e:logger.error(f"❌ [ChromaDB] 初始化失败: {e}")# 使用最简单的配置作为备用try:settings = Settings(allow_reset=True,anonymized_telemetry=False, # 关键:禁用遥测is_persistent=False)self._client = chromadb.Client(settings)logger.info(f"📚 [ChromaDB] 使用备用配置初始化完成")except Exception as backup_error:# 最后的备用方案self._client = chromadb.Client()logger.warning(f"⚠️ [ChromaDB] 使用最简配置初始化: {backup_error}")self._initialized = True
-
作用:在第一次创建时初始化
ChromaDB
客户端。 -
逻辑:
- 检测操作系统(Windows / Linux / Mac)。
- Windows:进一步区分 Windows 11 优化配置 和 Windows 10 兼容配置。
- 其它系统:用标准配置
Settings(...)
。 - 初始化失败 → 尝试 备用配置(禁用遥测、非持久化)。
- 如果还失败 → 用 最简配置
chromadb.Client()
。
-
输出:无(初始化
_client
)。
get_or_create_collection(self, name: str)
def get_or_create_collection(self, name: str):"""线程安全地获取或创建集合(输出ChromaDB 集合对象)"""with self._lock:if name in self._collections:logger.info(f"📚 [ChromaDB] 使用缓存集合: {name}")return self._collections[name]try:# 尝试获取现有集合collection = self._client.get_collection(name=name)logger.info(f"📚 [ChromaDB] 获取现有集合: {name}")except Exception:try:# 创建新集合collection = self._client.create_collection(name=name)logger.info(f"📚 [ChromaDB] 创建新集合: {name}")except Exception as e:# 可能是并发创建,再次尝试获取try:collection = self._client.get_collection(name=name)logger.info(f"📚 [ChromaDB] 并发创建后获取集合: {name}")except Exception as final_error:logger.error(f"❌ [ChromaDB] 集合操作失败: {name}, 错误: {final_error}")raise final_error# 缓存集合self._collections[name] = collectionreturn collection
-
作用:获取或创建一个集合(类似数据库里的表)。
-
线程安全:加锁保证多个线程不会同时创建同一个集合。
-
逻辑:
- 如果缓存
_collections
里已有 → 直接返回。 - 否则尝试
get_collection(name)
。 - 如果失败(说明不存在) →
create_collection(name)
。 - 如果创建也失败(可能是并发竞争) → 再尝试
get_collection(name)
。 - 如果还是失败 → 抛出异常。
- 成功则缓存集合,并返回。
- 如果缓存
FinancialSituationMemory类
类说明
- 类可以视为一个 “财务情况记忆库”:
- 写入:新情况 + 建议 → 生成 embedding → 存入 ChromaDB
- 读取:新情况 → 生成 embedding → 相似度搜索 → 找到最相关的历史建议
简化后代码
class FinancialSituationMemory:def __init__(self, name, config):self.config = configself.llm_provider = config.get("llm_provider", "openai").lower()# 配置向量缓存的长度限制(向量缓存默认启用长度检查)self.max_embedding_length = int(os.getenv('MAX_EMBEDDING_CONTENT_LENGTH', '50000')) # 默认50K字符self.enable_embedding_length_check = os.getenv('ENABLE_EMBEDDING_LENGTH_CHECK', 'true').lower() == 'true' # 向量缓存默认启用# 根据LLM提供商选择嵌入模型和客户端self.fallback_available = False # 初始化降级选项标志if self.llm_provider == "dashscope" or self.llm_provider == "alibaba":self.embedding = "text-embedding-v3"self.client = None # DashScope不需要OpenAI客户端# 设置DashScope API密钥dashscope_key = os.getenv('DASHSCOPE_API_KEY')if dashscope_key:try:# 尝试导入和初始化DashScopeimport dashscopefrom dashscope import TextEmbeddingdashscope.api_key = dashscope_keylogger.info(f"✅ DashScope API密钥已配置,启用记忆功能")except ImportError as e:# DashScope包未安装 ...except Exception as e:# 其他初始化错误 ...else:# 没有DashScope密钥,禁用记忆功能 ...elif self.llm_provider == "deepseek":...# 使用单例ChromaDB管理器self.chroma_manager = ChromaDBManager()self.situation_collection = self.chroma_manager.get_or_create_collection(name)def _smart_text_truncation(self, text, max_length=8192):"""智能文本截断,保持语义完整性和缓存兼容性"""if len(text) <= max_length:return text, False # 返回原文本和是否截断的标志# 尝试在句子边界截断sentences = text.split('。')...# 尝试在段落边界截断paragraphs = text.split('\n')...return truncated, Truedef get_embedding(self, text):"""Get embedding for a text using the configured provider"""# 检查记忆功能是否被禁用if self.client == "DISABLED":# 内存功能已禁用,返回空向量 ...if len(text) == 0: ... # 输入文本长度为0,返回空向量# 检查是否启用长度限制if self.enable_embedding_length_check and text_length > self.max_embedding_length: # 文本过长跳过向量化并存储跳过信息 ...return [0.0] * 1024# 存储文本处理信息self._last_text_info = { ... }if (self.llm_provider == "dashscope" orself.llm_provider == "alibaba" or(self.llm_provider == "google" and self.client is None) or(self.llm_provider == "deepseek" and self.client is None) or(self.llm_provider == "openrouter" and self.client is None)):# 使用阿里百炼的嵌入模型try: ...return embeddingelse:...return [0.0] * 1024def get_embedding_config_status(self):"""获取向量缓存配置状态"""return {'enabled': self.enable_embedding_length_check,'max_embedding_length': self.max_embedding_length,'max_embedding_length_formatted': f"{self.max_embedding_length:,}字符",'provider': self.llm_provider,'client_status': 'DISABLED' if self.client == "DISABLED" else 'ENABLED'}def add_situations(self, situations_and_advice):"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""...self.situation_collection.add(documents=situations,metadatas=[{"recommendation": rec} for rec in advice],embeddings=embeddings,ids=ids,)
Example
# Example usagematcher = FinancialSituationMemory()# Example dataexample_data = [("High inflation rate with rising interest rates and declining consumer spending","Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",),("Tech sector showing high volatility with increasing institutional selling pressure","Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",),("Strong dollar affecting emerging markets with increasing forex volatility","Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",),("Market showing signs of sector rotation with rising yields","Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",),]# Add the example situations and recommendationsmatcher.add_situations(example_data)# Example querycurrent_situation = """Market showing increased volatility in tech sector, with institutional investors reducing positions and rising interest rates affecting growth stock valuations"""try:recommendations = matcher.get_memories(current_situation, n_matches=2)for i, rec in enumerate(recommendations, 1):logger.info(f"\nMatch {i}:")logger.info(f"Similarity Score: {rec.get('similarity', 0):.2f}")logger.info(f"Matched Situation: {rec.get('situation', '')}")logger.info(f"Recommendation: {rec.get('recommendation', '')}")except Exception as e:logger.error(f"Error during recommendation: {str(e)}")
逐函数解析
__init__(self, config)
-
作用
初始化类,设置向量数据库(ChromaDB),并根据config
和环境变量自动选择合适的 Embedding 服务。 -
主要逻辑
- 保存
config
。 - 初始化一些内部状态(provider、model、status、last_text_info 等)。
- 根据
config["provider"]
来选择 embedding 服务:- DashScope / Alibaba → 用阿里云的 embedding API
- DeepSeek → 优先用 DashScope,其次 OpenAI,最后 DeepSeek 自己的 embedding
- Google → 优先 DashScope,如果配置里有
openai_api_key
就启用 fallback - OpenRouter → DashScope embedding
- 本地 Ollama (localhost:11434) → 使用 nomic-embed-text
- 默认 → 尝试 OpenAI embedding,失败则禁用
- 保存
- 创建一个 ChromaDB 客户端,并建立/获取一个集合
financial_situations
。
- 返回值 : 无(构造函数)。
def __init__(self, name, config):self.config = configself.llm_provider = config.get("llm_provider", "openai").lower()# 配置向量缓存的长度限制(向量缓存默认启用长度检查)self.max_embedding_length = int(os.getenv('MAX_EMBEDDING_CONTENT_LENGTH', '50000')) # 默认50K字符self.enable_embedding_length_check = os.getenv('ENABLE_EMBEDDING_LENGTH_CHECK', 'true').lower() == 'true' # 向量缓存默认启用# 根据LLM提供商选择嵌入模型和客户端# 初始化降级选项标志self.fallback_available = Falseif self.llm_provider == "dashscope" or self.llm_provider == "alibaba":self.embedding = "text-embedding-v3"self.client = None # DashScope不需要OpenAI客户端# 设置DashScope API密钥dashscope_key = os.getenv('DASHSCOPE_API_KEY')if dashscope_key:try:# 尝试导入和初始化DashScopeimport dashscopefrom dashscope import TextEmbeddingdashscope.api_key = dashscope_keylogger.info(f"✅ DashScope API密钥已配置,启用记忆功能")# 可选:测试API连接(简单验证)# 这里不做实际调用,只验证导入和密钥设置except ImportError as e:# DashScope包未安装logger.error(f"❌ DashScope包未安装: {e}")self.client = "DISABLED"logger.warning(f"⚠️ 记忆功能已禁用")except Exception as e:# 其他初始化错误logger.error(f"❌ DashScope初始化失败: {e}")self.client = "DISABLED"logger.warning(f"⚠️ 记忆功能已禁用")else:# 没有DashScope密钥,禁用记忆功能self.client = "DISABLED"logger.warning(f"⚠️ 未找到DASHSCOPE_API_KEY,记忆功能已禁用")logger.info(f"💡 系统将继续运行,但不会保存或检索历史记忆")elif self.llm_provider == "deepseek":# 检查是否强制使用OpenAI嵌入force_openai = os.getenv('FORCE_OPENAI_EMBEDDING', 'false').lower() == 'true'if not force_openai:# 尝试使用阿里百炼嵌入dashscope_key = os.getenv('DASHSCOPE_API_KEY')if dashscope_key:try:# 测试阿里百炼是否可用import dashscopefrom dashscope import TextEmbeddingdashscope.api_key = dashscope_key# 验证TextEmbedding可用性(不需要实际调用)self.embedding = "text-embedding-v3"self.client = Nonelogger.info(f"💡 DeepSeek使用阿里百炼嵌入服务")except ImportError as e:logger.error(f"⚠️ DashScope包未安装: {e}")dashscope_key = None # 强制降级except Exception as e:logger.error(f"⚠️ 阿里百炼嵌入初始化失败: {e}")dashscope_key = None # 强制降级else:dashscope_key = None # 跳过阿里百炼if not dashscope_key or force_openai:# 降级到OpenAI嵌入self.embedding = "text-embedding-3-small"openai_key = os.getenv('OPENAI_API_KEY')if openai_key:self.client = OpenAI(api_key=openai_key,base_url=config.get("backend_url", "https://api.openai.com/v1"))logger.warning(f"⚠️ DeepSeek回退到OpenAI嵌入服务")else:# 最后尝试DeepSeek自己的嵌入deepseek_key = os.getenv('DEEPSEEK_API_KEY')if deepseek_key:try:self.client = OpenAI(api_key=deepseek_key,base_url="https://api.deepseek.com")logger.info(f"💡 DeepSeek使用自己的嵌入服务")except Exception as e:logger.error(f"❌ DeepSeek嵌入服务不可用: {e}")# 禁用内存功能self.client = "DISABLED"logger.info(f"🚨 内存功能已禁用,系统将继续运行但不保存历史记忆")else:# 禁用内存功能而不是抛出异常self.client = "DISABLED"logger.info(f"🚨 未找到可用的嵌入服务,内存功能已禁用")elif self.llm_provider == "google":# Google AI使用阿里百炼嵌入(如果可用),否则禁用记忆功能dashscope_key = os.getenv('DASHSCOPE_API_KEY')openai_key = os.getenv('OPENAI_API_KEY')if dashscope_key:try:# 尝试初始化DashScopeimport dashscopefrom dashscope import TextEmbeddingself.embedding = "text-embedding-v3"self.client = Nonedashscope.api_key = dashscope_key# 检查是否有OpenAI密钥作为降级选项if openai_key:logger.info(f"💡 Google AI使用阿里百炼嵌入服务(OpenAI作为降级选项)")self.fallback_available = Trueself.fallback_client = OpenAI(api_key=openai_key, base_url=config["backend_url"])self.fallback_embedding = "text-embedding-3-small"else:logger.info(f"💡 Google AI使用阿里百炼嵌入服务(无降级选项)")self.fallback_available = Falseexcept ImportError as e:logger.error(f"❌ DashScope包未安装: {e}")self.client = "DISABLED"logger.warning(f"⚠️ Google AI记忆功能已禁用")except Exception as e:logger.error(f"❌ DashScope初始化失败: {e}")self.client = "DISABLED"logger.warning(f"⚠️ Google AI记忆功能已禁用")else:# 没有DashScope密钥,禁用记忆功能self.client = "DISABLED"self.fallback_available = Falselogger.warning(f"⚠️ Google AI未找到DASHSCOPE_API_KEY,记忆功能已禁用")logger.info(f"💡 系统将继续运行,但不会保存或检索历史记忆")elif self.llm_provider == "openrouter":# OpenRouter支持:优先使用阿里百炼嵌入,否则禁用记忆功能dashscope_key = os.getenv('DASHSCOPE_API_KEY')if dashscope_key:try:# 尝试使用阿里百炼嵌入import dashscopefrom dashscope import TextEmbeddingself.embedding = "text-embedding-v3"self.client = Nonedashscope.api_key = dashscope_keylogger.info(f"💡 OpenRouter使用阿里百炼嵌入服务")except ImportError as e:logger.error(f"❌ DashScope包未安装: {e}")self.client = "DISABLED"logger.warning(f"⚠️ OpenRouter记忆功能已禁用")except Exception as e:logger.error(f"❌ DashScope初始化失败: {e}")self.client = "DISABLED"logger.warning(f"⚠️ OpenRouter记忆功能已禁用")else:# 没有DashScope密钥,禁用记忆功能self.client = "DISABLED"logger.warning(f"⚠️ OpenRouter未找到DASHSCOPE_API_KEY,记忆功能已禁用")logger.info(f"💡 系统将继续运行,但不会保存或检索历史记忆")elif config["backend_url"] == "http://localhost:11434/v1":self.embedding = "nomic-embed-text"self.client = OpenAI(base_url=config["backend_url"])else:self.embedding = "text-embedding-3-small"openai_key = os.getenv('OPENAI_API_KEY')if openai_key:self.client = OpenAI(api_key=openai_key,base_url=config["backend_url"])else:self.client = "DISABLED"logger.warning(f"⚠️ 未找到OPENAI_API_KEY,记忆功能已禁用")# 使用单例ChromaDB管理器self.chroma_manager = ChromaDBManager()self.situation_collection = self.chroma_manager.get_or_create_collection(name)
_smart_text_truncation(self, text, max_length)
def _smart_text_truncation(self, text, max_length=8192):"""智能文本截断,保持语义完整性和缓存兼容性"""if len(text) <= max_length:return text, False # 返回原文本和是否截断的标志# 尝试在句子边界截断sentences = text.split('。')if len(sentences) > 1:truncated = ""for sentence in sentences:if len(truncated + sentence + '。') <= max_length - 50: # 留50字符余量truncated += sentence + '。'else:breakif len(truncated) > max_length // 2: # 至少保留一半内容logger.info(f"📝 智能截断:在句子边界截断,保留{len(truncated)}/{len(text)}字符")return truncated, True# 尝试在段落边界截断paragraphs = text.split('\n')if len(paragraphs) > 1:truncated = ""for paragraph in paragraphs:if len(truncated + paragraph + '\n') <= max_length - 50:truncated += paragraph + '\n'else:breakif len(truncated) > max_length // 2:logger.info(f"📝 智能截断:在段落边界截断,保留{len(truncated)}/{len(text)}字符")return truncated, True# 最后选择:保留前半部分和后半部分的关键信息front_part = text[:max_length//2]back_part = text[-(max_length//2-100):] # 留100字符给连接符truncated = front_part + "\n...[内容截断]...\n" + back_partlogger.warning(f"⚠️ 强制截断:保留首尾关键信息,{len(text)}字符截断为{len(truncated)}字符")return truncated, True
-
作用
保证输入文本不会超过max_length
,但尽量保持语义完整(比如按句子、段落截断)。 -
主要逻辑
- 如果文本长度 ≤
max_length
→ 原样返回。 - 如果太长:
- 尝试在句号(。.!?)之后截断,保留前面一段完整句子。
- 否则尝试按段落
\n
截断。 - 如果都不行,就取前
max_length//2
和后max_length//2
拼接,中间插入...
。
- 把截断后的文本保存到
self.last_text_info
,记录是否截断、采用哪种策略、原始/最终长度等。
- 如果文本长度 ≤
-
输入
text
(str):原始文本max_length
(int):允许的最大长度
-
输出
- 截断后的文本 (str)
get_embedding(self, text, max_length=1000)
def get_embedding(self, text):"""Get embedding for a text using the configured provider"""# 检查记忆功能是否被禁用if self.client == "DISABLED":# 内存功能已禁用,返回空向量logger.debug(f"⚠️ 记忆功能已禁用,返回空向量")return [0.0] * 1024 # 返回1024维的零向量# 验证输入文本if not text or not isinstance(text, str):logger.warning(f"⚠️ 输入文本为空或无效,返回空向量")return [0.0] * 1024text_length = len(text)if text_length == 0:logger.warning(f"⚠️ 输入文本长度为0,返回空向量")return [0.0] * 1024# 检查是否启用长度限制if self.enable_embedding_length_check and text_length > self.max_embedding_length:logger.warning(f"⚠️ 文本过长({text_length:,}字符 > {self.max_embedding_length:,}字符),跳过向量化")# 存储跳过信息self._last_text_info = {'original_length': text_length,'processed_length': 0,'was_truncated': False,'was_skipped': True,'provider': self.llm_provider,'strategy': 'length_limit_skip','max_length': self.max_embedding_length}return [0.0] * 1024# 记录文本信息(不进行任何截断)if text_length > 8192:logger.info(f"📝 处理长文本: {text_length}字符,提供商: {self.llm_provider}")# 存储文本处理信息self._last_text_info = {'original_length': text_length,'processed_length': text_length, # 不截断,保持原长度'was_truncated': False, # 永不截断'was_skipped': False,'provider': self.llm_provider,'strategy': 'no_truncation_with_fallback' # 标记策略}if (self.llm_provider == "dashscope" orself.llm_provider == "alibaba" or(self.llm_provider == "google" and self.client is None) or(self.llm_provider == "deepseek" and self.client is None) or(self.llm_provider == "openrouter" and self.client is None)):# 使用阿里百炼的嵌入模型try:# 导入DashScope模块import dashscopefrom dashscope import TextEmbedding# 检查DashScope API密钥是否可用if not hasattr(dashscope, 'api_key') or not dashscope.api_key:logger.warning(f"⚠️ DashScope API密钥未设置,记忆功能降级")return [0.0] * 1024 # 返回空向量# 尝试调用DashScope APIresponse = TextEmbedding.call(model=self.embedding,input=text)# 检查响应状态if response.status_code == 200:# 成功获取embeddingembedding = response.output['embeddings'][0]['embedding']logger.debug(f"✅ DashScope embedding成功,维度: {len(embedding)}")return embeddingelse:# API返回错误状态码error_msg = f"{response.code} - {response.message}"# 检查是否为长度限制错误if any(keyword in error_msg.lower() for keyword in ['length', 'token', 'limit', 'exceed']):logger.warning(f"⚠️ DashScope长度限制: {error_msg}")# 检查是否有降级选项if hasattr(self, 'fallback_available') and self.fallback_available:logger.info(f"💡 尝试使用OpenAI降级处理长文本")try:response = self.fallback_client.embeddings.create(model=self.fallback_embedding,input=text)embedding = response.data[0].embeddinglogger.info(f"✅ OpenAI降级成功,维度: {len(embedding)}")return embeddingexcept Exception as fallback_error:logger.error(f"❌ OpenAI降级失败: {str(fallback_error)}")logger.info(f"💡 所有降级选项失败,记忆功能降级")return [0.0] * 1024else:logger.info(f"💡 无可用降级选项,记忆功能降级")return [0.0] * 1024else:logger.error(f"❌ DashScope API错误: {error_msg}")return [0.0] * 1024 # 返回空向量而不是抛出异常except Exception as e:error_str = str(e).lower()# 检查是否为长度限制错误if any(keyword in error_str for keyword in ['length', 'token', 'limit', 'exceed', 'too long']):logger.warning(f"⚠️ DashScope长度限制异常: {str(e)}")# 检查是否有降级选项if hasattr(self, 'fallback_available') and self.fallback_available:logger.info(f"💡 尝试使用OpenAI降级处理长文本")try:response = self.fallback_client.embeddings.create(model=self.fallback_embedding,input=text)embedding = response.data[0].embeddinglogger.info(f"✅ OpenAI降级成功,维度: {len(embedding)}")return embeddingexcept Exception as fallback_error:logger.error(f"❌ OpenAI降级失败: {str(fallback_error)}")logger.info(f"💡 所有降级选项失败,记忆功能降级")return [0.0] * 1024else:logger.info(f"💡 无可用降级选项,记忆功能降级")return [0.0] * 1024elif 'import' in error_str:logger.error(f"❌ DashScope包未安装: {str(e)}")elif 'connection' in error_str:logger.error(f"❌ DashScope网络连接错误: {str(e)}")elif 'timeout' in error_str:logger.error(f"❌ DashScope请求超时: {str(e)}")else:logger.error(f"❌ DashScope embedding异常: {str(e)}")logger.warning(f"⚠️ 记忆功能降级,返回空向量")return [0.0] * 1024else:# 使用OpenAI兼容的嵌入模型if self.client is None:logger.warning(f"⚠️ 嵌入客户端未初始化,返回空向量")return [0.0] * 1024 # 返回空向量elif self.client == "DISABLED":# 内存功能已禁用,返回空向量logger.debug(f"⚠️ 内存功能已禁用,返回空向量")return [0.0] * 1024 # 返回1024维的零向量# 尝试调用OpenAI兼容的embedding APItry:response = self.client.embeddings.create(model=self.embedding,input=text)embedding = response.data[0].embeddinglogger.debug(f"✅ {self.llm_provider} embedding成功,维度: {len(embedding)}")return embeddingexcept Exception as e:error_str = str(e).lower()# 检查是否为长度限制错误length_error_keywords = ['token', 'length', 'too long', 'exceed', 'maximum', 'limit','context', 'input too large', 'request too large']is_length_error = any(keyword in error_str for keyword in length_error_keywords)if is_length_error:# 长度限制错误:直接降级,不截断重试logger.warning(f"⚠️ {self.llm_provider}长度限制: {str(e)}")logger.info(f"💡 为保证分析准确性,不截断文本,记忆功能降级")else:# 其他类型的错误if 'attributeerror' in error_str:logger.error(f"❌ {self.llm_provider} API调用错误: {str(e)}")elif 'connectionerror' in error_str or 'connection' in error_str:logger.error(f"❌ {self.llm_provider}网络连接错误: {str(e)}")elif 'timeout' in error_str:logger.error(f"❌ {self.llm_provider}请求超时: {str(e)}")elif 'keyerror' in error_str:logger.error(f"❌ {self.llm_provider}响应格式错误: {str(e)}")else:logger.error(f"❌ {self.llm_provider} embedding异常: {str(e)}")logger.warning(f"⚠️ 记忆功能降级,返回空向量")return [0.0] * 1024
-
作用 : 把文本转成向量 embedding。
-
主要逻辑
- 调用
_smart_text_truncation
保证长度安全。 - 根据 provider 选择对应的 API:
- DashScope → 调用
dashscope.TextEmbedding.call
,模型是text-embedding-v2
- OpenAI → 调用
client.embeddings.create(model="text-embedding-3-small")
- DeepSeek → 可能走 DeepSeek 自己的 embedding API
- Ollama → 调用
client.embeddings.create(model="nomic-embed-text")
- 其它情况 → 返回
[0.0] * 1024
(代表禁用)
- DashScope → 调用
- 如果调用失败,捕获异常并返回零向量。
- 调用
-
输入
text
(str):输入文本max_length
(int):最大允许长度(默认 1000)
-
输出
- 向量 embedding (list[float])
get_embedding_config_status(self)
def get_embedding_config_status(self):"""获取向量缓存配置状态"""return {'enabled': self.enable_embedding_length_check,'max_embedding_length': self.max_embedding_length,'max_embedding_length_formatted': f"{self.max_embedding_length:,}字符",'provider': self.llm_provider,'client_status': 'DISABLED' if self.client == "DISABLED" else 'ENABLED'}
- 作用 返回当前 embedding 的配置信息,主要用于调试/检查状态。
get_last_text_info(self)
def get_last_text_info(self):"""获取最后处理的文本信息"""return getattr(self, '_last_text_info', None)
- 作用:返回最近一次
_smart_text_truncation
的信息(调试用)。
add_situations(self, situations)
def add_situations(self, situations_and_advice):"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""situations = []advice = []ids = []embeddings = []offset = self.situation_collection.count()for i, (situation, recommendation) in enumerate(situations_and_advice):situations.append(situation)advice.append(recommendation)ids.append(str(offset + i))embeddings.append(self.get_embedding(situation))self.situation_collection.add(documents=situations,metadatas=[{"recommendation": rec} for rec in advice],embeddings=embeddings,ids=ids,)
- 作用:把一组新的“财务情况 + 建议”存入数据库。
主要逻辑
-
遍历
situations
,每个元素是(situation, recommendation)
。 -
对
situation
文本生成 embedding。 -
构造数据项:
{"id": str(uuid.uuid4()),"embedding": 向量,"metadata": {"recommendation": 建议} }
-
批量写入
self.collection
(ChromaDB)。
get_cache_info(self)
def get_cache_info(self):"""获取缓存相关信息,用于调试和监控"""info = {'collection_count': self.situation_collection.count(),'client_status': 'enabled' if self.client != "DISABLED" else 'disabled','embedding_model': self.embedding,'provider': self.llm_provider}# 添加最后一次文本处理信息if hasattr(self, '_last_text_info'):info['last_text_processing'] = self._last_text_inforeturn info
- 作用
返回当前 ChromaDB 集合的一些元信息(比如名称、模型、provider 等),帮助确认系统运行状态。
get_memories(self, situation, n_results=5)
def get_memories(self, current_situation, n_matches=1):"""Find matching recommendations using embeddings with smart truncation handling"""# 获取当前情况的embeddingquery_embedding = self.get_embedding(current_situation)# 检查是否为空向量(记忆功能被禁用或出错)if all(x == 0.0 for x in query_embedding):logger.debug(f"⚠️ 查询embedding为空向量,返回空结果")return []# 检查是否有足够的数据进行查询collection_count = self.situation_collection.count()if collection_count == 0:logger.debug(f"📭 记忆库为空,返回空结果")return []# 调整查询数量,不能超过集合中的文档数量actual_n_matches = min(n_matches, collection_count)try:# 执行相似度查询results = self.situation_collection.query(query_embeddings=[query_embedding],n_results=actual_n_matches)# 处理查询结果memories = []if results and 'documents' in results and results['documents']:documents = results['documents'][0]metadatas = results.get('metadatas', [[]])[0]distances = results.get('distances', [[]])[0]for i, doc in enumerate(documents):metadata = metadatas[i] if i < len(metadatas) else {}distance = distances[i] if i < len(distances) else 1.0memory_item = {'situation': doc,'recommendation': metadata.get('recommendation', ''),'similarity': 1.0 - distance, # 转换为相似度分数'distance': distance}memories.append(memory_item)# 记录查询信息if hasattr(self, '_last_text_info') and self._last_text_info.get('was_truncated'):logger.info(f"🔍 截断文本查询完成,找到{len(memories)}个相关记忆")logger.debug(f"📊 原文长度: {self._last_text_info['original_length']}, "f"处理后长度: {self._last_text_info['processed_length']}")else:logger.debug(f"🔍 记忆查询完成,找到{len(memories)}个相关记忆")return memoriesexcept Exception as e:logger.error(f"❌ 记忆查询失败: {str(e)}")return []
-
作用:检索最相似的历史情况,返回对应的建议。
-
主要逻辑
- 对输入
situation
生成 embedding。 - 调用
self.collection.query
,检索最相似的n_results
条记录。 - 解析返回结果:提取文本、embedding 相似度/距离、推荐建议。
- 返回一个结果列表。
- 对输入