74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import List, Tuple
|
|
import math
|
|
|
|
|
|
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
|
|
if not vec_a or not vec_b or len(vec_a) != len(vec_b):
|
|
return 0.0
|
|
dot = sum(a * b for a, b in zip(vec_a, vec_b))
|
|
na = math.sqrt(sum(a * a for a in vec_a))
|
|
nb = math.sqrt(sum(b * b for b in vec_b))
|
|
if na == 0.0 or nb == 0.0:
|
|
return 0.0
|
|
return dot / (na * nb)
|
|
|
|
|
|
@dataclass
|
|
class IndexRow:
|
|
id: str
|
|
text: str
|
|
vector: List[float]
|
|
source: str
|
|
|
|
|
|
class JsonlIndex:
|
|
def __init__(self, path: str) -> None:
|
|
self.path = Path(path)
|
|
self.rows: List[IndexRow] = []
|
|
self._load()
|
|
|
|
def _load(self) -> None:
|
|
self.rows.clear()
|
|
if not self.path.exists():
|
|
return
|
|
with self.path.open("r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
obj = json.loads(line)
|
|
self.rows.append(IndexRow(
|
|
id=obj["id"],
|
|
text=obj["text"],
|
|
vector=obj["vector"],
|
|
source=obj.get("source", "")
|
|
))
|
|
|
|
def search(self, query_vec: List[float], top_k: int = 5) -> List[Tuple[IndexRow, float]]:
|
|
scored: List[Tuple[IndexRow, float]] = []
|
|
for row in self.rows:
|
|
score = cosine_similarity(query_vec, row.vector)
|
|
scored.append((row, score))
|
|
scored.sort(key=lambda x: x[1], reverse=True)
|
|
return scored[:top_k]
|
|
|
|
def append(self, new_rows: List[IndexRow]) -> int:
|
|
if not new_rows:
|
|
return 0
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
with self.path.open("a", encoding="utf-8") as f:
|
|
for r in new_rows:
|
|
obj = {"id": r.id, "text": r.text, "vector": r.vector, "source": r.source}
|
|
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
|
self.rows.extend(new_rows)
|
|
return len(new_rows)
|
|
|
|
def reload(self) -> int:
|
|
self._load()
|
|
return len(self.rows)
|
|
|