#!/usr/bin/env python3
"""
TMED-II Phase III — Motor Line Marking Layout Extraction
Extracts machine footprints, column grid, and process station labels from DXF.
Generates: machine_positions.tsv, column_grid.tsv, unmatched_machines.txt, simplified DXF.

Usage:
    python3 extract_layout.py [--datum COL_LABEL] [--dxf PATH]
    python3 extract_layout.py --inspect                          # layer inspection only
    python3 extract_layout.py --machine-layers 0,4 --column-layers stlcol
"""

import argparse
import json
import math
import sys
from collections import Counter
from pathlib import Path

import ezdxf

SCRIPT_DIR = Path(__file__).parent
DEFAULT_DXF = SCRIPT_DIR / "V38_motor_line_marking_layout.dxf"
XREF_JSON = SCRIPT_DIR.parent / "cross_reference_output.json"

# --- Layout constants discovered from DXF analysis ---
# The building XREF block A$Cf0d7dff8 is nested inside A$301694157596
# which is INSERT'd at (108501.3, 54881.3) with sub-INSERT at (0,0)
BUILDING_XREF_BLOCK = 'A$Cf0d7dff8'
BUILDING_XREF_PARENT = 'A$301694157596'

# Motor line zone (world coordinates, mm)
MOTOR_LINE_X_MIN = 213000
MOTOR_LINE_X_MAX = 315000
MOTOR_LINE_Y_MIN = 110000
MOTOR_LINE_Y_MAX = 130000

# Machine footprint size range (mm)
MIN_MACHINE_DIM = 500
MAX_MACHINE_DIM = 20000

# Block name prefixes to exclude from machine detection
EXCLUDE_BLOCK_PREFIXES = ['XREF-PTM_POWERTECH']

# Entity types that are annotation, never geometry — always excluded
ANNOTATION_ENTITY_TYPES = {'DIMENSION', 'LEADER', 'MLEADER', 'HATCH', 'VIEWPORT', 'ATTDEF'}

# Default machine layers (layer 0 contains the equipment INSERT blocks)
DEFAULT_MACHINE_LAYERS = ['0', '2', '4']
# Default label layer
DEFAULT_LABEL_LAYERS = ['TEXT']


def find_xref_offset(doc, msp):
    """Find the world-coordinate offset for the building XREF block."""
    for e in msp:
        if e.dxftype() == 'INSERT' and e.dxf.name == BUILDING_XREF_PARENT:
            return e.dxf.insert.x, e.dxf.insert.y
    # Fallback: look for direct reference
    for e in msp:
        if e.dxftype() == 'INSERT' and e.dxf.name == BUILDING_XREF_BLOCK:
            return e.dxf.insert.x, e.dxf.insert.y
    return 0, 0


def extract_grid_from_xref(doc, offset_x, offset_y):
    """Extract grid lines and column positions from the nested building XREF block."""
    try:
        big_block = doc.blocks.get(BUILDING_XREF_BLOCK)
    except Exception:
        return [], [], []

    if big_block is None:
        return [], [], []

    grid_vert_x = []   # vertical grid line X positions
    grid_horiz_y = []   # horizontal grid line Y positions
    column_positions = []  # (x, y) of each column

    for e in big_block:
        # Grid lines on 00-grid layer
        if e.dxftype() == 'LINE' and '00-grid' in e.dxf.layer:
            x1 = e.dxf.start.x + offset_x
            y1 = e.dxf.start.y + offset_y
            x2 = e.dxf.end.x + offset_x
            y2 = e.dxf.end.y + offset_y
            length = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
            if length < 5000:
                continue
            if abs(x2 - x1) < 100:  # vertical
                grid_vert_x.append((x1 + x2) / 2)
            elif abs(y2 - y1) < 100:  # horizontal
                grid_horiz_y.append((y1 + y2) / 2)

        # Column INSERT entities on stlcol layers
        if e.dxftype() == 'INSERT' and 'stlcol' in e.dxf.layer.lower():
            wx = e.dxf.insert.x + offset_x
            wy = e.dxf.insert.y + offset_y
            column_positions.append((wx, wy))

    # Cluster grid positions
    def cluster(values, tol=500):
        if not values:
            return []
        values = sorted(values)
        clusters = [[values[0]]]
        for v in values[1:]:
            if v - clusters[-1][-1] < tol:
                clusters[-1].append(v)
            else:
                clusters.append([v])
        return [sum(c) / len(c) for c in clusters]

    grid_x = cluster(grid_vert_x)
    grid_y = cluster(grid_horiz_y)

    return grid_x, grid_y, column_positions


def filter_grid_to_motor_zone(grid_x, grid_y):
    """Filter grid lines to those relevant to the motor line zone."""
    # Keep grid lines that pass through or near the motor line area
    margin = 5000
    filt_x = [x for x in grid_x if MOTOR_LINE_X_MIN - margin <= x <= MOTOR_LINE_X_MAX + margin]
    filt_y = [y for y in grid_y if MOTOR_LINE_Y_MIN - margin <= y <= MOTOR_LINE_Y_MAX + margin]
    return filt_x, filt_y


def get_block_bbox(doc, block_name, cache={}):
    """Compute bounding box of a block definition (cached)."""
    if block_name in cache:
        return cache[block_name]

    try:
        block = doc.blocks.get(block_name)
    except Exception:
        cache[block_name] = None
        return None
    if block is None:
        cache[block_name] = None
        return None

    xs, ys = [], []
    for e in block:
        try:
            if e.dxftype() == 'LINE':
                xs.extend([e.dxf.start.x, e.dxf.end.x])
                ys.extend([e.dxf.start.y, e.dxf.end.y])
            elif e.dxftype() == 'LWPOLYLINE':
                for pt in e.get_points():
                    xs.append(pt[0])
                    ys.append(pt[1])
            elif e.dxftype() in ('CIRCLE', 'ARC'):
                xs.extend([e.dxf.center.x - e.dxf.radius, e.dxf.center.x + e.dxf.radius])
                ys.extend([e.dxf.center.y - e.dxf.radius, e.dxf.center.y + e.dxf.radius])
            elif e.dxftype() == 'SPLINE':
                for pt in e.control_points:
                    xs.append(pt.x)
                    ys.append(pt.y)
        except Exception:
            continue

    if not xs or not ys:
        cache[block_name] = None
        return None
    result = (min(xs), min(ys), max(xs), max(ys))
    cache[block_name] = result
    return result


