#!/usr/bin/env python3
"""nanocode - minimal claude code alternative"""
import glob as globlib
import hashlib
import json
import os
import random
import re
import readline
import select
import ssl
import subprocess
import sys
import termios
import time
import tty
import unicodedata
import urllib.request
import urllib.parse
from datetime import datetime

OPENROUTER_KEY = os.environ.get("OPENROUTER_API_KEY")
LOCAL_API_KEY = os.environ.get("LOCAL_API_KEY")
API_URL = (
    "http://127.0.0.1:8990/v1/messages" if LOCAL_API_KEY
    else "https://openrouter.ai/api/v1/messages" if OPENROUTER_KEY
    else "https://api.anthropic.com/v1/messages"
)
MODEL = os.environ.get("MODEL",
    "anthropic/claude-sonnet-4.5" if LOCAL_API_KEY
    else "anthropic/claude-opus-4.5" if OPENROUTER_KEY
    else "claude-opus-4-5"
)

# ANSI colors
RESET, BOLD, DIM = "\033[0m", "\033[1m", "\033[2m"
BLUE, CYAN, GREEN, YELLOW, RED = "\033[34m", "\033[36m", "\033[32m", "\033[33m", "\033[31m"
stop_flag = False

def detect_clipboard_image():
    """Detect and return image from clipboard using system commands
    
    Returns:
        dict or None: Image content block with metadata, or None if no image
    """
    try:
        import base64
        import tempfile
        
        # Save clipboard image to temp file
        temp_file = None
        image_data = None
        
        if sys.platform == 'darwin':
            # macOS: use pngpaste or osascript
            temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
            temp_file.close()
            
            # Try pngpaste first (faster)
            result = subprocess.run(
                ['pngpaste', temp_file.name],
                capture_output=True, timeout=5  # Increase timeout for large images
            )
            
            if result.returncode != 0:
                # Fallback to osascript
                script = f'''
                set theFile to POSIX file "{temp_file.name}"
                try
                    set theImage to the clipboard as «class PNGf»
                    set fileRef to open for access theFile with write permission
                    write theImage to fileRef
                    close access fileRef
                    return "success"
                on error errMsg
                    return "error: " & errMsg
                end try
                '''
                result = subprocess.run(
                    ['osascript', '-e', script],
                    capture_output=True, timeout=5, text=True
                )
                if not result.stdout.strip().startswith('success'):
                    print(f"{RED}✗ osascript error: {result.stdout.strip()}{RESET}", flush=True)
                    os.unlink(temp_file.name)
                    return None
            
            # Check file size
            file_size = os.path.getsize(temp_file.name)
            
            # Read the file
            with open(temp_file.name, 'rb') as f:
                image_data = f.read()
            os.unlink(temp_file.name)
            
        elif sys.platform.startswith('linux'):
            # Linux: use xclip
            result = subprocess.run(
                ['xclip', '-selection', 'clipboard', '-t', 'image/png', '-o'],
                capture_output=True, timeout=5
            )
            if result.returncode == 0 and result.stdout:
                image_data = result.stdout
            else:
                return None
                
        elif sys.platform == 'win32':
            # Windows: use PowerShell
            temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
            temp_file.close()
            
            ps_script = f'''
            Add-Type -AssemblyName System.Windows.Forms
            $img = [System.Windows.Forms.Clipboard]::GetImage()
            if ($img) {{
                $img.Save("{temp_file.name}", [System.Drawing.Imaging.ImageFormat]::Png)
                exit 0
            }} else {{
                exit 1
            }}
            '''
            result = subprocess.run(
                ['powershell', '-Command', ps_script],
                capture_output=True, timeout=5
            )
            
            if result.returncode == 0:
                file_size = os.path.getsize(temp_file.name)
                with open(temp_file.name, 'rb') as f:
                    image_data = f.read()
                os.unlink(temp_file.name)
            else:
                os.unlink(temp_file.name)
                return None
        else:
            return None
        
        if not image_data:
            return None
        
        # Get image size (read PNG header)
        # PNG signature: 8 bytes, then IHDR chunk with width/height
        if len(image_data) > 24 and image_data[:8] == b'\x89PNG\r\n\x1a\n':
            width = int.from_bytes(image_data[16:20], 'big')
            height = int.from_bytes(image_data[20:24], 'big')
            size = (width, height)
        else:
            size = (0, 0)
        
        # Check if too large
        max_size = 1568
        if size[0] > max_size or size[1] > max_size:
            print(f"\r{YELLOW}⚠ Image {size[0]}x{size[1]} exceeds {max_size}x{max_size}, may be slow{RESET}", flush=True)
            import time
            time.sleep(1)
        
        # Check file size limit (5MB for Claude API)
        if len(image_data) > 5 * 1024 * 1024:
            print(f"{RED}✗ Image too large: {len(image_data)/1024/1024:.2f} MB (max 5MB){RESET}", flush=True)
            return None
        
        # Encode to base64
        base64_data = base64.b64encode(image_data).decode('utf-8')
        
        return {
            "type": "image",
            "source": {
                "type": "base64",
                "media_type": "image/png",
                "data": base64_data
            },
            "_metadata": {
                "size": size,
                "resized": None
            }
        }
    except subprocess.TimeoutExpired as e:
        print(f"\r{RED}✗ Timeout: Image too large or system command hung{RESET}", flush=True)
        import time
        time.sleep(1)
        return None
    except Exception as e:
        return None

def create_opener():
    """Create URL opener with SSL and proxy support"""
    proxy = os.environ.get("http_proxy") or os.environ.get("https_proxy")
    ssl_ctx = ssl.create_default_context()
    ssl_ctx.check_hostname = False
    ssl_ctx.verify_mode = ssl.CERT_NONE

    handlers = [urllib.request.HTTPSHandler(context=ssl_ctx)]
    if proxy: handlers.insert(0, urllib.request.ProxyHandler({"http": proxy, "https": proxy}))
    return urllib.request.build_opener(*handlers)

def register_tool(name, desc, params):
    """Register a tool from extension code"""
    def decorator(func):
        TOOLS[name] = (desc, params, func)
        return func
    return decorator

def search_extension(args):
    """Search extensions from gist.kitchain.cn"""
    query = args.get("query", "")
    if not query: return "error: query required"
    try:
        # Split query into keywords
        keywords = query.lower().split()
        gist_info = {}  # {gist_path: {"hits": count, "title": str, "desc": str, "topics": []}}
        opener = create_opener()

        # Search each keyword as a topic
        for keyword in keywords:
            url = f"https://gist.kitchain.cn/topics/{urllib.parse.quote(keyword)}"
            html = opener.open(urllib.request.Request(url), timeout=10).read().decode()

            # Extract gist URLs and titles
            gist_matches = re.findall(
                r'<a class="font-bold" href="https://gist\.kitchain\.cn/([^/]+/[a-f0-9]+)">([^<]+)</a>',
                html
            )

            for gist_path, title in gist_matches:
                if gist_path not in gist_info:
                    # Extract description and topics for this gist
                    gist_section = re.search(
                        rf'{re.escape(gist_path)}.*?'
                        r'<h6 class="text-xs[^"]*">([^<]+)</h6>(.*?)</div>\s*</div>',
                        html, re.DOTALL
                    )
                    desc = ""
                    topics = []
                    if gist_section:
                        desc = gist_section.group(1).strip()
                        topics_section = gist_section.group(2)
                        topics = re.findall(r'topics/([^"]+)"[^>]*>([^<]+)<', topics_section)
                        topics = [t[1] for t in topics]  # Extract topic names

                    gist_info[gist_path] = {
                        "hits": 0,
                        "title": title.strip(),
                        "desc": desc,
                        "topics": topics,
                        "filename": title.strip()
                    }
                gist_info[gist_path]["hits"] += 1

        if not gist_info: return f"No extensions found: {query}"

        # Sort by hit count (descending)
        sorted_gists = sorted(gist_info.items(), key=lambda x: x[1]["hits"], reverse=True)[:10]

        result = f"Found {len(sorted_gists)} extensions:\n\n"
        for gist_path, info in sorted_gists:
            result += f"• {info['title']}\n"
            if info['desc']:
                result += f"  {info['desc']}\n"
            if info['topics']:
                result += f"  Topics: {', '.join(info['topics'])}\n"
            result += f"  Matched: {info['hits']} keyword(s)\n\n"

        # Return first gist's load URL
        first_gist = sorted_gists[0][0]
        first_filename = sorted_gists[0][1]['filename']
        result += f"To load the top result:\nload({{\"url\": \"https://gist.kitchain.cn/{first_gist}/raw/HEAD/{first_filename}\"}})"
        return result
    except Exception as e:
        return f"error: {e}"

def load(args):
    """Load extension from URL"""
    url = args.get("url")
    if not url: return "error: url required"
    try:
        opener = create_opener()
        code = opener.open(urllib.request.Request(url), timeout=10).read().decode()
        exec(code, {"register_tool": register_tool, "TOOLS": TOOLS, "urllib": urllib, "json": json, "re": re, "subprocess": subprocess})
        new = [k for k in TOOLS if k not in ["read","write","edit","glob","grep","bash","web_search","search_extension","load"]]
        return f"Loaded. New tools: {', '.join(new)}"
    except Exception as e:
        return f"error: {e}"

# --- Tools ---
def read(args):
    lines = open(args["path"]).readlines()
    offset, limit = args.get("offset", 0), args.get("limit", len(lines))
    return "".join(f"{offset+i+1:4}| {l}" for i, l in enumerate(lines[offset:offset+limit]))

def write(args):
    filepath = args["path"]
    content = args["content"]
    print(f"{DIM}[LOG] write: {filepath} ({len(content)} bytes){RESET}", flush=True)
    open(filepath, "w").write(content)
    print(f"{DIM}[LOG] write completed: {filepath}{RESET}", flush=True)
    return "ok"

def edit(args):
    filepath = args["path"]
    print(f"{DIM}[LOG] edit: {filepath}{RESET}", flush=True)
    text = open(filepath).read()
    print(f"{DIM}[LOG] edit read: {len(text)} bytes{RESET}", flush=True)
    old, new = args["old"], args["new"]
    if old not in text: return "error: old_string not found"
    count = text.count(old)
    if not args.get("all") and count > 1:
        return f"error: old_string appears {count} times (use all=true)"
    result = text.replace(old, new) if args.get("all") else text.replace(old, new, 1)
    print(f"{DIM}[LOG] edit writing: {len(result)} bytes{RESET}", flush=True)
    open(filepath, "w").write(result)
    print(f"{DIM}[LOG] edit completed: {filepath}{RESET}", flush=True)
    return "ok"

def glob(args):
    pattern = (args.get("path", ".") + "/" + args["pat"]).replace("//", "/")
    files = sorted(globlib.glob(pattern, recursive=True), 
                   key=lambda f: os.path.getmtime(f) if os.path.isfile(f) else 0, reverse=True)
    return "\n".join(files) or "none"

