import json, math, argparse
from pathlib import Path
import cv2
import numpy as np

def get_contours_from_image(image_path, element_type="unknown"):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Cannot load image {image_path}")

    # Threshold the image to ensure binary mask
    _, binary_mask = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)

    # Find all contours
    cnts, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Create separate objects for each disconnected contour
    contours_list = []
    for c in cnts:
        if len(c) > 2:  # Need at least 3 points to form a polygon
            pts = [{"x":int(p[0][0]), "y":int(p[0][1])} for p in c]
            contours_list.append({
                "class": element_type,
                "points": pts
            })

    return {"custom_contours": contours_list}

def min_pt_dist(pts1, pts2, early_break=None):
    best = float('inf')
    for p in pts1:
        for q in pts2:
            d = math.hypot(p['x']-q['x'], p['y']-q['y'])
            if d < best:
                best = d
                if early_break is not None and best <= early_break:
                    return best
    return best

def merge_with_scaled_mask(orig_img_path, mask_path, pred_json, output_json, threshold=50, element_type="deck"):
    orig = cv2.imread(orig_img_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    oh, ow = orig.shape[:2]
    mh, mw = mask.shape
    sx, sy = ow/mw, oh/mh

    preds = json.loads(Path(pred_json).read_text())
    contours = get_contours_from_image(mask_path, element_type)

    # Scale each contour's points to match original image size
    for contour in contours["custom_contours"]:
        contour["points"] = [
            { "x": int(pt["x"]*sx + 0.5), "y": int(pt["y"]*sy + 0.5) }
            for pt in contour["points"]
        ]

    # Process each contour separately
    for contour in contours["custom_contours"]:
        cls = contour["class"]
        cpts = contour["points"]

        # Find all predictions of the same class
        same_class_preds = [p for p in preds["predictions"] if p["class"] == cls]

        best_pred, best_dist = None, float("inf")
        for p in same_class_preds:
            d = min_pt_dist(cpts, p["points"], early_break=best_dist)
            if d < best_dist:
                best_dist, best_pred = d, p

        # if best_pred and best_dist <= threshold:
        #     # Add as a custom contour to the existing prediction
        #     best_pred.setdefault("custom_contours", []).append({"points": cpts})
        # else:
        # Add as a new prediction
        preds["predictions"].append({
            "class": cls,
            "points": cpts
        })

    Path(output_json).write_text(json.dumps(preds, indent=2))
    print(f"Merged JSON written to {output_json}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--orig", required=True, help="Path to original image")
    parser.add_argument("--mask", required=True, help="Path to mask image")
    parser.add_argument("--pred", required=True, help="Path to prediction JSON")
    parser.add_argument("--out", required=True, help="Output merged JSON")
    parser.add_argument("--threshold", type=int, default=50, help="Distance threshold")
    parser.add_argument("--type", default="deck", help="Element type")
    args = parser.parse_args()

    merge_with_scaled_mask(
        orig_img_path=args.orig,
        mask_path=args.mask,
        pred_json=args.pred,
        output_json=args.out,
        threshold=args.threshold,
        element_type=args.type
    )