def extract_machines(doc, msp, machine_layers=None):
    """Extract machine footprints from INSERT entities in the motor line zone.

    Filters by layer (machine_layers), excludes annotation entity types,
    excludes dimension blocks (*D), XREF blocks, and building wrapper blocks.
    """
    if machine_layers is None:
        machine_layers = DEFAULT_MACHINE_LAYERS
    machines = []

    for e in msp:
        # Only process INSERT entities (geometry blocks)
        if e.dxftype() != 'INSERT':
            continue
        # Skip annotation-type entities (belt-and-suspenders — INSERT won't match, but future-proof)
        if e.dxftype() in ANNOTATION_ENTITY_TYPES:
            continue

        block_name = e.dxf.name
        layer = e.dxf.layer
        insert_pt = e.dxf.insert
        rotation = e.dxf.get('rotation', 0)
        xscale = e.dxf.get('xscale', 1.0)
        yscale = e.dxf.get('yscale', 1.0)

        # Layer filter — only extract from specified machine layers
        if layer not in machine_layers:
            continue
        # Skip dimension blocks (*D prefix)
        if block_name.startswith('*'):
            continue
        # Skip XREF building blocks
        if any(block_name.startswith(pfx) for pfx in EXCLUDE_BLOCK_PREFIXES):
            continue
        # Skip the building wrapper blocks
        if block_name in (BUILDING_XREF_BLOCK, BUILDING_XREF_PARENT):
            continue

        bbox = get_block_bbox(doc, block_name)
        if bbox is None:
            continue

        bw = (bbox[2] - bbox[0]) * abs(xscale)
        bh = (bbox[3] - bbox[1]) * abs(yscale)

        # Apply rotation swap for 90/270
        if rotation % 360 in (90, 270):
            bw, bh = bh, bw

        # Size filter
        if not (MIN_MACHINE_DIM <= max(bw, bh) <= MAX_MACHINE_DIM and min(bw, bh) >= MIN_MACHINE_DIM):
            continue

        # Compute center in world coordinates
        bcx = (bbox[0] + bbox[2]) / 2 * xscale
        bcy = (bbox[1] + bbox[3]) / 2 * yscale
        rad = math.radians(rotation)
        wcx = insert_pt.x + bcx * math.cos(rad) - bcy * math.sin(rad)
        wcy = insert_pt.y + bcx * math.sin(rad) + bcy * math.cos(rad)

        # Zone filter — only motor line area
        if not (MOTOR_LINE_X_MIN <= wcx <= MOTOR_LINE_X_MAX and
                MOTOR_LINE_Y_MIN <= wcy <= MOTOR_LINE_Y_MAX):
            continue

        machines.append({
            'block_name': block_name,
            'insert_pos': (insert_pt.x, insert_pt.y),
            'center': (wcx, wcy),
            'width': bw,
            'height': bh,
            'layer': layer,
            'rotation': rotation,
            'label': None,
        })

    return machines


def extract_labels(msp):
    """Extract MTEXT and TEXT labels in the motor line zone."""
    labels = []
    for e in msp:
        if e.dxftype() == 'MTEXT':
            pt = e.dxf.insert
            if not (MOTOR_LINE_X_MIN <= pt.x <= MOTOR_LINE_X_MAX and
                    MOTOR_LINE_Y_MIN <= pt.y <= MOTOR_LINE_Y_MAX):
                continue
            raw = e.text
            text = raw
            # Strip MTEXT formatting {\W0.7;...} or {\fFont;...}
            if '{\\' in text:
                idx = text.find(';')
                if idx >= 0:
                    text = text[idx + 1:]
                if text.endswith('}'):
                    text = text[:-1]
            # Clean line breaks
            text = text.replace('\\P', ' ').replace('\\p', ' ').strip()
            labels.append({
                'text': text,
                'pos': (pt.x, pt.y),
                'layer': e.dxf.layer,
            })
        elif e.dxftype() == 'TEXT':
            pt = e.dxf.insert
            if not (MOTOR_LINE_X_MIN <= pt.x <= MOTOR_LINE_X_MAX and
                    MOTOR_LINE_Y_MIN <= pt.y <= MOTOR_LINE_Y_MAX):
                continue
            labels.append({
                'text': e.dxf.text.strip(),
                'pos': (pt.x, pt.y),
                'layer': e.dxf.layer,
            })
    return labels


def associate_labels(machines, labels):
    """Associate nearest label with each machine (within 5000mm)."""
    MAX_DIST = 5000
    for m in machines:
        mx, my = m['center']
        best_dist = MAX_DIST
        best_label = None
        for lb in labels:
            lx, ly = lb['pos']
            dist = math.sqrt((mx - lx) ** 2 + (my - ly) ** 2)
            if dist < best_dist:
                best_dist = dist
                best_label = lb['text']
        m['label'] = best_label


def load_xref_data(path):
    """Load cross-reference JSON with machine dimensions."""
    if not path.exists():
        print(f"  WARNING: cross-reference file not found: {path}")
        return {}
    with open(path) as f:
        data = json.load(f)
    rows = data.get("crossref", [])
    if not rows or len(rows) < 2:
        return {}
    header = rows[0]
    idx = {h: i for i, h in enumerate(header)}
    xref = {}
    for row in rows[1:]:
        line = row[idx.get("Line", 0)] if "Line" in idx else ""
        if "Motor" not in line:
            continue
        item_no = row[idx.get("Item No", 0)]
        korean = row[idx.get("Korean Name", 0)] if "Korean Name" in idx else ""
        english = row[idx.get("English Name", 0)] if "English Name" in idx else ""
        dims_str = row[idx.get("Max Dims LxWxH (mm)", 0)] if "Max Dims LxWxH (mm)" in idx else ""
        dims = None
        if dims_str and 'x' in dims_str.lower():
            parts = dims_str.lower().split('x')
            try:
                dims = tuple(float(p) for p in parts)
            except ValueError:
                pass
        xref[item_no] = {
            "english": english,
            "korean": korean,
            "dims": dims,
            "item_no": item_no,
        }
    return xref