def grep(args):
    pattern, hits = re.compile(args["pat"]), []
    for fp in globlib.glob(args.get("path", ".") + "/**", recursive=True):
        try:
            for n, l in enumerate(open(fp), 1):
                if pattern.search(l): hits.append(f"{fp}:{n}:{l.rstrip()}")
        except: pass
    return "\n".join(hits[:50]) or "none"

def bash(args):
    global stop_flag
    proc = subprocess.Popen(args["cmd"], shell=True, stdout=subprocess.PIPE, 
                           stderr=subprocess.STDOUT, text=True)
    lines = []
    old_settings = termios.tcgetattr(sys.stdin)
    try:
        tty.setcbreak(sys.stdin.fileno())
        if proc.stdout:
            import fcntl
            fd = proc.stdout.fileno()
            fcntl.fcntl(fd, fcntl.F_SETFL, fcntl.fcntl(fd, fcntl.F_GETFL) | os.O_NONBLOCK)
            
            while True:
                # Check ESC key
                if select.select([sys.stdin], [], [], 0)[0]:
                    if sys.stdin.read(1) == '\x1b':
                        stop_flag = True
                        proc.kill()
                        lines.append("\n(stopped)")
                        print(f"\n{YELLOW}⏸ Stopped{RESET}")
                        break
                
                # Read output
                if select.select([proc.stdout], [], [], 0.1)[0]:
                    line = proc.stdout.readline()
                    if line:
                        print(f"  {DIM}│ {line.rstrip()}{RESET}", flush=True)
                        lines.append(line)
                
                # Check if done
                if proc.poll() is not None:
                    remaining = proc.stdout.read()
                    if remaining:
                        for line in remaining.split('\n'):
                            if line:
                                print(f"  {DIM}│ {line.rstrip()}{RESET}", flush=True)
                                lines.append(line + '\n')
                    break
        
        if not stop_flag:
            proc.wait(timeout=30)
    except subprocess.TimeoutExpired:
        proc.kill()
        lines.append("\n(timeout)")
    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
    
    return "".join(lines).strip() or "(empty)"

def web_search(args):
    """Search web using DuckDuckGo"""
    query, max_results = args["query"], args.get("max_results", 5)
    try:
        url = f"https://html.duckduckgo.com/html/?q={urllib.parse.quote_plus(query)}"
        headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"}
        opener = create_opener()
        html = opener.open(urllib.request.Request(url, headers=headers), timeout=30).read().decode()
        
        # Extract titles and URLs
        links = re.findall(r'class="result__a"[^>]+href="([^"]+)"[^>]*>([^<]+)<', html)
        # Extract snippets
        snippets = re.findall(r'class="result__snippet"[^>]*>([^<]*)<', html)
        if not links: return "No results found"
        
        results = []
        for i, ((link, title), snippet) in enumerate(zip(links[:max_results], snippets[:max_results] + [""] * max_results), 1):
            results.append(f"{i}. {title.strip()}\n   URL: {link}\n   {snippet.strip()}\n")
        return "\n".join(results)
    except Exception as e:
        return f"error: {e}"


TOOLS = {
    "read": ("Read file with line numbers", {"path": "string", "offset": "number?", "limit": "number?"}, read),
    "write": ("Write content to file", {"path": "string", "content": "string"}, write),
    "edit": ("Replace old with new in file", {"path": "string", "old": "string", "new": "string", "all": "boolean?"}, edit),
    "glob": ("Find files by pattern", {"pat": "string", "path": "string?"}, glob),
    "grep": ("Search files for regex", {"pat": "string", "path": "string?"}, grep),
    "bash": ("Run shell command", {"cmd": "string"}, bash),
    "web_search": ("Search the web using DuckDuckGo", {"query": "string", "max_results": "number?"}, web_search),
    "search_extension": ("Search for extensions to add new capabilities (GitHub docs, web scraping, APIs, etc)", {"query": "string"}, search_extension),
    "load": ("Load extension from URL to add new tools", {"url": "string"}, load),
}

def run_tool(name, args):
    try: return TOOLS[name][2](args)
    except Exception as e: return f"error: {e}"

def make_schema():
    result = []
    for name, (desc, params, _) in TOOLS.items():
        props, req = {}, []
        for pname, ptype in params.items():
            opt = ptype.endswith("?")
            props[pname] = {"type": "integer" if ptype.rstrip("?") == "number" else ptype.rstrip("?")}
            if not opt: req.append(pname)
        result.append({"name": name, "description": desc, 
                      "input_schema": {"type": "object", "properties": props, "required": req}})
    return result

def call_api(messages, system_prompt, stream=True, enable_thinking=True, use_tools=True):
    headers = {"Content-Type": "application/json", "anthropic-version": "2023-06-01"}
    if LOCAL_API_KEY: headers["Authorization"] = f"Bearer {LOCAL_API_KEY}"
    elif OPENROUTER_KEY: headers["Authorization"] = f"Bearer {OPENROUTER_KEY}"
    else: headers["x-api-key"] = os.environ.get("ANTHROPIC_API_KEY", "")
    
    data = {"model": MODEL, "max_tokens": 8192, "system": system_prompt,
            "messages": messages, "stream": stream}
    
    if use_tools:
        data["tools"] = make_schema()

    if enable_thinking and os.environ.get("THINKING"):
        data["thinking"] = {"type": "enabled", "budget_tokens": int(os.environ.get("THINKING_BUDGET", "10000"))}

    req = urllib.request.Request(API_URL, json.dumps(data).encode(), headers, method="POST")
    return create_opener().open(req)

def summarize_changes(user_input, files_modified, checkpoint_manager, checkpoint_id):
    """Use LLM to summarize the changes made in this turn
    
    Args:
        user_input: User's request
        files_modified: Set of modified file paths
        checkpoint_manager: CheckpointManager instance
        checkpoint_id: Checkpoint hash to get diff from
        
    Returns:
        str: One-line summary of changes
    """
    if not files_modified or not checkpoint_id:
        return user_input[:50]
    
    try:
        # Get diff from git
        diff_output = checkpoint_manager._git_command(
            "--git-dir", checkpoint_manager.bare_repo,
            "show", "--format=", checkpoint_id
        )
        
        # Check if diff is empty or error - no actual changes
        if not diff_output or diff_output.startswith("error") or len(diff_output.strip()) == 0:
            # No diff available, just use user input
            return user_input[:50]
        
        # Limit diff size to avoid token overflow (max ~3000 chars)
        if len(diff_output) > 3000:
            diff_output = diff_output[:3000] + "\n... (truncated)"
        
        summary_prompt = f"""Based on the actual code changes (diff), generate a brief Chinese summary (max 30 Chinese characters).

IMPORTANT: Must be based on the actual code changes, not the user's description.

Code changes (diff):
{diff_output}

User description (for reference only): {user_input}

Requirements:
1. Describe what code/functionality was actually modified
2. Reply in Chinese only, no explanation
3. No quotes
4. Max 30 Chinese characters

Good examples:
- 在 auth.py 添加 JWT 验证
- 修复 parser.py 空指针异常
- 重构 database.py 连接池
- 更新 README 添加安装说明
"""
        
        messages = [{"role": "user", "content": summary_prompt}]
        response = call_api(messages, "You are a code change analyzer, skilled at extracting key information from diffs. Reply in Chinese.", 
                           stream=False, enable_thinking=False, use_tools=False)
        
        # Parse non-streaming response
        data = json.loads(response.read().decode())
        blocks = data.get("content", [])
        
        for block in blocks:
            if block.get("type") == "text":
                summary = block.get("text", "").strip()
                
                # Remove thinking tags if present
                if "<thinking>" in summary:
                    # Extract content after </thinking>
                    parts = summary.split("</thinking>")
                    if len(parts) > 1:
                        summary = parts[-1].strip()
                
                # Clean up and limit length
                summary = summary.replace('"', '').replace("'", "")
                if summary and len(summary) <= 80:
                    return summary
        
        # Fallback to user input
        return user_input[:50]
    except Exception as e:
        # On error, fallback to user input
        return user_input[:50]

def process_stream(response):
    """简化的流式处理，支持ESC中断"""
    global stop_flag
    blocks, current, text_buf, json_buf, think_buf = [], None, "", "", ""
    
    # Save terminal settings
    old_settings = termios.tcgetattr(sys.stdin)
    try:
        tty.setcbreak(sys.stdin.fileno())
        
        for line in response:
            if select.select([sys.stdin], [], [], 0)[0]:
                ch = sys.stdin.read(1)
                if ch == '\x1b':  # ESC key
                    stop_flag = True
                    print(f"\n{YELLOW}⏸ Stopped{RESET}")
                    break
            
            line = line.decode("utf-8").strip()
            if not line.startswith("data: "): continue
            if line == "data: [DONE]": continue
            
            try:
                data = json.loads(line[6:])
                etype = data.get("type")
                
                if etype == "content_block_start":
                    block = data.get("content_block", {})
                    current = {"type": block.get("type"), "id": block.get("id")}
                    if current["type"] == "text":
                        text_buf = ""
                        print(f"\n{CYAN}⏺{RESET} ", end="", flush=True)
                    elif current["type"] == "thinking":
                        think_buf = ""
                        print(f"\n{YELLOW}💭{RESET} {DIM}", end="", flush=True)
                    elif current["type"] == "tool_use":
                        current["name"] = block.get("name")
                        json_buf = ""
                        
                elif etype == "content_block_delta":
                    delta = data.get("delta", {})
                    dtype = delta.get("type")
                    if dtype == "text_delta":
                        text = delta.get("text", "")
                        text_buf += text
                        print(text, end="", flush=True)
                    elif dtype == "thinking_delta":
                        text = delta.get("thinking", "")
                        think_buf += text
                        print(text, end="", flush=True)
                    elif dtype == "input_json_delta" and current:
                        json_buf += delta.get("partial_json", "")
                        
                elif etype == "content_block_stop" and current:
                    if current["type"] == "text":
                        current["text"] = text_buf
                        print()
                    elif current["type"] == "thinking":
                        print(RESET)
                    elif current["type"] == "tool_use":
                        try: current["input"] = json.loads(json_buf)
                        except: current["input"] = {}
                    blocks.append(current)
                    current = None
            except: pass
    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
    
    return blocks

def get_display_width(text):
    """Get display width of text (CJK chars = 2, others = 1)"""
    width = 0
    for char in text:
        if unicodedata.east_asian_width(char) in ('F', 'W'):
            width += 2
        else:
            width += 1
    return width

