|
import time |
|
import uuid |
|
from contextlib import asynccontextmanager |
|
from typing import Optional |
|
|
|
from fastapi import FastAPI |
|
from pydantic import BaseModel, Field |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import uvicorn |
|
|
|
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" |
|
|
|
tokenizer = None |
|
model = None |
|
model_loaded_at: int = 0 |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
global tokenizer, model, model_loaded_at |
|
print(f"Loading {MODEL_ID}...") |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto") |
|
model_loaded_at = int(time.time()) |
|
print("Model ready.") |
|
yield |
|
|
|
|
|
app = FastAPI(title="Ollama-compatible LLM API", lifespan=lifespan) |
|
|
|
|
|
# --- Shared schemas --- |
|
|
|
class Message(BaseModel): |
|
role: str |
|
content: str |
|
|
|
|
|
# --- OpenAI-compatible --- |
|
|
|
class ChatCompletionRequest(BaseModel): |
|
model: str = MODEL_ID |
|
messages: list[Message] |
|
max_tokens: int = 512 |
|
temperature: float = 0.7 |
|
top_p: float = 1.0 |
|
stream: bool = False |
|
|
|
|
|
@app.get("/v1/models") |
|
async def list_models(): |
|
return { |
|
"object": "list", |
|
"data": [{ |
|
"id": MODEL_ID, |
|
"object": "model", |
|
"created": model_loaded_at, |
|
"owned_by": "local", |
|
}], |
|
} |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
async def chat_completions(request: ChatCompletionRequest): |
|
messages_dict = [{"role": m.role, "content": m.content} for m in request.messages] |
|
prompt = tokenizer.apply_chat_template(messages_dict, tokenize=False, add_generation_prompt=True) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
prompt_tokens = inputs.input_ids.shape[1] |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=request.max_tokens, |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
do_sample=request.temperature > 0, |
|
pad_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
generated_tokens = outputs[0][prompt_tokens:] |
|
response_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
completion_tokens = len(generated_tokens) |
|
|
|
return { |
|
"id": f"chatcmpl-{uuid.uuid4().hex}", |
|
"object": "chat.completion", |
|
"created": int(time.time()), |
|
"model": request.model, |
|
"choices": [{ |
|
"index": 0, |
|
"message": {"role": "assistant", "content": response_text}, |
|
"finish_reason": "stop", |
|
}], |
|
"usage": { |
|
"prompt_tokens": prompt_tokens, |
|
"completion_tokens": completion_tokens, |
|
"total_tokens": prompt_tokens + completion_tokens, |
|
}, |
|
} |
|
|
|
|
|
# --- Ollama-compatible --- |
|
|
|
class OllamaGenerateRequest(BaseModel): |
|
model: str = MODEL_ID |
|
prompt: str |
|
stream: bool = False |
|
options: dict = Field(default_factory=dict) |
|
|
|
|
|
class OllamaChatRequest(BaseModel): |
|
model: str = MODEL_ID |
|
messages: list[Message] |
|
stream: bool = False |
|
options: dict = Field(default_factory=dict) |
|
|
|
|
|
def _gen_params(options: dict) -> dict: |
|
temp = options.get("temperature", 0.7) |
|
return { |
|
"max_new_tokens": options.get("num_predict", 512), |
|
"temperature": temp, |
|
"top_p": options.get("top_p", 1.0), |
|
"do_sample": temp > 0, |
|
"pad_token_id": tokenizer.eos_token_id, |
|
} |
|
|
|
|
|
def _model_details() -> dict: |
|
return { |
|
"format": "transformers", |
|
"family": "qwen2", |
|
"parameter_size": "0.5B", |
|
"quantization_level": "none", |
|
} |
|
|
|
|
|
@app.get("/api/tags") |
|
async def api_tags(): |
|
return { |
|
"models": [{ |
|
"name": MODEL_ID, |
|
"model": MODEL_ID, |
|
"modified_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(model_loaded_at)), |
|
"size": 0, |
|
"details": _model_details(), |
|
}], |
|
} |
|
|
|
|
|
@app.get("/api/ps") |
|
async def api_ps(): |
|
return { |
|
"models": [{ |
|
"name": MODEL_ID, |
|
"model": MODEL_ID, |
|
"size": 0, |
|
"digest": "", |
|
"details": _model_details(), |
|
"expires_at": "9999-12-31T23:59:59Z", |
|
"size_vram": 0, |
|
}], |
|
} |
|
|
|
|
|
@app.post("/api/generate") |
|
async def api_generate(request: OllamaGenerateRequest): |
|
inputs = tokenizer(request.prompt, return_tensors="pt").to(model.device) |
|
prompt_tokens = inputs.input_ids.shape[1] |
|
|
|
outputs = model.generate(**inputs, **_gen_params(request.options)) |
|
generated_tokens = outputs[0][prompt_tokens:] |
|
response_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
return { |
|
"model": request.model, |
|
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), |
|
"response": response_text, |
|
"done": True, |
|
"done_reason": "stop", |
|
"prompt_eval_count": prompt_tokens, |
|
"eval_count": len(generated_tokens), |
|
} |
|
|
|
|
|
@app.post("/api/chat") |
|
async def api_chat(request: OllamaChatRequest): |
|
messages_dict = [{"role": m.role, "content": m.content} for m in request.messages] |
|
prompt = tokenizer.apply_chat_template(messages_dict, tokenize=False, add_generation_prompt=True) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
prompt_tokens = inputs.input_ids.shape[1] |
|
|
|
outputs = model.generate(**inputs, **_gen_params(request.options)) |
|
generated_tokens = outputs[0][prompt_tokens:] |
|
response_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
return { |
|
"model": request.model, |
|
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), |
|
"message": {"role": "assistant", "content": response_text}, |
|
"done": True, |
|
"done_reason": "stop", |
|
"prompt_eval_count": prompt_tokens, |
|
"eval_count": len(generated_tokens), |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=11434) |