def match_xref(machines, xref_data):
    """Match machines against cross-reference by Korean label."""
    for m in machines:
        m['xref_match'] = 'NO_REFERENCE'
        m['xref_delta'] = ''
        m['xref_item'] = ''

        label = m.get('label', '') or ''
        if not label:
            continue

        # Try matching Korean name substring
        best_match = None
        for item_no, ref in xref_data.items():
            korean = ref.get('korean', '')
            if not korean:
                continue
            # Exact substring match
            if korean in label or label in korean:
                best_match = ref
                break

        if best_match and best_match.get('dims'):
            ref_l, ref_w = best_match['dims'][0], best_match['dims'][1]
            m_l = max(m['width'], m['height'])
            m_w = min(m['width'], m['height'])
            r_l = max(ref_l, ref_w)
            r_w = min(ref_l, ref_w)
            delta_l = abs(m_l - r_l) / r_l * 100 if r_l > 0 else 0
            delta_w = abs(m_w - r_w) / r_w * 100 if r_w > 0 else 0
            max_delta = max(delta_l, delta_w)
            m['xref_match'] = 'MATCH' if max_delta <= 10 else f'DELTA_{max_delta:.0f}%'
            m['xref_delta'] = f'{max_delta:.1f}%'
            m['xref_item'] = best_match.get('item_no', '')
        elif best_match:
            m['xref_match'] = 'REF_NO_DIMS'
            m['xref_item'] = best_match.get('item_no', '')


def build_grid_points(grid_x, grid_y, datum):
    """Create labeled grid point list with spacings."""
    points = []
    for yi, y in enumerate(sorted(grid_y)):
        row_letter = chr(ord('A') + yi) if yi < 26 else f"R{yi}"
        for xi, x in enumerate(sorted(grid_x)):
            bay = xi + 1
            points.append({
                'label': f"{row_letter}{bay}",
                'x': x,
                'y': y,
                'dx': x - datum[0],
                'dy': y - datum[1],
            })
    return points


def write_machine_tsv(machines, datum, grid_points, path):
    """Write machine_positions.tsv."""
    # Find datum label
    datum_label = "ORIGIN"
    if grid_points:
        closest = min(grid_points, key=lambda g: math.sqrt(
            (g['x'] - datum[0]) ** 2 + (g['y'] - datum[1]) ** 2))
        datum_label = closest['label']

    # Sort by X
    machines.sort(key=lambda m: m['center'][0])

    with open(path, 'w') as f:
        f.write('\t'.join([
            'machine_name', 'datum_col', 'dx_from_datum_mm', 'dy_from_datum_mm',
            'width_mm', 'height_mm', 'rotation_deg',
            'prev_machine', 'gap_to_prev_mm',
            'xref_match', 'xref_delta_%'
        ]) + '\n')

        prev_name = '-'
        prev_center = None
        for m in machines:
            name = m.get('label') or m.get('block_name', 'UNKNOWN')
            dx = m['center'][0] - datum[0]
            dy = m['center'][1] - datum[1]
            gap = '-'
            if prev_center:
                gap = f'{math.sqrt((m["center"][0] - prev_center[0])**2 + (m["center"][1] - prev_center[1])**2):.0f}'
            f.write('\t'.join([
                name, datum_label, f'{dx:.0f}', f'{dy:.0f}',
                f'{m["width"]:.0f}', f'{m["height"]:.0f}', f'{m["rotation"]:.0f}',
                prev_name, gap,
                m.get('xref_match', 'NO_REFERENCE'), m.get('xref_delta', ''),
            ]) + '\n')
            prev_name = name
            prev_center = m['center']

    print(f"  Wrote {len(machines)} machines to {path}")


def write_column_tsv(grid_points, path):
    """Write column_grid.tsv with spacings."""
    if not grid_points:
        with open(path, 'w') as f:
            f.write("# No grid points found\n")
        return

    grid_points.sort(key=lambda g: (g['y'], g['x']))

    with open(path, 'w') as f:
        f.write('\t'.join([
            'col_label', 'x', 'y', 'dx_from_datum', 'dy_from_datum',
            'spacing_to_next_col_x', 'spacing_to_next_col_y'
        ]) + '\n')

        for i, g in enumerate(grid_points):
            sx, sy = '-', '-'
            if i + 1 < len(grid_points):
                nxt = grid_points[i + 1]
                sx = f'{abs(nxt["x"] - g["x"]):.0f}'
                sy = f'{abs(nxt["y"] - g["y"]):.0f}'
            f.write('\t'.join([
                g['label'], f'{g["x"]:.0f}', f'{g["y"]:.0f}',
                f'{g["dx"]:.0f}', f'{g["dy"]:.0f}', sx, sy,
            ]) + '\n')

    print(f"  Wrote {len(grid_points)} grid points to {path}")


def write_unmatched(machines, path):
    """Write unmatched_machines.txt."""
    unmatched = [m for m in machines if m.get('xref_match') == 'NO_REFERENCE']
    with open(path, 'w') as f:
        f.write(f"# Unmatched machines (no cross-reference match)\n")
        f.write(f"# Total: {len(unmatched)} of {len(machines)} extracted machines\n")
        f.write(f"# Known missing items from cross_reference_output.json: #15, #21, #27, #37, #42, #44, #53-54, #58-62, #64\n\n")
        for m in unmatched:
            name = m.get('label') or m.get('block_name', 'UNKNOWN')
            f.write(f"{name}\t{m['width']:.0f}x{m['height']:.0f}mm\tcenter=({m['center'][0]:.0f},{m['center'][1]:.0f})\n")
    print(f"  Wrote {len(unmatched)} unmatched machines to {path}")


def _resolve_color(entity, layer_colors, parent_color=7):
    """Resolve effective ACI color for an entity."""
    c = entity.dxf.get('color', 256)
    if c == 256:  # BYLAYER
        return abs(layer_colors.get(entity.dxf.layer, 7))
    elif c == 0:  # BYBLOCK
        return parent_color
    return abs(c)