def format_tool_preview(name, args):
    """Format tool call preview based on tool type"""
    if name == "read":
        path = args.get("path", "")
        offset = args.get("offset")
        limit = args.get("limit")
        if offset or limit:
            return f"{path}, offset={offset or 0}, limit={limit or 'all'}"
        return path
    
    elif name == "write":
        path = args.get("path", "")
        content = args.get("content", "")
        return f"{path}, {len(content)} bytes"
    
    elif name == "edit":
        path = args.get("path", "")
        old = args.get("old", "")
        new = args.get("new", "")
        all_flag = args.get("all", False)
        old_preview = old[:20] + "..." if len(old) > 20 else old
        new_preview = new[:20] + "..." if len(new) > 20 else new
        flag_str = ", all=true" if all_flag else ""
        return f"{path}: '{old_preview}' → '{new_preview}'{flag_str}"
    
    elif name == "bash":
        return args.get("cmd", "")
    
    elif name == "glob":
        pattern = args.get("pat", "")
        path = args.get("path", ".")
        return f"{pattern} in {path}"
    
    elif name == "grep":
        pattern = args.get("pat", "")
        path = args.get("path", ".")
        return f"/{pattern}/ in {path}"
    
    elif name == "web_search":
        query = args.get("query", "")
        max_results = args.get("max_results", 5)
        return f"{query}, max={max_results}"
    
    elif name == "search_extension":
        return args.get("query", "")
    
    elif name == "load":
        url = args.get("url", "")
        # Show filename from URL
        filename = url.split("/")[-1] if url else ""
        return filename or url
    
    else:
        # Fallback: show first value
        return str(list(args.values())[0])[:50] if args else ""

def format_tool_result(name, result):
    """Format tool result preview based on tool type"""
    if not result:
        return "(empty)"
    
    lines = result.split("\n")
    
    # For simple results (ok, error, etc), show as-is
    if len(lines) == 1 and len(result) < 80:
        return result
    
    # For multi-line results
    first_line = lines[0][:60]
    if len(lines[0]) > 60:
        first_line += "..."
    
    if len(lines) > 1:
        return f"{first_line} +{len(lines)-1} lines"
    
    return first_line

def is_file_in_project(filepath, project_path):
    """Check if file is within project directory"""
    try:
        abs_file = os.path.abspath(filepath)
        abs_project = os.path.abspath(project_path)
        # Check if file is under project directory
        return abs_file.startswith(abs_project + os.sep) or abs_file == abs_project
    except:
        return False

def read_multiline_input(prefill=""):
    """Read multiline input. Enter to submit, Alt+Enter for newline.
    Supports pasting images from clipboard.
    
    Args:
        prefill: Text to prefill in the input box
    
    Returns:
        tuple: (text: str, images: list[dict])
    """
    lines = []
    current = prefill
    cursor_pos = len(prefill)  # Cursor at end of prefill text
    images = []  # Store pasted images
    
    # Enable bracketed paste mode
    print("\033[?2004h", end="", flush=True)
    
    old_settings = termios.tcgetattr(sys.stdin)
    try:
        tty.setcbreak(sys.stdin.fileno())
        
        # Disable LNEXT (Ctrl+V) so it can be used for paste
        attrs = termios.tcgetattr(sys.stdin)
        attrs[6][termios.VLNEXT] = 0  # Disable literal-next
        termios.tcsetattr(sys.stdin, termios.TCSANOW, attrs)
        
        print(f"{BOLD}{BLUE}❯{RESET} {current}", end="", flush=True)
        
        while True:
            ch = sys.stdin.read(1)
            
            if ch == '\x03':  # Ctrl+C - clear input
                lines.clear()
                current = ""
                cursor_pos = 0
                print("\r\033[K", end="", flush=True)
                print(f"{BOLD}{BLUE}❯{RESET} ", end="", flush=True)
                continue
            
            if ch == '\x05':  # Ctrl+E
                raise EOFError
            
            if ch == '\x16':  # Ctrl+V - paste image
                img_block = detect_clipboard_image()
                
                if img_block:
                    # Image detected!
                    images.append(img_block)
                    size = img_block.get('_metadata', {}).get('size', (0, 0))
                    resized = img_block.get('_metadata', {}).get('resized')
                    
                    # Create image marker
                    if resized:
                        img_marker = f"[📷 Image {len(images)}: {size[0]}x{size[1]} → {resized[0]}x{resized[1]}]"
                    else:
                        img_marker = f"[📷 Image {len(images)}: {size[0]}x{size[1]}]"
                    
                    # Insert marker at cursor position
                    current = current[:cursor_pos] + img_marker + current[cursor_pos:]
                    cursor_pos += len(img_marker)
                    
                    # Redraw current line
                    prefix = f"{BOLD}{BLUE}{'│' if lines else '❯'}{RESET} "
                    print(f"\r\033[K{prefix}{current}", end="", flush=True)
                else:
                    # No image in clipboard
                    print(f"\r{YELLOW}⚠ No image in clipboard{RESET}", flush=True)
                    import time
                    time.sleep(1)
                    # Redraw current line
                    prefix = f"{BOLD}{BLUE}{'│' if lines else '❯'}{RESET} "
                    print(f"\r\033[K{prefix}{current}", end="", flush=True)
                
                continue
            
            if ch == '\x1b':  # Escape sequence
                next_ch = sys.stdin.read(1)
                if next_ch in ('\r', '\n'):  # Alt+Enter
                    lines.append(current)
                    current = ""
                    cursor_pos = 0
                    print(f"\n{BOLD}{BLUE}│{RESET} ", end="", flush=True)
                elif next_ch == '[':  # Escape sequence
                    seq = sys.stdin.read(1)
                    if seq == 'C':  # Right arrow
                        if cursor_pos < len(current):
                            # Get display width of character at cursor
                            char_width = get_display_width(current[cursor_pos])
                            cursor_pos += 1
                            # Move cursor by actual display width
                            if char_width == 2:
                                print("\033[2C", end="", flush=True)
                            else:
                                print("\033[C", end="", flush=True)
                    elif seq == 'D':  # Left arrow
                        if cursor_pos > 0:
                            cursor_pos -= 1
                            # Get display width of character before cursor
                            char_width = get_display_width(current[cursor_pos])
                            # Move cursor by actual display width
                            if char_width == 2:
                                print("\033[2D", end="", flush=True)
                            else:
                                print("\033[D", end="", flush=True)
                    elif seq == '2':  # Bracketed paste start: ESC[200~
                        rest = sys.stdin.read(3)  # Read "00~"
                        if rest == '00~':
                            # Read pasted content until ESC[201~
                            paste_buf = ""
                            while True:
                                c = sys.stdin.read(1)
                                if c == '\x1b':
                                    # Check for [201~
                                    peek = sys.stdin.read(5)
                                    if peek == '[201~':
                                        break
                                    else:
                                        paste_buf += c + peek
                                else:
                                    paste_buf += c
                            
                            # Process pasted content
                            paste_lines = paste_buf.split('\n')
                            
                            if len(paste_lines) == 1:
                                # Single line paste
                                current = current[:cursor_pos] + paste_lines[0] + current[cursor_pos:]
                                cursor_pos += len(paste_lines[0])
                                prefix = f"{BOLD}{BLUE}{'│' if lines else '❯'}{RESET} "
                                print(f"\r\033[K{prefix}{current}", end="", flush=True)
                            else:
                                # Multi-line paste
                                # First line appends to current
                                first_line = current[:cursor_pos] + paste_lines[0]
                                print(paste_lines[0], end="", flush=True)
                                if first_line:
                                    lines.append(first_line)
                                
                                # Middle lines
                                for line in paste_lines[1:-1]:
                                    print(f"\n{BOLD}{BLUE}│{RESET} {line}", end="", flush=True)
                                    lines.append(line)
                                
                                # Last line becomes new current
                                current = paste_lines[-1]
                                cursor_pos = len(current)
                                print(f"\n{BOLD}{BLUE}│{RESET} {current}", end="", flush=True)
                continue
            
            if ch in ('\r', '\n'):  # Enter - submit
                if current:
                    lines.append(current)
                print()
                break
            
            if ch in ('\x7f', '\x08'):  # Backspace
                if cursor_pos > 0:
                    # Delete character before cursor
                    current = current[:cursor_pos-1] + current[cursor_pos:]
                    cursor_pos -= 1
                    # Redraw current line
                    prefix = f"{BOLD}{BLUE}{'│' if lines else '❯'}{RESET} "
                    print(f"\r\033[K{prefix}{current}", end="", flush=True)
                    # Move cursor back to position
                    if cursor_pos < len(current):
                        # Calculate display width from cursor to end
                        remaining_width = get_display_width(current[cursor_pos:])
                        print(f"\033[{remaining_width}D", end="", flush=True)
                elif lines:
                    # Merge with previous line
                    prev_line = lines.pop()
                    cursor_pos = len(prev_line)  # Cursor at end of previous line
                    current = prev_line + current
                    # Move up and redraw
                    print("\033[A\033[K", end="", flush=True)
                    prefix = f"{BOLD}{BLUE}{'│' if lines else '❯'}{RESET} "
                    print(f"\r{prefix}{current}", end="", flush=True)
                    if cursor_pos < len(current):
                        # Calculate display width from cursor to end
                        remaining_width = get_display_width(current[cursor_pos:])
                        print(f"\033[{remaining_width}D", end="", flush=True)
                continue
            
            if ch.isprintable() or ch == '\t':
                # Insert character at cursor position
                current = current[:cursor_pos] + ch + current[cursor_pos:]
                cursor_pos += 1
                # Redraw from cursor position
                print(f"{ch}{current[cursor_pos:]}", end="", flush=True)
                # Move cursor back if needed
                if cursor_pos < len(current):
                    # Calculate display width from cursor to end
                    remaining_width = get_display_width(current[cursor_pos:])
                    print(f"\033[{remaining_width}D", end="", flush=True)
        
    finally:
        # Disable bracketed paste mode
        print("\033[?2004l", end="", flush=True)
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
    
    text = "\n".join(lines).strip()
    return text, images

def main():
    global stop_flag
    # Parse command line arguments
    continue_session = "-c" in sys.argv or "--continue" in sys.argv
    list_sessions = "-l" in sys.argv or "--list" in sys.argv
    
    # Disable Ctrl+C signal
    old_settings = termios.tcgetattr(sys.stdin)
    new_settings = termios.tcgetattr(sys.stdin)
    new_settings[3] = new_settings[3] & ~termios.ISIG  # Disable signal generation
    termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)

    try:
        proxy = os.environ.get("http_proxy") or os.environ.get("https_proxy")
        proxy_info = f" | {DIM}🌐 {proxy}{RESET}" if proxy else ""
        thinking_info = f" | {YELLOW}💭{RESET}" if os.environ.get("THINKING") else ""
        
        if list_sessions:
            session_mode = f" | {YELLOW}Select{RESET}"
        elif continue_session:
            session_mode = f" | {GREEN}Continue{RESET}"
        else:
            session_mode = f" | {CYAN}New{RESET}"
            
        print(f"{BOLD}nanocode{RESET} | {DIM}{MODEL} | {os.getcwd()}{proxy_info}{thinking_info}{session_mode}{RESET}")
        print(f"{DIM}Shortcuts: Enter=submit | Alt+Enter=newline | Ctrl+V=paste image | Ctrl+C=clear | Ctrl+E=exit | ESC=stop{RESET}")
        print(f"{DIM}Commands: /c [all|baseline|<id>] | /ca | /t | /clear{RESET}")
        print(f"{DIM}Usage: nanocode (new) | nanocode -c (continue) | nanocode -l (select){RESET}\n")
        
        selected_session_id = None
        if list_sessions:
            selected_session_id = select_session_interactive()
            if not selected_session_id:
                print(f"{DIM}Exiting...{RESET}")
                return
        
        run_main_loop(continue_session, selected_session_id)
    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)

