Smolagents视频代码

微信:adoresever

gemini进行text2SQL的查询

from smolagents import CodeAgent
from smolagents import tool, LiteLLMModel
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, Float, insert, text

engine = create_engine("sqlite:///:memory:")
metadata = MetaData()

products = Table(
    "products",
    metadata,
    Column("id", Integer, primary_key=True),
    Column("name", String(50)),
    Column("category", String(20)),
    Column("price", Float),
    Column("stock", Integer)
)

sales = Table(
    "sales",
    metadata,
    Column("id", Integer, primary_key=True),
    Column("product_id", Integer),
    Column("quantity", Integer),
    Column("sale_date", String(10))
)

metadata.create_all(engine)

# 示例数据
product_data = [
    {"id": 1, "name": "游戏本", "category": "电脑", "price": 6999.0, "stock": 100},
    {"id": 2, "name": "机械键盘", "category": "配件", "price": 299.0, "stock": 50},
    {"id": 3, "name": "游戏手柄", "category": "配件", "price": 199.0, "stock": 30},
    {"id": 4, "name": "办公本", "category": "电脑", "price": 4999.0, "stock": 80}
]

sales_data = [
    {"id": 1, "product_id": 1, "quantity": 2, "sale_date": "2024-01-01"},
    {"id": 2, "product_id": 2, "quantity": 5, "sale_date": "2024-01-02"},
    {"id": 3, "product_id": 1, "quantity": 1, "sale_date": "2024-01-03"},
    {"id": 4, "product_id": 4, "quantity": 3, "sale_date": "2024-01-03"}
]

with engine.begin() as conn:
    for item in product_data:
        conn.execute(insert(products).values(item))
    for item in sales_data:
        conn.execute(insert(sales).values(item))

@tool
def sql_engine(query: str) -> str:
    """执行SQL查询。

    Args:
        query: SQL查询语句

    Returns:
        str: 查询结果
    """
    try:
        with engine.connect() as conn:
            result = conn.execute(text(query))
            columns = result.keys()
            rows = result.fetchall()
            
            if not rows:
                return "查询没有返回任何结果"

            output = []
            output.append(" | ".join(str(col) for col in columns))
            output.append("-" * (sum(len(str(col)) for col in columns) + 3 * (len(columns) - 1)))
            
            for row in rows:
                output.append(" | ".join(str(val) for val in row))
                
            return "\n".join(output)
    except Exception as e:
        return f"SQL执行错误: {str(e)}"

model = LiteLLMModel(model_id="gemini/gemini-2.0-flash-exp")
#ollama/qwen2.5:14b
agent = CodeAgent(
    tools=[sql_engine],
    model=model,
    verbose=True
)

test_query = "请查找库存量最多的三种商品"
print("执行查询:", test_query)
result = agent.run(test_query)
print("查询结果:")
print(result)

ollama模型调用duckduckgo进行网络查询

from smolagents import CodeAgent, DuckDuckGoSearchTool  
from smolagents import tool, TransformersModel, LiteLLMModel
from typing import Optional

model = LiteLLMModel(
    model_id="ollama/qwen2.5:14b",  # 使用 Ollama 格式的模型 ID
    api_base="http://localhost:11434"  # Ollama 的本地地址
)

# 创建 Agent 实例
agent = CodeAgent(
    tools=[DuckDuckGoSearchTool()],
    model=model,
    verbose=True  
)

# 运行查询
print(agent.run("中国第六代战机"))

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注