def _entity_in_zone(entity):
    """Check if a geometry entity is in the motor line zone (approximate)."""
    et = entity.dxftype()
    try:
        if et == 'LINE':
            x = (entity.dxf.start.x + entity.dxf.end.x) / 2
            y = (entity.dxf.start.y + entity.dxf.end.y) / 2
        elif et in ('CIRCLE', 'ARC'):
            x, y = entity.dxf.center.x, entity.dxf.center.y
        elif et == 'LWPOLYLINE':
            pts = list(entity.get_points())
            if not pts:
                return False
            x = sum(p[0] for p in pts) / len(pts)
            y = sum(p[1] for p in pts) / len(pts)
        elif et == 'SPLINE':
            cps = list(entity.control_points)
            if not cps:
                return False
            x = sum(p.x for p in cps) / len(cps)
            y = sum(p.y for p in cps) / len(cps)
        elif et == 'ELLIPSE':
            x, y = entity.dxf.center.x, entity.dxf.center.y
        else:
            return False
        return (MOTOR_LINE_X_MIN <= x <= MOTOR_LINE_X_MAX and
                MOTOR_LINE_Y_MIN <= y <= MOTOR_LINE_Y_MAX)
    except Exception:
        return False


def _copy_entity_shifted(entity, new_msp, dx0, dy0, target_layer):
    """Copy a geometry entity into new modelspace, shifted by (dx0, dy0)."""
    et = entity.dxftype()
    attribs = {'layer': target_layer}

    try:
        if et == 'LINE':
            new_msp.add_line(
                (entity.dxf.start.x - dx0, entity.dxf.start.y - dy0),
                (entity.dxf.end.x - dx0, entity.dxf.end.y - dy0),
                dxfattribs=attribs)
        elif et == 'ARC':
            new_msp.add_arc(
                center=(entity.dxf.center.x - dx0, entity.dxf.center.y - dy0),
                radius=entity.dxf.radius,
                start_angle=entity.dxf.start_angle,
                end_angle=entity.dxf.end_angle,
                dxfattribs=attribs)
        elif et == 'CIRCLE':
            new_msp.add_circle(
                center=(entity.dxf.center.x - dx0, entity.dxf.center.y - dy0),
                radius=entity.dxf.radius,
                dxfattribs=attribs)
        elif et == 'LWPOLYLINE':
            pts = list(entity.get_points(format='xyseb'))
            shifted = [(p[0] - dx0, p[1] - dy0) + p[2:] for p in pts]
            poly = new_msp.add_lwpolyline(shifted, dxfattribs=attribs)
            poly.close(entity.closed)
        elif et == 'SPLINE':
            cps = [(p.x - dx0, p.y - dy0, p.z) for p in entity.control_points]
            new_msp.add_spline(cps, dxfattribs=attribs)
        elif et == 'ELLIPSE':
            new_msp.add_ellipse(
                center=(entity.dxf.center.x - dx0, entity.dxf.center.y - dy0),
                major_axis=entity.dxf.major_axis,
                ratio=entity.dxf.ratio,
                start_param=entity.dxf.start_param,
                end_param=entity.dxf.end_param,
                dxfattribs=attribs)
        else:
            return False
        return True
    except Exception:
        return False