def select_session_interactive():
    """Display sessions and let user select one
    
    Returns:
        session_id: Selected session ID, or None if cancelled
    """
    session_manager = SessionManager(os.getcwd())
    sessions = session_manager.list_sessions()[:10]  # Limit to 10 most recent
    
    if not sessions:
        print(f"{YELLOW}⚠ No previous sessions found{RESET}")
        print(f"{DIM}Starting new session...{RESET}\n")
        return None
    
    print(f"{BOLD}📂 Recent Sessions:{RESET}\n")
    
    for i, sess_info in enumerate(sessions, 1):
        created = datetime.fromtimestamp(sess_info['metadata']['created_at']).strftime('%Y-%m-%d %H:%M')
        last_active = datetime.fromtimestamp(sess_info['metadata']['last_active']).strftime('%Y-%m-%d %H:%M')
        desc = sess_info['metadata'].get('description', '(no description)')
        
        # Git info
        git_commit = sess_info['metadata'].get('git_commit')
        git_branch = sess_info['metadata'].get('git_branch')
        git_dirty = sess_info['metadata'].get('git_dirty', False)
        
        print(f"{CYAN}{i}.{RESET} {BOLD}{sess_info['session_id']}{RESET}")
        print(f"   {desc}")
        
        git_info = ""
        if git_commit and git_branch:
            dirty_mark = f"{YELLOW}*{RESET}" if git_dirty else ""
            git_info = f" | Git: {git_branch}@{git_commit}{dirty_mark}"
        
        print(f"   Created: {created} | Last: {last_active} | {sess_info['message_count']} messages{git_info}\n")
    
    print(f"{DIM}Enter session number (1-{len(sessions)}), or press Enter for new session:{RESET}")
    
    try:
        choice = input(f"{BOLD}{BLUE}❯{RESET} ").strip()
        
        if not choice:
            # Empty input = new session
            return None
        
        try:
            idx = int(choice) - 1
            if 0 <= idx < len(sessions):
                return sessions[idx]['session_id']
            else:
                print(f"{RED}✗ Invalid number{RESET}")
                return None
        except ValueError:
            print(f"{RED}✗ Invalid input{RESET}")
            return None
    except (EOFError, KeyboardInterrupt):
        return None


def run_main_loop(continue_session=False, selected_session_id=None):
    # Initialize session manager
    session_manager = SessionManager(os.getcwd())
    
    # Load or create session based on parameters
    if selected_session_id:
        # Load specific session selected by user
        session = session_manager.load_session(selected_session_id)
        if session:
            git_info = ""
            git_commit = session.metadata.get('git_commit')
            git_branch = session.metadata.get('git_branch')
            if git_commit and git_branch:
                git_dirty = session.metadata.get('git_dirty', False)
                dirty_mark = f"{YELLOW}*{RESET}" if git_dirty else ""
                git_info = f" | Git: {git_branch}@{git_commit}{dirty_mark}"
            
            print(f"{GREEN}✓ Loaded session: {session.session_id}{RESET}")
            print(f"{DIM}  └─ {len(session.messages)} messages{git_info}{RESET}")
            
            # Check for conflicts
            conflicts = session.detect_conflicts()
            if conflicts:
                print(f"\n{YELLOW}⚠ File conflicts detected:{RESET}")
                for filepath in conflicts[:5]:
                    print(f"  - {filepath}")
                if len(conflicts) > 5:
                    print(f"  ... and {len(conflicts)-5} more")
                print(f"\n{DIM}These files have been modified outside this session.{RESET}")
                confirm = input(f"{BOLD}Continue anyway? (y/N/u=update): {RESET}").strip().lower()
                
                if confirm == 'u':
                    session.update_file_states()
                    session_manager.save_session()
                    print(f"{GREEN}✓ Updated file states{RESET}\n")
                elif confirm != 'y':
                    print(f"{DIM}Creating new session instead...{RESET}\n")
                    session_manager.create_session()
                else:
                    print()
            else:
                print()
        else:
            print(f"{RED}✗ Failed to load session{RESET}")
            print(f"{GREEN}✓ Creating new session instead{RESET}\n")
            session_manager.create_session()
    elif continue_session:
        # Continue last session
        last_session = session_manager.load_last_session()
        if last_session:
            git_info = ""
            git_commit = last_session.metadata.get('git_commit')
            git_branch = last_session.metadata.get('git_branch')
            if git_commit and git_branch:
                git_dirty = last_session.metadata.get('git_dirty', False)
                dirty_mark = f"{YELLOW}*{RESET}" if git_dirty else ""
                git_info = f" | Git: {git_branch}@{git_commit}{dirty_mark}"
            
            print(f"{GREEN}✓ Continued session: {last_session.session_id}{RESET}")
            print(f"{DIM}  └─ {len(last_session.messages)} messages{git_info}{RESET}")
            
            # Check for conflicts
            conflicts = last_session.detect_conflicts()
            if conflicts:
                print(f"\n{YELLOW}⚠ File conflicts detected:{RESET}")
                for filepath in conflicts[:5]:
                    print(f"  - {filepath}")
                if len(conflicts) > 5:
                    print(f"  ... and {len(conflicts)-5} more")
                print(f"\n{DIM}These files have been modified outside this session.{RESET}")
                confirm = input(f"{BOLD}Continue anyway? (y/N/u=update): {RESET}").strip().lower()
                
                if confirm == 'u':
                    last_session.update_file_states()
                    session_manager.save_session()
                    print(f"{GREEN}✓ Updated file states{RESET}\n")
                elif confirm != 'y':
                    print(f"{DIM}Creating new session instead...{RESET}\n")
                    session_manager.create_session()
                else:
                    print()
            else:
                print()
        else:
            # No previous session, create new one
            session_manager.create_session()
            print(f"{YELLOW}⚠ No previous session found{RESET}")
            print(f"{GREEN}✓ Created new session: {session_manager.current_session.session_id}{RESET}\n")
    else:
        # Always create new session by default
        # Try to detect parent from last session's latest checkpoint
        parent_checkpoint = None
        parent_session = None
        
        last_session = session_manager.load_last_session()
        if last_session:
            # Get the latest checkpoint from last session
            checkpoints = session_manager.checkpoint_manager.list_checkpoints(show_all=False)
            if checkpoints:
                parent_checkpoint = checkpoints[0][0]  # Latest checkpoint hash
                parent_session = last_session.session_id
        
        session_manager.create_session(
            parent_checkpoint=parent_checkpoint,
            parent_session=parent_session
        )
        
        git_info = ""
        git_commit = session_manager.current_session.metadata.get('git_commit')
        git_branch = session_manager.current_session.metadata.get('git_branch')
        if git_commit and git_branch:
            git_dirty = session_manager.current_session.metadata.get('git_dirty', False)
            dirty_mark = f"{YELLOW}*{RESET}" if git_dirty else ""
            git_info = f" | Git: {git_branch}@{git_commit}{dirty_mark}"
        
        if parent_checkpoint:
            print(f"{GREEN}✓ Created new session: {session_manager.current_session.session_id}{RESET}")
            print(f"{DIM}  └─ Branched from {parent_session[:8]}... @ {parent_checkpoint}{git_info}{RESET}\n")
        else:
            print(f"{GREEN}✓ Created new session: {session_manager.current_session.session_id}{RESET}")
            if git_info:
                print(f"{DIM}  └─{git_info}{RESET}\n")
            else:
                print()
    
    files_modified = set()
    auto_checkpoint = True
    
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    system_prompt = f"""Concise coding assistant. cwd: {os.getcwd()} Current time: {current_time}
IMPORTANT: When you don't have a tool for the task, ALWAYS try search_extension first before saying you can't do it.
Examples:
- User asks about GitHub repo → search_extension({{"query": "github documentation"}})
- User needs web data → search_extension({{"query": "web scraping"}})
- User needs API → search_extension({{"query": "api client"}})"""

    prefill_next_input = ""  # Store prefill text for next input
    
    while True:
        try:
            print(f"{DIM}{'─'*80}{RESET}")
            user_input, images = read_multiline_input(prefill_next_input)
            prefill_next_input = ""  # Clear after use
            print(f"{DIM}{'─'*80}{RESET}")
            
            if not user_input and not images: continue
            if user_input in ("/q", "exit"):
                session_manager.save_session()
                break
            
            # Handle /clear command first (before /c to avoid conflict)
            if user_input == "/clear":
                # Save current session
                session_manager.save_session()
                
                # Get latest checkpoint from current session (if any)
                checkpoints = session_manager.checkpoint_manager.list_checkpoints(show_all=False)
                parent_checkpoint = checkpoints[0][0] if checkpoints else None
                parent_session = session_manager.current_session.session_id
                
                # Create new session branched from current
                session_manager.create_session(
                    parent_checkpoint=parent_checkpoint,
                    parent_session=parent_session
                )
                
                # Reset state
                files_modified.clear()
                
                print(f"{GREEN}✓ Started new session: {session_manager.current_session.session_id}{RESET}")
                if parent_checkpoint:
                    print(f"{DIM}  └─ Branched from {parent_session[:8]}... @ {parent_checkpoint}{RESET}")
                continue
            
            # Handle turns command
            if user_input.startswith("/t") or user_input.startswith("/turns"):
                parts = user_input.split()
                
                # Parse turn number if provided
                turn_number = None
                if len(parts) >= 2 and parts[1].isdigit():
                    turn_number = int(parts[1])
                
                success, prefill_input = handle_turns_command(session_manager, turn_number)
                if success:
                    # Session already modified in handle_turns_command, just save
                    session_manager.save_session()
                    # Set prefill for next input if provided
                    if prefill_input:
                        prefill_next_input = prefill_input
                continue
            
            # Handle checkpoint commands
            if user_input.startswith("/checkpoint") or user_input.startswith("/c") or user_input == "/ca":
                parts = user_input.split()
                
                # /ca is shortcut for /c all
                if parts[0] == "/ca":
                    parts = ["/c", "all"]
                
                # /c without args defaults to list
                if len(parts) == 1 and parts[0] in ["/c", "/checkpoint"]:
                    parts.append("list")
                
                restored_messages = handle_checkpoint_command(parts, session_manager, files_modified)
                if restored_messages is not None:
                    # Restore conversation by replacing session messages
                    session_manager.current_session.messages = restored_messages
                    session_manager.save_session()
                continue
            
            # Build message content
            if images:
                # Has images: use content array format
                content = []
                if user_input:
                    content.append({"type": "text", "text": user_input})
                
                for img_block in images:
                    # Remove metadata (API doesn't need it)
                    img_block_clean = {
                        "type": "image",
                        "source": img_block["source"]
                    }
                    content.append(img_block_clean)
                
                message = {"role": "user", "content": content}
                
                # Show image info
                print(f"{GREEN}📷 Attached {len(images)} image(s){RESET}")
            else:
                # Text only: keep original format
                message = {"role": "user", "content": user_input}
            
            # Add user message to current session
            session_manager.current_session.messages.append(message)
            
            # Reset stop flag for new turn
            stop_flag = False
            
            # Track files modified in this turn
            files_modified_this_turn = set()
            
            while True:
                response = call_api(session_manager.current_session.messages, system_prompt)
                blocks = process_stream(response)
                if stop_flag: break
                
                tool_results = []
                for block in blocks:
                    if block["type"] == "tool_use":
                        name, args = block["name"], block["input"]
                        
                        # Save baseline BEFORE executing write/edit
                        if name in ['write', 'edit']:
                            filepath = args.get('path')
                            if filepath and is_file_in_project(filepath, session_manager.project_path):
                                session_manager.save_baseline_if_needed(filepath)
                        
                        # Format preview based on tool type
                        preview = format_tool_preview(name, args)
                        print(f"\n{GREEN}⏺ {name}{RESET}({DIM}{preview}{RESET})")
                        
                        result = run_tool(name, args)
                        
                        # Format result based on tool type
                        result_preview = format_tool_result(name, result)
                        print(f"  {DIM}⎿ {result_preview}{RESET}")
                        
                        # Track file modifications (only project files)
                        if name in ['write', 'edit']:
                            filepath = args.get('path')
                            if filepath and is_file_in_project(filepath, session_manager.project_path):
                                files_modified.add(filepath)
                                files_modified_this_turn.add(filepath)
                                session_manager.current_session.track_file_state(filepath)
                        
                        tool_results.append({"type": "tool_result", "tool_use_id": block["id"], "content": result})
                        
                        # Check stop_flag after each tool execution
                        if stop_flag:
                            print(f"{YELLOW}⚠ Tool execution stopped{RESET}")
                            break
                
                session_manager.current_session.messages.append({"role": "assistant", "content": blocks})
                if not tool_results or stop_flag: break
                session_manager.current_session.messages.append({"role": "user", "content": tool_results})
            
            # Auto checkpoint after AI work (if project files were modified)
            checkpoint_id = None
            if auto_checkpoint and files_modified_this_turn:
                # files_modified_this_turn already filtered to project files only
                # Use parent_commit for first checkpoint of new session
                parent_commit = session_manager.parent_commit_for_next_checkpoint
                checkpoint_id = session_manager.checkpoint_manager.create_checkpoint(
                    f"Auto: {user_input[:50]}", 
                    list(files_modified_this_turn),
                    conversation_snapshot=session_manager.current_session.messages.copy(),
                    parent_commit=parent_commit
                )
                # Clear parent after first checkpoint
                if parent_commit:
                    session_manager.parent_commit_for_next_checkpoint = None
                
                if checkpoint_id:
                    # Generate summary using LLM with actual diff
                    print(f"{DIM}Generating checkpoint summary...{RESET}", end="", flush=True)
                    summary = summarize_changes(
                        user_input, 
                        files_modified_this_turn,
                        session_manager.checkpoint_manager,
                        checkpoint_id
                    )
                    print(f"\r{' ' * 40}\r", end="", flush=True)  # Clear the line
                    
                    # Update commit message with better summary (only if different from temp message)
                    temp_message = f"Auto: {user_input[:50]}"
                    if summary != user_input[:50] and summary != temp_message:
                        session_manager.checkpoint_manager._git_command(
                            "--git-dir", session_manager.checkpoint_manager.bare_repo,
                            "commit", "--amend", "-m", summary
                        )
                    
                    print(f"\n{YELLOW}📍 {checkpoint_id}: {summary}{RESET}")
                else:
                    # Checkpoint creation failed (e.g., no actual diff)
                    print(f"\n{DIM}(No project file changes to checkpoint){RESET}")
            
            # Record this turn
            session_manager.current_session.add_turn(
                user_input,
                files_modified_this_turn,
                checkpoint_id=checkpoint_id
            )
            
            # Auto-save session after each interaction
            session_manager.save_session()

            print()
        except EOFError:
            session_manager.save_session()
            break
        except Exception as e: print(f"{RED}⏺ Error: {e}{RESET}")

