Add compare to cli

This commit is contained in:
2026-03-22 11:26:22 +01:00
parent 04987555e5
commit 765e90be11
2 changed files with 19 additions and 9 deletions

View File

@@ -5,6 +5,7 @@ from pydantic import ValidationError
from beaky import _ansi from beaky import _ansi
from beaky.config import Config from beaky.config import Config
from beaky.image_classifier.classifier import img_classify
from beaky.link_classifier.classifier import LinkClassifier from beaky.link_classifier.classifier import LinkClassifier
from beaky.resolvers.resolver import TicketResolver, TicketVerdict from beaky.resolvers.resolver import TicketResolver, TicketVerdict
from beaky.scanner.scanner import Links from beaky.scanner.scanner import Links
@@ -36,7 +37,7 @@ def main() -> None:
parser = argparse.ArgumentParser(prog="beaky") parser = argparse.ArgumentParser(prog="beaky")
parser.add_argument("--config", help="Path to config file.", default="config/application.yml") parser.add_argument("--config", help="Path to config file.", default="config/application.yml")
parser.add_argument("--id", type=int, help="Resolve a single ticket by id (only used with resolve mode).") parser.add_argument("--id", type=int, help="Resolve a single ticket by id (only used with resolve mode).")
parser.add_argument("mode", choices=["screenshotter", "parser", "class", "resolve"], help="Mode of operation.") parser.add_argument("mode", choices=["screenshotter", "parser", "class", "resolve", "compare"], help="Mode of operation.")
args = parser.parse_args() args = parser.parse_args()
config = load_config(args.config) config = load_config(args.config)
@@ -72,6 +73,19 @@ def main() -> None:
for k, v in vars(bet).items(): for k, v in vars(bet).items():
print(f" {k}: {v}") print(f" {k}: {v}")
if args.mode == "compare":
linkclassifier = LinkClassifier()
links = [l for l in data.links if l.id == args.id] if args.id is not None else data.links
if args.id is not None and not links:
print(f"ERROR: ticket id {args.id} not found")
return
for link in links:
linkClass = linkclassifier.classify(link)
imgClass = img_classify(["./data/screenshots/{link.id}.png"], ticket_id=link.id)
print(linkClass)
print(imgClass)
if args.mode == "resolve": if args.mode == "resolve":
classifier = LinkClassifier() classifier = LinkClassifier()
resolver = TicketResolver(config.resolver) resolver = TicketResolver(config.resolver)

View File

@@ -132,20 +132,16 @@ def classify(text: str) -> Bet:
return UnknownTicket(ticketType=BetType.UNKNOWN, raw_text=text, **base_args) return UnknownTicket(ticketType=BetType.UNKNOWN, raw_text=text, **base_args)
def img_classify(path: str, ticket_id: int) -> Ticket: def img_classify(paths: list[str], ticket_id: int) -> Ticket:
"""Given a path to an image and a date, return a list of Tickets that are """Given a path to an image and a date, return a list of Tickets that are
relevant to that image and date.""" relevant to that image and date."""
# Define valid image extensions to ignore system files or text documents # Define valid image extensions to ignore system files or text documents
ticket = Ticket(id=ticket_id, bets=[]) ticket = Ticket(id=ticket_id, bets=[])
valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"} valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"}
path_obj = Path(path)
if not path_obj.is_dir():
print(f"Error: The path '{path}' is not a valid directory.")
return ticket
# Iterate through all files in the folder # Iterate through all files in the folder
for file_path in path_obj.iterdir(): for file in paths:
file_path = Path(file)
if file_path.is_file() and file_path.suffix.lower() in valid_extensions: if file_path.is_file() and file_path.suffix.lower() in valid_extensions:
# 1. Extract the text (called separately) # 1. Extract the text (called separately)
extracted_text = img_to_text(str(file_path)) extracted_text = img_to_text(str(file_path))
@@ -195,4 +191,4 @@ def img_classify(path: str, ticket_id: int) -> Ticket:
if __name__ == "__main__": if __name__ == "__main__":
img_classify("./data/screenshots/", ticket_id=1) img_classify(["./data/screenshots/2.png"], ticket_id=1)