def generate_simplified_dxf(machines, grid_points, grid_x, grid_y, datum, path,
                            source_doc, source_msp):
    """Generate simplified DXF: spatial crop + full explode + datum shift.

    Approach (F-18/F-20/F-21 combined):
    1. Spatial crop — everything inside expanded motor line zone
    2. Explode all INSERT blocks — walk all nesting via virtual_entities()
    3. No color filtering — include all geometry
    4. Preserve original colors — write resolved ACI color per entity
    5. Datum shift to (0,0)
    6. Exclude only: DIMENSION, LEADER, MLEADER, HATCH, VIEWPORT, ATTDEF
    7. H-beam columns from 5-level nested INSERT chain (separate pass)
    """
    dx0, dy0 = datum

    # Expanded spatial crop zone (wider margin for context)
    CROP_X_MIN, CROP_X_MAX = 205000, 320000
    CROP_Y_MIN, CROP_Y_MAX = 100000, 140000

    GEOM_TYPES = {'LINE', 'ARC', 'CIRCLE', 'LWPOLYLINE', 'SPLINE', 'ELLIPSE'}

    # Build layer color lookup
    layer_colors = {}
    for layer in source_doc.layers:
        layer_colors[layer.dxf.name] = abs(layer.dxf.color)

    new_doc = ezdxf.new('R2013')
    new_msp = new_doc.modelspace()

    # Single geometry layer — color per entity, not per layer
    new_doc.layers.add('GEOMETRY', color=7)
    new_doc.layers.add('COLUMNS', color=2)
    new_doc.layers.add('LABELS', color=3)
    new_doc.layers.add('GRID', color=8)
    new_doc.layers.add('DATUM', color=5)

    def in_crop(x, y):
        return CROP_X_MIN <= x <= CROP_X_MAX and CROP_Y_MIN <= y <= CROP_Y_MAX

    def entity_center(e):
        et = e.dxftype()
        if et == 'LINE':
            return (e.dxf.start.x + e.dxf.end.x) / 2, (e.dxf.start.y + e.dxf.end.y) / 2
        elif et in ('CIRCLE', 'ARC'):
            return e.dxf.center.x, e.dxf.center.y
        elif et == 'LWPOLYLINE':
            pts = list(e.get_points())
            if not pts:
                return None, None
            return sum(p[0] for p in pts) / len(pts), sum(p[1] for p in pts) / len(pts)
        elif et == 'SPLINE':
            try:
                cps = list(e.control_points)
                if not cps:
                    return None, None
                # control_points may be Vec3 objects or numpy arrays
                xs = [float(p[0]) for p in cps]
                ys = [float(p[1]) for p in cps]
                return sum(xs) / len(xs), sum(ys) / len(ys)
            except Exception:
                return None, None
        elif et == 'ELLIPSE':
            return e.dxf.center.x, e.dxf.center.y
        elif et in ('TEXT', 'MTEXT'):
            return e.dxf.insert.x, e.dxf.insert.y
        elif et == 'SOLID':
            return e.dxf.vtx0.x, e.dxf.vtx0.y
        elif et == 'POINT':
            return e.dxf.location.x, e.dxf.location.y
        return None, None

    def copy_shifted(entity, layer='GEOMETRY', color=None):
        """Copy entity shifted by datum, with explicit color."""
        et = entity.dxftype()
        attribs = {'layer': layer}
        if color is not None:
            attribs['color'] = color

        try:
            if et == 'LINE':
                new_msp.add_line(
                    (entity.dxf.start.x - dx0, entity.dxf.start.y - dy0),
                    (entity.dxf.end.x - dx0, entity.dxf.end.y - dy0),
                    dxfattribs=attribs)
            elif et == 'ARC':
                new_msp.add_arc(
                    center=(entity.dxf.center.x - dx0, entity.dxf.center.y - dy0),
                    radius=entity.dxf.radius,
                    start_angle=entity.dxf.start_angle,
                    end_angle=entity.dxf.end_angle,
                    dxfattribs=attribs)
            elif et == 'CIRCLE':
                new_msp.add_circle(
                    center=(entity.dxf.center.x - dx0, entity.dxf.center.y - dy0),
                    radius=entity.dxf.radius,
                    dxfattribs=attribs)
            elif et == 'LWPOLYLINE':
                pts = list(entity.get_points(format='xyseb'))
                shifted = [(p[0] - dx0, p[1] - dy0) + p[2:] for p in pts]
                poly = new_msp.add_lwpolyline(shifted, dxfattribs=attribs)
                poly.close(entity.closed)
            elif et == 'SPLINE':
                cps = [(float(p[0]) - dx0, float(p[1]) - dy0, float(p[2]) if len(p) > 2 else 0)
                       for p in entity.control_points]
                new_msp.add_spline(cps, dxfattribs=attribs)
            elif et == 'ELLIPSE':
                new_msp.add_ellipse(
                    center=(entity.dxf.center.x - dx0, entity.dxf.center.y - dy0),
                    major_axis=entity.dxf.major_axis,
                    ratio=entity.dxf.ratio,
                    start_param=entity.dxf.start_param,
                    end_param=entity.dxf.end_param,
                    dxfattribs=attribs)
            elif et == 'MTEXT':
                mt = new_msp.add_mtext(entity.text, dxfattribs=attribs)
                mt.dxf.insert = (entity.dxf.insert.x - dx0, entity.dxf.insert.y - dy0, 0)
                mt.dxf.char_height = entity.dxf.get('char_height', 100)
            elif et == 'TEXT':
                new_msp.add_text(entity.dxf.text, height=entity.dxf.get('height', 100),
                                 dxfattribs={**attribs,
                                             'insert': (entity.dxf.insert.x - dx0,
                                                        entity.dxf.insert.y - dy0)})
            elif et == 'SOLID':
                new_msp.add_solid([
                    (entity.dxf.vtx0.x - dx0, entity.dxf.vtx0.y - dy0),
                    (entity.dxf.vtx1.x - dx0, entity.dxf.vtx1.y - dy0),
                    (entity.dxf.vtx2.x - dx0, entity.dxf.vtx2.y - dy0),
                    (entity.dxf.vtx3.x - dx0, entity.dxf.vtx3.y - dy0),
                ], dxfattribs=attribs)
            elif et == 'POINT':
                new_msp.add_point(
                    (entity.dxf.location.x - dx0, entity.dxf.location.y - dy0),
                    dxfattribs=attribs)
            else:
                return False
            return True
        except Exception:
            return False

    copied = 0
    # Track types for verification
    from collections import Counter as _Counter
    type_counts = _Counter()

    # --- Pass 1: ALL direct (non-INSERT) entities in crop zone ---
    for e in source_msp:
        et = e.dxftype()
        if et == 'INSERT':
            continue  # handled in Pass 2
        # Try to get position and check crop zone
        x, y = entity_center(e)
        if x is not None and in_crop(x, y):
            rc = _resolve_color(e, layer_colors)
            if copy_shifted(e, 'GEOMETRY', color=rc):
                copied += 1
                type_counts[et] += 1
        # DIMENSION: also explode into primitives (they have virtual geometry)
        if et == 'DIMENSION':
            try:
                dx, dy = e.dxf.defpoint.x, e.dxf.defpoint.y
                if not in_crop(dx, dy):
                    continue
                for ve in e.virtual_entities():
                    vx, vy = entity_center(ve)
                    if vx is not None and in_crop(vx, vy):
                        rc = _resolve_color(ve, layer_colors)
                        if copy_shifted(ve, 'GEOMETRY', color=rc):
                            copied += 1
                            type_counts['DIM_' + ve.dxftype()] += 1
            except Exception:
                continue

    # --- Pass 2: Explode ALL INSERT blocks (no spatial pre-filter) ---
    # virtual_entities() handles full recursive nesting automatically.
    for e in source_msp:
        if e.dxftype() != 'INSERT':
            continue
        name = e.dxf.name
        if name.startswith('*'):
            continue

        parent_c = _resolve_color(e, layer_colors)
        try:
            for ve in e.virtual_entities():
                vet = ve.dxftype()
                if vet == 'INSERT':
                    continue  # virtual_entities already recursed
                x, y = entity_center(ve)
                if x is None or not in_crop(x, y):
                    continue
                rc = _resolve_color(ve, layer_colors, parent_color=parent_c)
                if copy_shifted(ve, 'GEOMETRY', color=rc):
                    copied += 1
                    type_counts['BLK_' + vet] += 1
        except Exception:
            continue

    print(f"  Entities by type:")
    for t, c in type_counts.most_common():
        print(f"    {t}: {c}")

    # --- Pass 3: Building XREF — walk all entities with manual transform ---
    # virtual_entities() can't reach into the deeply nested XREF chain, so we
    # walk the building block manually, applying the parent INSERT offset.
    xref_count = 0
    try:
        l1_insert = None
        for e in source_msp:
            if e.dxftype() == 'INSERT' and e.dxf.name == BUILDING_XREF_PARENT:
                l1_insert = e
                break
        if l1_insert:
            l1x, l1y = l1_insert.dxf.insert.x, l1_insert.dxf.insert.y
            b_building = source_doc.blocks.get(BUILDING_XREF_BLOCK)
            if b_building:
                # Walk all entities in building block
                for e2 in b_building:
                    et2 = e2.dxftype()

                    # Direct geometry: offset and crop
                    if et2 in ('LINE', 'ARC', 'CIRCLE', 'LWPOLYLINE', 'SPLINE', 'ELLIPSE'):
                        # Create a shifted copy by temporarily adjusting coords
                        # We need to add l1x,l1y to entity coords then subtract datum
                        x, y = entity_center(e2)
                        if x is None:
                            continue
                        wx, wy = x + l1x, y + l1y
                        if not in_crop(wx, wy):
                            continue
                        rc = _resolve_color(e2, layer_colors)
                        # Manual coord shift: entity local + l1 offset - datum
                        total_dx = dx0 - l1x
                        total_dy = dy0 - l1y
                        if _copy_entity_shifted(e2, new_msp, total_dx, total_dy, 'GEOMETRY'):
                            xref_count += 1

                    # Sub-INSERT blocks: explode recursively
                    elif et2 == 'INSERT':
                        sub_name = e2.dxf.name
                        if sub_name.startswith('*'):
                            continue
                        sub_x = e2.dxf.insert.x + l1x
                        sub_y = e2.dxf.insert.y + l1y
                        sub_rot = e2.dxf.get('rotation', 0)

                        sub_block = source_doc.blocks.get(sub_name)
                        if not sub_block:
                            continue

                        # Walk sub-block entities (handles depth 2-3)
                        def walk_block(block, parent_x, parent_y, parent_rot, depth=0):
                            nonlocal xref_count
                            if depth > 10:
                                return
                            rad = math.radians(parent_rot)
                            for ent in block:
                                ent_type = ent.dxftype()
                                if ent_type == 'INSERT':
                                    if ent.dxf.name.startswith('*'):
                                        continue
                                    child_block = source_doc.blocks.get(ent.dxf.name)
                                    if not child_block:
                                        continue
                                    # Transform child insert point
                                    cx, cy = ent.dxf.insert.x, ent.dxf.insert.y
                                    rx = cx * math.cos(rad) - cy * math.sin(rad)
                                    ry = cx * math.sin(rad) + cy * math.cos(rad)
                                    child_rot = parent_rot + ent.dxf.get('rotation', 0)
                                    walk_block(child_block, parent_x + rx, parent_y + ry,
                                               child_rot, depth + 1)
                                elif ent_type in ('LINE', 'ARC', 'CIRCLE', 'LWPOLYLINE',
                                                  'SPLINE', 'ELLIPSE', 'MTEXT', 'TEXT',
                                                  'SOLID', 'POINT'):
                                    x, y = entity_center(ent)
                                    if x is None:
                                        continue
                                    # Apply parent rotation + translation
                                    rx = x * math.cos(rad) - y * math.sin(rad)
                                    ry = x * math.sin(rad) + y * math.cos(rad)
                                    wx = rx + parent_x
                                    wy = ry + parent_y
                                    if not in_crop(wx, wy):
                                        continue
                                    # For rotated blocks, we can't simply shift coords —
                                    # use a synthetic shifted entity approach
                                    rc = _resolve_color(ent, layer_colors)
                                    # Write with world offset
                                    shift_x = dx0 - parent_x
                                    shift_y = dy0 - parent_y
                                    # For rotated entities this is approximate, but close enough
                                    # for non-rotated (rot=0) it's exact
                                    if parent_rot == 0:
                                        if _copy_entity_shifted(ent, new_msp, shift_x, shift_y, 'COLUMNS'):
                                            xref_count += 1
                                    else:
                                        # Rotated: write individual primitives manually
                                        if ent_type == 'LWPOLYLINE':
                                            pts = list(ent.get_points())
                                            world_pts = []
                                            for p in pts:
                                                prx = p[0]*math.cos(rad) - p[1]*math.sin(rad)
                                                pry = p[0]*math.sin(rad) + p[1]*math.cos(rad)
                                                world_pts.append((prx+parent_x-dx0, pry+parent_y-dy0))
                                            poly = new_msp.add_lwpolyline(world_pts, dxfattribs={'layer':'COLUMNS','color':rc})
                                            poly.close(ent.closed)
                                            xref_count += 1
                                        elif ent_type == 'LINE':
                                            s = ent.dxf.start
                                            e_ = ent.dxf.end
                                            sx = s.x*math.cos(rad)-s.y*math.sin(rad)+parent_x-dx0
                                            sy = s.x*math.sin(rad)+s.y*math.cos(rad)+parent_y-dy0
                                            ex = e_.x*math.cos(rad)-e_.y*math.sin(rad)+parent_x-dx0
                                            ey = e_.x*math.sin(rad)+e_.y*math.cos(rad)+parent_y-dy0
                                            new_msp.add_line((sx,sy),(ex,ey),dxfattribs={'layer':'COLUMNS','color':rc})
                                            xref_count += 1
                                        elif ent_type == 'CIRCLE':
                                            cx_ = ent.dxf.center.x*math.cos(rad)-ent.dxf.center.y*math.sin(rad)+parent_x-dx0
                                            cy_ = ent.dxf.center.x*math.sin(rad)+ent.dxf.center.y*math.cos(rad)+parent_y-dy0
                                            new_msp.add_circle((cx_,cy_),ent.dxf.radius,dxfattribs={'layer':'COLUMNS','color':rc})
                                            xref_count += 1
                                        elif ent_type == 'ARC':
                                            cx_ = ent.dxf.center.x*math.cos(rad)-ent.dxf.center.y*math.sin(rad)+parent_x-dx0
                                            cy_ = ent.dxf.center.x*math.sin(rad)+ent.dxf.center.y*math.cos(rad)+parent_y-dy0
                                            new_msp.add_arc((cx_,cy_),ent.dxf.radius,
                                                ent.dxf.start_angle+parent_rot,ent.dxf.end_angle+parent_rot,
                                                dxfattribs={'layer':'COLUMNS','color':rc})
                                            xref_count += 1

                        walk_block(sub_block, sub_x, sub_y, sub_rot)
    except Exception as ex:
        print(f"  WARNING: Building XREF walk failed: {ex}")

    print(f"  Building XREF entities: {xref_count}")

    # --- Pass 4: Grid lines + labels ---
    if grid_x and grid_y:
        y_min = min(grid_y) - dy0 - 2000
        y_max = max(grid_y) - dy0 + 2000
        x_min = min(grid_x) - dx0 - 2000
        x_max = max(grid_x) - dx0 + 2000
        for x in grid_x:
            new_msp.add_line((x - dx0, y_min), (x - dx0, y_max), dxfattribs={'layer': 'GRID'})
        for y in grid_y:
            new_msp.add_line((x_min, y - dy0), (x_max, y - dy0), dxfattribs={'layer': 'GRID'})

    for g in grid_points:
        gx, gy = g['x'] - dx0, g['y'] - dy0
        new_msp.add_circle((gx, gy), radius=200, dxfattribs={'layer': 'GRID'})
        new_msp.add_text(g['label'], height=150,
                         dxfattribs={'layer': 'LABELS', 'insert': (gx + 250, gy + 250)})

    new_msp.add_circle((0, 0), radius=400, dxfattribs={'layer': 'DATUM'})
    new_msp.add_text('DATUM A1 (0,0)', height=200,
                     dxfattribs={'layer': 'DATUM', 'insert': (500, 500)})

    # Station name labels
    for e in source_msp:
        if e.dxftype() in ('MTEXT', 'TEXT'):
            px, py = e.dxf.insert.x, e.dxf.insert.y
            if not in_crop(px, py):
                continue
            if e.dxftype() == 'MTEXT':
                raw = e.text
                text = raw
                if '{\\' in text:
                    idx = text.find(';')
                    if idx >= 0:
                        text = text[idx + 1:]
                    if text.endswith('}'):
                        text = text[:-1]
                text = text.replace('\\P', ' ').replace('\\p', ' ').strip()
            else:
                text = e.dxf.text.strip()
            if text:
                new_msp.add_text(text, height=150,
                                 dxfattribs={'layer': 'LABELS', 'insert': (px - dx0, py - dy0)})

    new_doc.saveas(str(path))
    total_out = sum(1 for _ in new_doc.modelspace())
    print(f"  Wrote simplified DXF to {path}")
    print(f"  Geometry: {copied} from spatial crop, {xref_count} building XREF")
    print(f"  Crop zone: X [{CROP_X_MIN}-{CROP_X_MAX}], Y [{CROP_Y_MIN}-{CROP_Y_MAX}]")
    print(f"  Datum A1 shifted to (0, 0)")

    # Round-trip verification
    try:
        verify = ezdxf.readfile(str(path))
        count = sum(1 for _ in verify.modelspace())
        print(f"  Round-trip OK: {count} entities readable")
    except Exception as ex:
        print(f"  WARNING: Round-trip failed: {ex}")