# ============================================================================
# Checkpoint & Session Management (Phase 1+2)
# ============================================================================

class CheckpointManager:
    """Manage checkpoints using shadow bare git repository with session isolation"""
    
    def __init__(self, project_path, session_id=None):
        self.project_path = project_path
        self.session_id = session_id
        self.nanocode_dir = os.path.join(project_path, ".nanocode")
        self.bare_repo = os.path.join(self.nanocode_dir, "checkpoint.git")
        self._init_bare_repo()
    
    def set_session(self, session_id):
        """Set current session for checkpoint operations"""
        self.session_id = session_id
    
    def _get_branch_name(self):
        """Get git branch name for current session"""
        if not self.session_id:
            return "main"
        return f"session_{self.session_id}"
    
    def _init_bare_repo(self):
        """Initialize shadow bare repository"""
        if not os.path.exists(self.bare_repo):
            os.makedirs(self.bare_repo, exist_ok=True)
            try:
                subprocess.run(
                    ["git", "init", "--bare", self.bare_repo],
                    capture_output=True, check=True
                )
            except (subprocess.CalledProcessError, FileNotFoundError):
                # Git not available, will handle gracefully
                pass
    
    def _git_command(self, *args, cwd=None):
        """Execute git command"""
        try:
            result = subprocess.run(
                ["git"] + list(args),
                cwd=cwd or self.project_path,
                capture_output=True,
                text=True,
                check=True
            )
            return result.stdout.strip()
        except (subprocess.CalledProcessError, FileNotFoundError) as e:
            return f"error: {e}"
    
    def save_file_to_blob(self, filepath):
        """Save file to git blob storage
        
        Returns:
            str: blob hash
        """
        try:
            result = subprocess.run(
                ["git", "--git-dir", self.bare_repo, "hash-object", "-w", filepath],
                capture_output=True, text=True, check=True
            )
            return result.stdout.strip()
        except Exception as e:
            return None
    
    def restore_file_from_blob(self, blob_hash, filepath):
        """Restore file from git blob storage"""
        try:
            content = subprocess.run(
                ["git", "--git-dir", self.bare_repo, "cat-file", "-p", blob_hash],
                capture_output=True, check=True
            ).stdout
            
            os.makedirs(os.path.dirname(filepath) or ".", exist_ok=True)
            with open(filepath, 'wb') as f:
                f.write(content)
            return True
        except Exception as e:
            return False
    
    def get_file_git_info(self, filepath):
        """Get file's git info from project .git
        
        Returns:
            dict: {"commit": "abc123", "has_changes": True/False} or None
        """
        try:
            # Check if file is tracked
            result = subprocess.run(
                ["git", "ls-files", "--error-unmatch", filepath],
                cwd=self.project_path,
                capture_output=True, text=True
            )
            
            if result.returncode != 0:
                return None  # Not tracked
            
            # Get last commit for this file
            commit = subprocess.run(
                ["git", "log", "-1", "--format=%H", "--", filepath],
                cwd=self.project_path,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            # Check if file has local changes
            diff = subprocess.run(
                ["git", "diff", "HEAD", "--", filepath],
                cwd=self.project_path,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            return {
                "commit": commit,
                "has_changes": bool(diff)
            }
        except Exception as e:
            return None
    
    def create_checkpoint(self, message, files_changed, conversation_snapshot=None, parent_commit=None):
        """Create a checkpoint on current session's branch
        
        Args:
            message: Commit message
            files_changed: List of modified files
            conversation_snapshot: Conversation state to save
            parent_commit: Parent commit hash to branch from (for new sessions)
        """
        print(f"{DIM}[LOG] create_checkpoint: files_changed={files_changed}{RESET}", flush=True)
        if not files_changed or not self.session_id:
            return None
        
        branch_name = self._get_branch_name()
        
        # Save conversation snapshot
        if conversation_snapshot:
            snapshot_file = os.path.join(self.nanocode_dir, "conversation_snapshots.json")
            snapshots = {}
            if os.path.exists(snapshot_file):
                with open(snapshot_file, 'r') as f:
                    snapshots = json.load(f)
        
        # Create temp worktree for this session
        temp_worktree = os.path.join(self.nanocode_dir, f"temp_worktree_{self.session_id}")
        
        try:
            # Check if branch exists
            branch_exists = self._git_command("--git-dir", self.bare_repo, "rev-parse", "--verify", branch_name)
            
            if not branch_exists or branch_exists.startswith("error"):
                # Create new branch
                os.makedirs(temp_worktree, exist_ok=True)
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "config", "core.bare", "false")
                
                # If parent_commit specified, branch from it
                if parent_commit:
                    # Create branch from parent commit
                    self._git_command("--git-dir", self.bare_repo, "branch", branch_name, parent_commit)
                    self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "checkout", branch_name, "-f")
                else:
                    # Create orphan branch (no parent)
                    self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "checkout", "--orphan", branch_name)
                
                # Copy files to temp worktree
                for filepath in files_changed:
                    print(f"{DIM}[LOG] checkpoint copying: {filepath}{RESET}", flush=True)
                    if os.path.exists(filepath):
                        file_size = os.path.getsize(filepath)
                        print(f"{DIM}[LOG] source file exists: {filepath} ({file_size} bytes){RESET}", flush=True)
                        # Convert absolute path to relative path
                        if os.path.isabs(filepath):
                            rel_filepath = os.path.relpath(filepath, self.project_path)
                        else:
                            rel_filepath = filepath
                        dest = os.path.join(temp_worktree, rel_filepath)
                        os.makedirs(os.path.dirname(dest), exist_ok=True)
                        with open(filepath, 'rb') as src, open(dest, 'wb') as dst:
                            content = src.read()
                            dst.write(content)
                            print(f"{DIM}[LOG] copied to temp_worktree: {dest} ({len(content)} bytes){RESET}", flush=True)
                    else:
                        print(f"{DIM}[LOG] source file NOT exists: {filepath}{RESET}", flush=True)
                
                # Commit
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "add", "-A")
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, 
                                "commit", "-m", message, "--allow-empty")
                
                commit_hash = self._git_command("--git-dir", self.bare_repo, "rev-parse", "HEAD")
                checkpoint_id = commit_hash[:8] if commit_hash and not commit_hash.startswith("error") else None
                
                # Save conversation snapshot with checkpoint_id
                if checkpoint_id and conversation_snapshot:
                    snapshots[checkpoint_id] = conversation_snapshot
                    with open(snapshot_file, 'w') as f:
                        json.dump(snapshots, f, indent=2)
                
                return checkpoint_id
            else:
                # Branch exists, checkout and commit
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "checkout", branch_name, "-f")
                
                # Update temp worktree
                for filepath in files_changed:
                    print(f"{DIM}[LOG] checkpoint updating: {filepath}{RESET}", flush=True)
                    if os.path.exists(filepath):
                        file_size = os.path.getsize(filepath)
                        print(f"{DIM}[LOG] source file exists: {filepath} ({file_size} bytes){RESET}", flush=True)
                        # Convert absolute path to relative path
                        if os.path.isabs(filepath):
                            rel_filepath = os.path.relpath(filepath, self.project_path)
                        else:
                            rel_filepath = filepath
                        dest = os.path.join(temp_worktree, rel_filepath)
                        os.makedirs(os.path.dirname(dest), exist_ok=True)
                        with open(filepath, 'rb') as src, open(dest, 'wb') as dst:
                            content = src.read()
                            dst.write(content)
                            print(f"{DIM}[LOG] copied to temp_worktree: {dest} ({len(content)} bytes){RESET}", flush=True)
                    else:
                        print(f"{DIM}[LOG] source file NOT exists: {filepath}{RESET}", flush=True)
                
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "add", "-A")
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, 
                                "commit", "-m", message, "--allow-empty")
                
                commit_hash = self._git_command("--git-dir", self.bare_repo, "rev-parse", "HEAD")
                checkpoint_id = commit_hash[:8] if commit_hash and not commit_hash.startswith("error") else None
                
                # Save conversation snapshot with checkpoint_id
                if checkpoint_id and conversation_snapshot:
                    snapshots[checkpoint_id] = conversation_snapshot
                    with open(snapshot_file, 'w') as f:
                        json.dump(snapshots, f, indent=2)
                
                return checkpoint_id
        except Exception as e:
            return None
    
    def list_checkpoints(self, limit=10, show_all=False):
        """List recent checkpoints for current session
        
        Args:
            limit: Maximum number of checkpoints to show
            show_all: If True, show all sessions; if False, only show current session
        """
        if not self.session_id and not show_all:
            return []
        
        try:
            if show_all:
                # Show all branches
                args = ["--git-dir", self.bare_repo, "log", f"--max-count={limit}", "--oneline", "--all"]
            else:
                # Show only current session's branch
                branch_name = self._get_branch_name()
                args = ["--git-dir", self.bare_repo, "log", f"--max-count={limit}", "--oneline", branch_name]
            
            log = self._git_command(*args)
            if log and not log.startswith("error"):
                return [line.split(" ", 1) for line in log.split("\n") if line]
            return []
        except:
            return []
    
    def restore_checkpoint(self, checkpoint_id, session_baseline_files):
        """Restore files to checkpoint state and reset current session's branch
        
        This properly handles files that were added after the checkpoint by:
        1. Restoring files that exist in checkpoint
        2. Restoring files to baseline if they don't exist in checkpoint
        
        Args:
            checkpoint_id: Checkpoint hash to restore to
            session_baseline_files: Dict of baseline file states from session
        
        Returns:
            tuple: (success: bool, conversation_snapshot: dict or None)
        """
        if not self.session_id:
            return False, None
        
        branch_name = self._get_branch_name()
        temp_worktree = os.path.join(self.nanocode_dir, f"temp_worktree_{self.session_id}")
        
        try:
            # Checkout branch first
            self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "checkout", branch_name, "-f")
            
            # Reset branch to checkpoint (discards future commits on this branch)
            self._git_command("--git-dir", self.bare_repo, "reset", "--hard", checkpoint_id)
            
            # Checkout to temp worktree
            self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree,
                            "checkout", checkpoint_id, "-f")
            
            # Get list of files in checkpoint
            files_in_checkpoint_str = self._git_command(
                "--git-dir", self.bare_repo,
                "ls-tree", "-r", "--name-only", checkpoint_id
            )
            
            files_in_checkpoint = set()
            if files_in_checkpoint_str and not files_in_checkpoint_str.startswith("error"):
                files_in_checkpoint = set(f for f in files_in_checkpoint_str.split('\n') if f.strip())
            
            print(f"\n{DIM}Files in checkpoint: {len(files_in_checkpoint)}{RESET}")
            print(f"{DIM}Files modified in session: {len(session_baseline_files)}{RESET}\n")
            
            # Process each file that was modified in this session
            for filepath, baseline_source in session_baseline_files.items():
                # Convert to relative path for comparison
                if os.path.isabs(filepath):
                    rel_filepath = os.path.relpath(filepath, self.project_path)
                else:
                    rel_filepath = filepath
                
                # Normalize path: remove leading ./
                normalized_rel_path = rel_filepath.lstrip('./')
                
                if normalized_rel_path in files_in_checkpoint:
                    # File exists in checkpoint - restore from checkpoint
                    src = os.path.join(temp_worktree, normalized_rel_path)
                    dest = os.path.join(self.project_path, normalized_rel_path)
                    
                    dest_dir = os.path.dirname(dest)
                    if dest_dir:
                        os.makedirs(dest_dir, exist_ok=True)
                    
                    with open(src, 'rb') as s, open(dest, 'wb') as d:
                        d.write(s.read())
                    
                    print(f"  {GREEN}✓{RESET} {filepath} {DIM}(from checkpoint){RESET}")
                else:
                    # File doesn't exist in checkpoint - restore to baseline
                    abs_filepath = filepath if os.path.isabs(filepath) else os.path.join(self.project_path, filepath)
                    
                    if baseline_source["type"] == "git":
                        # Restore from project .git (use normalized path for git)
                        result = subprocess.run(
                            ["git", "checkout", baseline_source["commit"], "--", normalized_rel_path],
                            cwd=self.project_path,
                            capture_output=True, text=True
                        )
                        if result.returncode == 0:
                            print(f"  {CYAN}↺{RESET} {filepath} {DIM}(to baseline: git {baseline_source['commit'][:8]}){RESET}")
                    
                    elif baseline_source["type"] == "blob":
                        # Restore from blob
                        if self.restore_file_from_blob(baseline_source["hash"], abs_filepath):
                            print(f"  {CYAN}↺{RESET} {filepath} {DIM}(to baseline: blob {baseline_source['hash'][:8]}){RESET}")
                    
                    elif baseline_source["type"] == "new":
                        # Delete new file
                        if os.path.exists(abs_filepath):
                            os.remove(abs_filepath)
                            print(f"  {YELLOW}✗{RESET} {filepath} {DIM}(deleted: was added after checkpoint){RESET}")
            
            # Load conversation snapshot
            snapshot_file = os.path.join(self.nanocode_dir, "conversation_snapshots.json")
            conversation_snapshot = None
            if os.path.exists(snapshot_file):
                with open(snapshot_file, 'r') as f:
                    snapshots = json.load(f)
                    conversation_snapshot = snapshots.get(checkpoint_id)
            
            return True, conversation_snapshot
        except Exception as e:
            print(f"{RED}Error during restore: {e}{RESET}")
            return False, None


