Add compare to cli
This commit is contained in:
@@ -5,6 +5,7 @@ from pydantic import ValidationError
|
||||
|
||||
from beaky import _ansi
|
||||
from beaky.config import Config
|
||||
from beaky.image_classifier.classifier import img_classify
|
||||
from beaky.link_classifier.classifier import LinkClassifier
|
||||
from beaky.resolvers.resolver import TicketResolver, TicketVerdict
|
||||
from beaky.scanner.scanner import Links
|
||||
@@ -36,7 +37,7 @@ def main() -> None:
|
||||
parser = argparse.ArgumentParser(prog="beaky")
|
||||
parser.add_argument("--config", help="Path to config file.", default="config/application.yml")
|
||||
parser.add_argument("--id", type=int, help="Resolve a single ticket by id (only used with resolve mode).")
|
||||
parser.add_argument("mode", choices=["screenshotter", "parser", "class", "resolve"], help="Mode of operation.")
|
||||
parser.add_argument("mode", choices=["screenshotter", "parser", "class", "resolve", "compare"], help="Mode of operation.")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = load_config(args.config)
|
||||
@@ -72,6 +73,19 @@ def main() -> None:
|
||||
for k, v in vars(bet).items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
if args.mode == "compare":
|
||||
linkclassifier = LinkClassifier()
|
||||
links = [l for l in data.links if l.id == args.id] if args.id is not None else data.links
|
||||
if args.id is not None and not links:
|
||||
print(f"ERROR: ticket id {args.id} not found")
|
||||
return
|
||||
for link in links:
|
||||
linkClass = linkclassifier.classify(link)
|
||||
imgClass = img_classify(["./data/screenshots/{link.id}.png"], ticket_id=link.id)
|
||||
|
||||
print(linkClass)
|
||||
print(imgClass)
|
||||
|
||||
if args.mode == "resolve":
|
||||
classifier = LinkClassifier()
|
||||
resolver = TicketResolver(config.resolver)
|
||||
|
||||
@@ -132,20 +132,16 @@ def classify(text: str) -> Bet:
|
||||
return UnknownTicket(ticketType=BetType.UNKNOWN, raw_text=text, **base_args)
|
||||
|
||||
|
||||
def img_classify(path: str, ticket_id: int) -> Ticket:
|
||||
def img_classify(paths: list[str], ticket_id: int) -> Ticket:
|
||||
"""Given a path to an image and a date, return a list of Tickets that are
|
||||
relevant to that image and date."""
|
||||
# Define valid image extensions to ignore system files or text documents
|
||||
ticket = Ticket(id=ticket_id, bets=[])
|
||||
valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"}
|
||||
path_obj = Path(path)
|
||||
|
||||
if not path_obj.is_dir():
|
||||
print(f"Error: The path '{path}' is not a valid directory.")
|
||||
return ticket
|
||||
|
||||
# Iterate through all files in the folder
|
||||
for file_path in path_obj.iterdir():
|
||||
for file in paths:
|
||||
file_path = Path(file)
|
||||
if file_path.is_file() and file_path.suffix.lower() in valid_extensions:
|
||||
# 1. Extract the text (called separately)
|
||||
extracted_text = img_to_text(str(file_path))
|
||||
@@ -195,4 +191,4 @@ def img_classify(path: str, ticket_id: int) -> Ticket:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
img_classify("./data/screenshots/", ticket_id=1)
|
||||
img_classify(["./data/screenshots/2.png"], ticket_id=1)
|
||||
|
||||
Reference in New Issue
Block a user