def inspect_layers(doc, msp, offset_x, offset_y):
    """Layer inspection: list all layers with entity counts and auto-classification."""
    from collections import defaultdict

    print("\n=== LAYER INSPECTION (Motor Line Zone) ===\n")

    # Count entities per layer in motor zone
    layer_entities = defaultdict(Counter)
    for e in msp:
        et = e.dxftype()
        layer = e.dxf.layer
        try:
            if et == 'INSERT':
                x, y = e.dxf.insert.x, e.dxf.insert.y
            elif et in ('TEXT', 'MTEXT'):
                x, y = e.dxf.insert.x, e.dxf.insert.y
            elif et == 'LINE':
                x, y = (e.dxf.start.x + e.dxf.end.x) / 2, (e.dxf.start.y + e.dxf.end.y) / 2
            elif et == 'DIMENSION':
                x, y = e.dxf.defpoint.x, e.dxf.defpoint.y
            elif et in ('CIRCLE', 'ARC'):
                x, y = e.dxf.center.x, e.dxf.center.y
            elif et == 'LWPOLYLINE':
                pts = list(e.get_points())
                if not pts:
                    continue
                x = sum(p[0] for p in pts) / len(pts)
                y = sum(p[1] for p in pts) / len(pts)
            else:
                continue
            if MOTOR_LINE_X_MIN <= x <= MOTOR_LINE_X_MAX and MOTOR_LINE_Y_MIN <= y <= MOTOR_LINE_Y_MAX:
                layer_entities[layer][et] += 1
        except Exception:
            continue

    # Also inspect XREF block for grid/column layers
    xref_layer_entities = defaultdict(Counter)
    try:
        big_block = doc.blocks.get(BUILDING_XREF_BLOCK)
        if big_block:
            for e in big_block:
                et = e.dxftype()
                layer = e.dxf.layer
                xref_layer_entities[layer][et] += 1
    except Exception:
        pass

    # Auto-classify layers
    print(f"{'Layer':<50} {'Total':>6} {'Types':<40} {'Classification'}")
    print("-" * 120)
    for layer, types in sorted(layer_entities.items(), key=lambda x: -sum(x[1].values())):
        total = sum(types.values())
        type_str = ', '.join(f'{t}:{c}' for t, c in types.most_common(5))

        # Auto-classify
        has_insert = 'INSERT' in types
        has_dim = 'DIMENSION' in types
        has_text = 'TEXT' in types or 'MTEXT' in types
        classification = []
        if has_insert:
            classification.append('MACHINE')
        if has_dim:
            classification.append('ANNOTATION')
        if has_text:
            classification.append('LABEL')
        if not classification:
            classification.append('GEOMETRY')
        cls_str = '+'.join(classification)

        print(f"  {layer:<48} {total:>6} {type_str:<40} {cls_str}")

    # Show XREF grid/column layers
    print(f"\n--- XREF Building Block Layers (grid/column relevant) ---")
    for layer, types in sorted(xref_layer_entities.items()):
        if any(kw in layer.lower() for kw in ['grid', 'stlcol', 'col', 'center']):
            total = sum(types.values())
            type_str = ', '.join(f'{t}:{c}' for t, c in types.most_common(5))
            print(f"  {layer:<48} {total:>6} {type_str}")

    # Summary
    print(f"\n--- Auto-Classification Summary ---")
    print(f"  Machine layers (contain INSERT): {[l for l, t in layer_entities.items() if 'INSERT' in t]}")
    print(f"  Annotation layers (contain DIMENSION): {[l for l, t in layer_entities.items() if 'DIMENSION' in t]}")
    print(f"  Label layers (contain TEXT/MTEXT): {[l for l, t in layer_entities.items() if 'TEXT' in t or 'MTEXT' in t]}")
    print(f"\nAnnotation entity types always excluded: {ANNOTATION_ENTITY_TYPES}")
    print(f"\nTo override, use: --machine-layers 0,2,4")