class Session:
    """Represents a conversation session"""
    
    def __init__(self, session_id=None):
        self.session_id = session_id or self._generate_session_id()
        self.messages = []
        self.file_states = {}
        self.baseline_files = {}  # Track original file versions for rollback
        self.turns = []  # Track conversation turns
        self.metadata = {
            'created_at': time.time(),
            'last_active': time.time(),
            'description': '',
            'cwd': os.getcwd(),
            'parent_checkpoint': None,  # Track where this session branched from
            'parent_session': None,     # Track which session it branched from
            'git_commit': None,         # Project .git commit hash when session started
            'git_branch': None,         # Project .git branch when session started
            'git_dirty': False,         # Whether project had uncommitted changes
        }
    
    def _generate_session_id(self):
        """Generate unique session ID"""
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        random_suffix = ''.join(random.choices('0123456789abcdef', k=4))
        return f"{timestamp}_{random_suffix}"
    
    def _get_project_git_info(self):
        """Get current project git state"""
        try:
            cwd = self.metadata.get('cwd', os.getcwd())
            
            # Get current commit
            commit = subprocess.run(
                ["git", "rev-parse", "HEAD"],
                cwd=cwd,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            # Get current branch
            branch = subprocess.run(
                ["git", "rev-parse", "--abbrev-ref", "HEAD"],
                cwd=cwd,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            # Check if dirty
            status = subprocess.run(
                ["git", "status", "--porcelain"],
                cwd=cwd,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            return {
                'git_commit': commit[:8],
                'git_branch': branch,
                'git_dirty': bool(status)
            }
        except:
            return None
    
    def capture_git_state(self):
        """Capture current project git state into metadata"""
        git_info = self._get_project_git_info()
        if git_info:
            self.metadata.update(git_info)
    
    def track_file_state(self, filepath):
        """Track file state for conflict detection"""
        if os.path.exists(filepath):
            with open(filepath, 'rb') as f:
                content = f.read()
                file_hash = hashlib.md5(content).hexdigest()
                self.file_states[filepath] = {
                    'hash': file_hash,
                    'mtime': os.path.getmtime(filepath),
                    'size': len(content)
                }
    
    def detect_conflicts(self):
        """Detect if tracked files have been modified outside this session
        
        Returns:
            list: List of conflicted file paths
        """
        conflicts = []
        
        for filepath, saved_state in self.file_states.items():
            if os.path.exists(filepath):
                with open(filepath, 'rb') as f:
                    content = f.read()
                    current_hash = hashlib.md5(content).hexdigest()
                    current_mtime = os.path.getmtime(filepath)
                    
                    # Check if file has changed
                    if (current_hash != saved_state['hash'] or 
                        current_mtime != saved_state['mtime']):
                        conflicts.append(filepath)
            else:
                # File was deleted
                conflicts.append(f"{filepath} (deleted)")
        
        return conflicts
    
    def update_file_states(self):
        """Update all tracked file states to current state"""
        for filepath in list(self.file_states.keys()):
            if os.path.exists(filepath):
                self.track_file_state(filepath)
            else:
                # Remove deleted files from tracking
                del self.file_states[filepath]
    
    def add_turn(self, user_input, files_modified, checkpoint_id=None):
        """Record a conversation turn
        
        Args:
            user_input: User's input for this turn
            files_modified: Set of files modified in this turn
            checkpoint_id: Associated checkpoint ID if files were modified
        """
        turn_number = len(self.turns) + 1
        turn_data = {
            'turn_number': turn_number,
            'timestamp': time.time(),
            'user_input': user_input[:100],  # Truncate for preview
            'files_modified': list(files_modified),
            'checkpoint_id': checkpoint_id,
            'message_count': len(self.messages)  # Track message index
        }
        self.turns.append(turn_data)
        return turn_number
    
    def to_dict(self):
        """Serialize to dict"""
        return {
            'session_id': self.session_id,
            'messages': self.messages,
            'file_states': self.file_states,
            'baseline_files': self.baseline_files,
            'turns': self.turns,
            'metadata': self.metadata
        }
    
    @staticmethod
    def from_dict(data):
        """Deserialize from dict"""
        session = Session(session_id=data['session_id'])
        session.messages = data.get('messages', [])
        session.file_states = data.get('file_states', {})
        session.baseline_files = data.get('baseline_files', {})
        session.turns = data.get('turns', [])
        session.metadata = data.get('metadata', {})
        return session


class SessionManager:
    """Manage multiple sessions"""
    
    def __init__(self, project_path):
        self.project_path = project_path
        self.sessions_dir = os.path.join(project_path, ".nanocode", "sessions")
        self.current_session = None
        self.checkpoint_manager = CheckpointManager(project_path)
        self.parent_commit_for_next_checkpoint = None  # Track parent for first checkpoint
        os.makedirs(self.sessions_dir, exist_ok=True)
    
    def save_baseline_if_needed(self, filepath):
        """Save file's baseline version before first modification
        
        This is called before write/edit operations to preserve the original state.
        """
        if not self.current_session:
            return
        
        # Already saved
        if filepath in self.current_session.baseline_files:
            return
        
        print(f"{DIM}[LOG] Saving baseline for: {filepath}{RESET}", flush=True)
        
        # Get file info from project .git
        git_info = self.checkpoint_manager.get_file_git_info(filepath)
        
        if git_info:
            # File is tracked by project .git
            if git_info["has_changes"]:
                # Has local changes - save to blob
                blob_hash = self.checkpoint_manager.save_file_to_blob(filepath)
                if blob_hash:
                    self.current_session.baseline_files[filepath] = {
                        "type": "blob",
                        "hash": blob_hash
                    }
                    print(f"{DIM}[LOG] Saved dirty file to blob: {blob_hash[:8]}{RESET}", flush=True)
            else:
                # Clean - just record commit
                self.current_session.baseline_files[filepath] = {
                    "type": "git",
                    "commit": git_info["commit"]
                }
                print(f"{DIM}[LOG] Recorded git commit: {git_info['commit'][:8]}{RESET}", flush=True)
        else:
            # Untracked file or no .git
            if os.path.exists(filepath):
                # Save existing untracked file to blob
                blob_hash = self.checkpoint_manager.save_file_to_blob(filepath)
                if blob_hash:
                    self.current_session.baseline_files[filepath] = {
                        "type": "blob",
                        "hash": blob_hash
                    }
                    print(f"{DIM}[LOG] Saved untracked file to blob: {blob_hash[:8]}{RESET}", flush=True)
            else:
                # New file - mark as new
                self.current_session.baseline_files[filepath] = {
                    "type": "new"
                }
                print(f"{DIM}[LOG] Marked as new file{RESET}", flush=True)
        
        # Auto-save session to persist baseline_files
        self.save_session()
    
    def restore_baseline(self):
        """Restore all files to their baseline state
        
        Returns:
            bool: Success status
        """
        if not self.current_session:
            return False
        
        if not self.current_session.baseline_files:
            # No files modified = already at baseline state = success
            print(f"{DIM}No files to restore (already at baseline){RESET}")
            return True
        
        print(f"\n{BOLD}Restoring {len(self.current_session.baseline_files)} files to baseline...{RESET}\n")
        
        success_count = 0
        for filepath, source in self.current_session.baseline_files.items():
            try:
                # Normalize to absolute path
                abs_filepath = filepath if os.path.isabs(filepath) else os.path.join(self.project_path, filepath)
                # Get relative path for git operations
                rel_filepath = os.path.relpath(abs_filepath, self.project_path)
                
                if source["type"] == "git":
                    # Restore from project .git (use relative path)
                    result = subprocess.run(
                        ["git", "checkout", source["commit"], "--", rel_filepath],
                        cwd=self.project_path,
                        capture_output=True, text=True
                    )
                    if result.returncode == 0:
                        print(f"  {GREEN}✓{RESET} {filepath} {DIM}(from git {source['commit'][:8]}){RESET}")
                        success_count += 1
                    else:
                        print(f"  {RED}✗{RESET} {filepath} {DIM}(git checkout failed){RESET}")
                
                elif source["type"] == "blob":
                    # Restore from checkpoint.git blob (use absolute path)
                    if self.checkpoint_manager.restore_file_from_blob(source["hash"], abs_filepath):
                        print(f"  {GREEN}✓{RESET} {filepath} {DIM}(from blob {source['hash'][:8]}){RESET}")
                        success_count += 1
                    else:
                        print(f"  {RED}✗{RESET} {filepath} {DIM}(blob restore failed){RESET}")
                
                elif source["type"] == "new":
                    # Delete new file (use absolute path)
                    if os.path.exists(abs_filepath):
                        os.remove(abs_filepath)
                        print(f"  {GREEN}✓{RESET} {filepath} {DIM}(deleted new file){RESET}")
                        success_count += 1
                    else:
                        print(f"  {DIM}○{RESET} {filepath} {DIM}(already deleted){RESET}")
                        success_count += 1
            
            except Exception as e:
                print(f"  {RED}✗{RESET} {filepath} {DIM}(error: {e}){RESET}")
        
        print(f"\n{GREEN}✓ Restored {success_count}/{len(self.current_session.baseline_files)} files{RESET}")
        return success_count > 0
    
    def create_session(self, description="", parent_checkpoint=None, parent_session=None):
        """Create new session
        
        Args:
            description: Session description
            parent_checkpoint: Checkpoint ID this session branches from
            parent_session: Session ID this session branches from
        """
        session = Session()
        session.metadata['description'] = description
        session.metadata['parent_checkpoint'] = parent_checkpoint
        session.metadata['parent_session'] = parent_session
        # Capture project git state
        session.capture_git_state()
        self.current_session = session
        # Set checkpoint manager to use this session
        self.checkpoint_manager.set_session(session.session_id)
        # Store parent commit for first checkpoint
        self.parent_commit_for_next_checkpoint = parent_checkpoint
        self.save_session()
        return session
    
    def save_session(self):
        """Save current session to disk"""
        if not self.current_session:
            return
        
        self.current_session.metadata['last_active'] = time.time()
        session_file = os.path.join(
            self.sessions_dir,
            f"{self.current_session.session_id}.json"
        )
        
        with open(session_file, 'w') as f:
            json.dump(self.current_session.to_dict(), f, indent=2)
    
    def load_session(self, session_id):
        """Load session from disk"""
        session_file = os.path.join(self.sessions_dir, f"{session_id}.json")
        
        if not os.path.exists(session_file):
            return None
        
        with open(session_file, 'r') as f:
            data = json.load(f)
        
        session = Session.from_dict(data)
        self.current_session = session
        # Set checkpoint manager to use this session
        self.checkpoint_manager.set_session(session.session_id)
        return session
    
    def list_sessions(self):
        """List all sessions"""
        sessions = []
        
        if not os.path.exists(self.sessions_dir):
            return sessions
        
        for filename in os.listdir(self.sessions_dir):
            if filename.endswith('.json'):
                filepath = os.path.join(self.sessions_dir, filename)
                try:
                    with open(filepath, 'r') as f:
                        data = json.load(f)
                        sessions.append({
                            'session_id': data['session_id'],
                            'metadata': data['metadata'],
                            'message_count': len(data.get('messages', [])),
                        })
                except:
                    pass
        
        return sorted(sessions, key=lambda x: x['metadata'].get('last_active', 0), reverse=True)
    
    def load_last_session(self):
        """Load the most recent session"""
        sessions = self.list_sessions()
        if sessions:
            return self.load_session(sessions[0]['session_id'])
        return None




def handle_turns_command(session_manager, turn_number=None):
    """Handle /t or /turns command to display or restore conversation turns
    
    Args:
        session_manager: SessionManager instance
        turn_number: If provided, restore to this turn; otherwise list turns
    
    Returns:
        tuple: (success: bool, prefill_input: str or None)
    """
    if not session_manager.current_session:
        print(f"{YELLOW}⚠ No active session{RESET}")
        return False, None
    
    turns = session_manager.current_session.turns
    
    if not turns:
        print(f"{DIM}No conversation turns yet{RESET}")
        return False, None
    
    # If turn_number provided, restore to that turn
    if turn_number is not None:
        return restore_to_turn(session_manager, turn_number)
    
    # Otherwise, list all turns
    print(f"\n{BOLD}💬 Conversation Turns:{RESET}\n")
    
    for turn in turns:
        turn_num = turn['turn_number']
        timestamp = turn['timestamp']
        user_input = turn['user_input']
        files_modified = turn.get('files_modified', [])
        checkpoint_id = turn.get('checkpoint_id')
        
        # Calculate time ago
        time_ago = time.time() - timestamp
        if time_ago < 60:
            time_str = f"{int(time_ago)}s ago"
        elif time_ago < 3600:
            time_str = f"{int(time_ago/60)}m ago"
        elif time_ago < 86400:
            time_str = f"{int(time_ago/3600)}h ago"
        else:
            time_str = f"{int(time_ago/86400)}d ago"
        
        # Format turn line
        checkpoint_marker = f" {YELLOW}[{checkpoint_id}]{RESET}" if checkpoint_id else ""
        files_marker = f" {GREEN}✓{RESET}" if files_modified else ""
        
        print(f"  {CYAN}turn_{turn_num}{RESET}{checkpoint_marker}{files_marker} {DIM}({time_str}){RESET}")
        print(f"  {DIM}└─{RESET} {user_input}")
        
        if files_modified:
            files_display = ", ".join(files_modified[:3])
            if len(files_modified) > 3:
                files_display += f" +{len(files_modified)-3} more"
            print(f"     {DIM}Files: {files_display}{RESET}")
        
        print()
    
    print(f"{DIM}Tip: Use '/t <number>' to undo that turn and restart from previous{RESET}")
    return False, None  # Return False to indicate no restore action


def restore_to_turn(session_manager, turn_number):
    """Restore files and conversation to a specific turn
    
    Args:
        session_manager: SessionManager instance
        turn_number: Turn number to undo (1-indexed, will restore to turn_number-1)
    
    Returns:
        tuple: (success: bool, prefill_input: str or None)
    """
    turns = session_manager.current_session.turns
    
    # User wants to undo turn_N, so restore to turn_(N-1)
    # Special case: /t 1 means clear all turns
    if turn_number == 1:
        print(f"{YELLOW}⚠ This will undo turn_1 and reset to session start{RESET}")
        confirm = input(f"\n{BOLD}Continue? (y/N): {RESET}").strip().lower()
        if confirm != 'y':
            print(f"{DIM}Cancelled{RESET}")
            return False, None
        
        # Restore to baseline
        success = session_manager.restore_baseline()
        if not success:
            return False, None
        
        # Clear all turns and messages
        session_manager.current_session.turns = []
        session_manager.current_session.messages = []
        session_manager.current_session.update_file_states()
        session_manager.parent_commit_for_next_checkpoint = None
        
        print(f"{GREEN}✓ Reset to session start{RESET}")
        
        # Return original turn_1 input for prefill
        original_input = turns[0]['user_input'] if turns else None
        return True, original_input
    
    # Normal case: restore to turn_(N-1)
    restore_to = turn_number - 1
    
    # Validate
    if restore_to < 1 or restore_to > len(turns):
        print(f"{RED}Invalid turn number. Valid range: 1-{len(turns)}{RESET}")
        return False, None
    
    turn = turns[restore_to - 1]
    
    # Find the most recent checkpoint at or before restore_to
    checkpoint_id = None
    checkpoint_turn = None
    
    for i in range(restore_to - 1, -1, -1):
        if turns[i].get('checkpoint_id'):
            checkpoint_id = turns[i]['checkpoint_id']
            checkpoint_turn = i + 1
            break
    
    # Get original input from the turn being undone
    # Note: turn['user_input'] is truncated to 100 chars, so get full input from messages
    original_input = None
    if turn_number <= len(turns):
        turn_to_undo = turns[turn_number - 1]
        # Find the user message for this turn
        # turn_to_undo['message_count'] is the total messages after this turn
        # The user message for this turn is at index (message_count - 2) if there's an assistant response
        # or (message_count - 1) if it's the last message
        msg_idx = turn_to_undo['message_count'] - 2  # Assume there's an assistant response
        if msg_idx >= 0 and msg_idx < len(session_manager.current_session.messages):
            msg = session_manager.current_session.messages[msg_idx]
            if msg['role'] == 'user' and isinstance(msg['content'], str):
                original_input = msg['content']
    
    # Show warning
    future_turns = len(turns) - restore_to
    print(f"\n{YELLOW}⚠ This will undo turn_{turn_number} (restore to turn_{restore_to}){RESET}")
    
    if checkpoint_id:
        if checkpoint_turn == restore_to:
            print(f"{YELLOW}⚠ Files: restored to turn_{restore_to} checkpoint ({checkpoint_id}){RESET}")
        else:
            print(f"{YELLOW}⚠ Files: restored to turn_{checkpoint_turn} checkpoint ({checkpoint_id}){RESET}")
    else:
        print(f"{YELLOW}⚠ Files: restored to baseline (no checkpoints before turn_{restore_to}){RESET}")
    
    print(f"{YELLOW}⚠ Conversation: restored to turn_{restore_to} ({turn['message_count']} messages){RESET}")
    
    if future_turns > 0:
        print(f"{YELLOW}⚠ Future turns ({restore_to + 1}-{len(turns)}) will be discarded{RESET}")
    
    if original_input:
        print(f"\n{CYAN}Original turn_{turn_number} input will be prefilled:{RESET}")
        print(f"  {DIM}{original_input[:100]}{'...' if len(original_input) > 100 else ''}{RESET}")
    
    confirm = input(f"\n{BOLD}Continue? (y/N): {RESET}").strip().lower()
    
    if confirm != 'y':
        print(f"{DIM}Cancelled{RESET}")
        return False, None
    
    # Restore files
    if checkpoint_id:
        success, _ = session_manager.checkpoint_manager.restore_checkpoint(
            checkpoint_id,
            session_manager.current_session.baseline_files
        )
        if not success:
            print(f"{RED}✗ Failed to restore files{RESET}")
            return False, None
        print(f"{GREEN}✓ Restored files to checkpoint {checkpoint_id}{RESET}")
    else:
        # No checkpoint before this turn, restore to baseline
        success = session_manager.restore_baseline()
        if not success:
            print(f"{RED}✗ Failed to restore to baseline{RESET}")
            return False, None
        print(f"{GREEN}✓ Restored files to baseline{RESET}")
    
    # Restore conversation by truncating messages (modify session directly)
    session_manager.current_session.messages = session_manager.current_session.messages[:turn['message_count']]
    
    # Truncate turns list to restore_to
    session_manager.current_session.turns = turns[:restore_to]
    
    # Update file states to match restored files
    session_manager.current_session.update_file_states()
    
    # Reset parent commit for next checkpoint
    session_manager.parent_commit_for_next_checkpoint = checkpoint_id if checkpoint_id else None
    
    print(f"{GREEN}✓ Restored to turn_{restore_to} ({len(session_manager.current_session.messages)} messages){RESET}")
    
    return True, original_input  # Return success and prefill input


def handle_checkpoint_command(parts, session_manager, files_modified):
    """Handle /checkpoint or /c commands
    
    Returns:
        messages: New messages list if conversation was restored, None otherwise
    """
    # Default to list if no subcommand
    if len(parts) < 2:
        parts.append("list")
    
    cmd = parts[1]
    
    # If cmd looks like a commit hash (7-8 hex chars), treat as restore
    if len(cmd) >= 7 and len(cmd) <= 8 and all(c in '0123456789abcdef' for c in cmd.lower()):
        cmd = "restore"
        checkpoint_id = parts[1]
    else:
        checkpoint_id = None
    
    if cmd == "baseline" or cmd == "base":
        # Restore all files to baseline (session start state)
        if not session_manager.current_session.baseline_files:
            print(f"{YELLOW}⚠ No baseline: no files have been modified in this session{RESET}")
            return None
        
        print(f"{YELLOW}⚠ This will restore all modified files to their original state{RESET}")
        print(f"{YELLOW}⚠ Files to restore: {len(session_manager.current_session.baseline_files)}{RESET}")
        
        # Show files
        print(f"\n{DIM}Files:{RESET}")
        for filepath in list(session_manager.current_session.baseline_files.keys())[:10]:
            print(f"  {DIM}• {filepath}{RESET}")
        if len(session_manager.current_session.baseline_files) > 10:
            print(f"  {DIM}... and {len(session_manager.current_session.baseline_files) - 10} more{RESET}")
        
        confirm = input(f"\n{BOLD}Continue? (y/N): {RESET}").strip().lower()
        
        if confirm != 'y':
            print(f"{DIM}Cancelled{RESET}")
            return None
        
        success = session_manager.restore_baseline()
        if success:
            # Clear conversation
            print(f"{GREEN}✓ Conversation cleared{RESET}")
            return []
        else:
            return None
    
    elif cmd == "list" or cmd == "all" or cmd == "--all":
        show_all = (cmd == "all" or "--all" in parts)
        
        if show_all:
            # Show git graph of all branches
            print(f"\n{BOLD}📍 Checkpoint Graph:{RESET}\n")
            
            # Use git log --graph --all to show the tree
            # Format: %h = short hash, %d = ref names, %s = subject, %ar = relative date
            graph_output = session_manager.checkpoint_manager._git_command(
                "--git-dir", session_manager.checkpoint_manager.bare_repo,
                "log", "--graph", "--all", "--oneline", 
                "--format=%h %s (%ar)", "-20"
            )
            
            if graph_output and not graph_output.startswith("error"):
                # Also get branch info for each commit
                branches_output = session_manager.checkpoint_manager._git_command(
                    "--git-dir", session_manager.checkpoint_manager.bare_repo,
                    "branch", "-a", "--contains"
                )
                
                # Parse and display
                for line in graph_output.split('\n'):
                    if not line.strip():
                        continue
                    
                    # Extract commit hash
                    match = re.search(r'\b([0-9a-f]{7,8})\b', line)
                    if match:
                        commit_hash = match.group(1)
                        
                        # Get branches containing this commit
                        branch_info = session_manager.checkpoint_manager._git_command(
                            "--git-dir", session_manager.checkpoint_manager.bare_repo,
                            "branch", "-a", "--contains", commit_hash
                        )
                        
                        # Extract session names from branches
                        session_names = []
                        if branch_info and not branch_info.startswith("error"):
                            for branch_line in branch_info.split('\n'):
                                branch_line = branch_line.strip().lstrip('* ')
                                if branch_line.startswith('session_'):
                                    # Shorten session name: session_20260130_103323_f7 -> s:20260130_103323_f7
                                    session_short = 's:' + branch_line[8:]  # Remove 'session_' prefix
                                    session_names.append(session_short)
                        
                        # Highlight commit hash
                        line = line.replace(commit_hash, f"{CYAN}{commit_hash}{RESET}")
                        
                        # Add session info if found
                        if session_names:
                            # Insert session names after commit hash
                            session_str = f"{GREEN}[{', '.join(session_names[:2])}]{RESET}"
                            line = line.replace(commit_hash + f"{RESET}", commit_hash + f"{RESET} {session_str}")
                    
                    
                    print(f"  {line}")
                print()
            else:
                print(f"{DIM}No checkpoints yet{RESET}\n")
            
            print(f"{DIM}Restore: /c <hash>{RESET}")
            return None
        else:
            # Show current session's checkpoints
            checkpoints = session_manager.checkpoint_manager.list_checkpoints(show_all=False)
            if not checkpoints:
                print(f"{DIM}No checkpoints yet{RESET}")
                return None
            
            print(f"\n{BOLD}📍 Checkpoints:{RESET}\n")
            
            # Get checkpoint details from git log with timestamp
            for commit_hash, message in checkpoints[:10]:  # Show first 10 (already newest first from git log)
                # Try to get timestamp from git
                timestamp_str = session_manager.checkpoint_manager._git_command(
                    "--git-dir", session_manager.checkpoint_manager.bare_repo,
                    "log", "-1", "--format=%ar", commit_hash
                )
                if timestamp_str.startswith("error"):
                    timestamp_str = ""
                
                # Get modified files
                files_str = session_manager.checkpoint_manager._git_command(
                    "--git-dir", session_manager.checkpoint_manager.bare_repo,
                    "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash
                )
                files = []
                if files_str and not files_str.startswith("error"):
                    files = [f.strip() for f in files_str.split('\n') if f.strip()]
                
                # Format: hash | time ago | message
                time_part = f"{DIM}{timestamp_str}{RESET}" if timestamp_str else ""
                print(f"  {CYAN}{commit_hash}{RESET} {time_part}")
                print(f"  {DIM}└─{RESET} {message}")
                if files:
                    files_display = ", ".join(files[:3])
                    if len(files) > 3:
                        files_display += f" +{len(files)-3} more"
                    print(f"  {DIM}   Files: {files_display}{RESET}")
                print()
            
            print(f"{DIM}Tip: Use '/c all' or '/ca' to see git graph{RESET}")
            print(f"{DIM}Restore: /c <hash>{RESET}")
            return None
    
    elif cmd == "restore":
        if not checkpoint_id:
            print(f"{RED}Usage: /c <checkpoint_id>{RESET}")
            return None
        
        print(f"{YELLOW}⚠ This will restore files AND conversation to checkpoint {checkpoint_id}{RESET}")
        print(f"{YELLOW}⚠ Future checkpoints will be discarded from history{RESET}")
        confirm = input(f"{BOLD}Continue? (y/N): {RESET}").strip().lower()
        
        if confirm != 'y':
            print(f"{DIM}Cancelled{RESET}")
            return None
        
        success, conversation_snapshot = session_manager.checkpoint_manager.restore_checkpoint(
            checkpoint_id,
            session_manager.current_session.baseline_files
        )
        if success:
            print(f"{GREEN}✓ Restored files to checkpoint {checkpoint_id}{RESET}")
            if conversation_snapshot:
                print(f"{GREEN}✓ Restored conversation ({len(conversation_snapshot)} messages){RESET}")
                return conversation_snapshot
            else:
                print(f"{YELLOW}⚠ No conversation snapshot found for this checkpoint{RESET}")
                return None
        else:
            print(f"{RED}✗ Failed to restore checkpoint{RESET}")
            return None
    
    else:
        print(f"{RED}Unknown command: {cmd}{RESET}")
        return None


if __name__ == "__main__":
    main()