def main():
    parser = argparse.ArgumentParser(description="Extract TMED-II motor line layout from DXF")
    parser.add_argument('--dxf', default=str(DEFAULT_DXF), help='Path to DXF file')
    parser.add_argument('--datum', default=None, help='Override datum column label')
    parser.add_argument('--inspect', action='store_true', help='Layer inspection only (no extraction)')
    parser.add_argument('--machine-layers', default=None,
                        help='Comma-separated layers for machine extraction (default: 0,2,4)')
    parser.add_argument('--column-layers', default=None,
                        help='Comma-separated keywords for column layers (default: stlcol)')
    args = parser.parse_args()

    # Parse layer overrides
    machine_layers = args.machine_layers.split(',') if args.machine_layers else DEFAULT_MACHINE_LAYERS
    label_layers = DEFAULT_LABEL_LAYERS

    dxf_path = Path(args.dxf)
    if not dxf_path.exists():
        print(f"ERROR: DXF file not found: {dxf_path}")
        sys.exit(1)

    out_dir = dxf_path.parent

    print(f"=== TMED-II Motor Line Layout Extraction ===")
    print(f"DXF: {dxf_path}")
    print(f"Motor line zone: X [{MOTOR_LINE_X_MIN}-{MOTOR_LINE_X_MAX}], Y [{MOTOR_LINE_Y_MIN}-{MOTOR_LINE_Y_MAX}]")

    # Load DXF
    print("\n[1/7] Loading DXF...")
    doc = ezdxf.readfile(str(dxf_path))
    msp = doc.modelspace()
    total = sum(1 for _ in msp)
    print(f"  {total} entities in modelspace")

    # Find building XREF offset
    print("\n[2/7] Locating building XREF...")
    offset_x, offset_y = find_xref_offset(doc, msp)
    print(f"  XREF offset: ({offset_x:.0f}, {offset_y:.0f})")

    # Inspect mode
    if args.inspect:
        inspect_layers(doc, msp, offset_x, offset_y)
        return

    # Extract grid from XREF
    print("\n[3/7] Extracting grid from nested XREF block...")
    all_grid_x, all_grid_y, all_columns = extract_grid_from_xref(doc, offset_x, offset_y)
    print(f"  Full building grid: {len(all_grid_x)} vertical, {len(all_grid_y)} horizontal")
    print(f"  Column positions: {len(all_columns)}")

    # Filter to motor line zone
    grid_x, grid_y = filter_grid_to_motor_zone(all_grid_x, all_grid_y)
    print(f"  Motor zone grid: {len(grid_x)} vertical, {len(grid_y)} horizontal")

    if grid_x:
        print(f"  Grid X positions: {[f'{x:.0f}' for x in grid_x]}")
    if grid_y:
        print(f"  Grid Y positions: {[f'{y:.0f}' for y in grid_y]}")

    # Select datum (min X, min Y of motor zone grid)
    if grid_x and grid_y:
        datum = (min(grid_x), min(grid_y))
    else:
        datum = (MOTOR_LINE_X_MIN, MOTOR_LINE_Y_MIN)
    print(f"  Datum: ({datum[0]:.0f}, {datum[1]:.0f})")

    # Build grid points
    grid_points = build_grid_points(grid_x, grid_y, datum)
    print(f"  Grid intersection points: {len(grid_points)}")

    # Extract machines (layer-first approach)
    print(f"\n[4/7] Extracting machine footprints (layers: {machine_layers})...")
    machines = extract_machines(doc, msp, machine_layers)
    print(f"  Machines in motor zone: {len(machines)}")

    # Extract labels
    print("\n[5/7] Extracting labels...")
    labels = extract_labels(msp)
    print(f"  Labels in motor zone: {len(labels)}")

    # Associate labels with machines
    associate_labels(machines, labels)
    labeled = sum(1 for m in machines if m.get('label'))
    print(f"  Machines with labels: {labeled}")

    # Cross-reference
    print("\n[6/7] Cross-referencing dimensions...")
    xref_data = load_xref_data(XREF_JSON)
    print(f"  {len(xref_data)} Motor Line items in cross-reference")
    match_xref(machines, xref_data)
    matched = sum(1 for m in machines if m.get('xref_match') == 'MATCH')
    delta = sum(1 for m in machines if m.get('xref_match', '').startswith('DELTA'))
    no_ref = sum(1 for m in machines if m.get('xref_match') == 'NO_REFERENCE')
    ref_no_dims = sum(1 for m in machines if m.get('xref_match') == 'REF_NO_DIMS')
    print(f"  MATCH: {matched}, DELTA>10%: {delta}, REF_NO_DIMS: {ref_no_dims}, NO_REFERENCE: {no_ref}")

    # Write outputs
    print("\n[7/7] Writing outputs...")
    write_machine_tsv(machines, datum, grid_points, out_dir / 'machine_positions.tsv')
    write_column_tsv(grid_points, out_dir / 'column_grid.tsv')
    write_unmatched(machines, out_dir / 'unmatched_machines.txt')
    generate_simplified_dxf(machines, grid_points, grid_x, grid_y, datum,
                            out_dir / 'TMED2_marking_simplified.dxf',
                            source_doc=doc, source_msp=msp)

    # Summary
    print(f"\n{'='*50}")
    print(f"SUMMARY")
    print(f"{'='*50}")
    print(f"  Machines extracted: {len(machines)}")
    print(f"  Machines with labels: {labeled}")
    print(f"  Grid points: {len(grid_points)}")
    print(f"  Datum: ({datum[0]:.0f}, {datum[1]:.0f})")
    print(f"  XRef: {matched} match, {delta} delta>10%, {ref_no_dims} ref_no_dims, {no_ref} no_reference")
    print(f"\nOutput files:")
    for f in ['machine_positions.tsv', 'column_grid.tsv', 'unmatched_machines.txt', 'TMED2_marking_simplified.dxf']:
        print(f"  {out_dir / f}")


if __name__ == '__main__':
    main()
