#!/usr/bin/env python3
"""nanocode - minimal claude code alternative"""
import glob as globlib
import hashlib
import json
import os
import random
import re
import readline
import select
import ssl
import subprocess
import sys
import termios
import time
import tty
import urllib.request
import urllib.parse
from datetime import datetime

OPENROUTER_KEY = os.environ.get("OPENROUTER_API_KEY")
LOCAL_API_KEY = os.environ.get("LOCAL_API_KEY")
API_URL = (
    "http://127.0.0.1:8990/v1/messages" if LOCAL_API_KEY
    else "https://openrouter.ai/api/v1/messages" if OPENROUTER_KEY
    else "https://api.anthropic.com/v1/messages"
)
MODEL = os.environ.get("MODEL",
    "anthropic/claude-sonnet-4.5" if LOCAL_API_KEY
    else "anthropic/claude-opus-4.5" if OPENROUTER_KEY
    else "claude-opus-4-5"
)

# ANSI colors
RESET, BOLD, DIM = "\033[0m", "\033[1m", "\033[2m"
BLUE, CYAN, GREEN, YELLOW, RED = "\033[34m", "\033[36m", "\033[32m", "\033[33m", "\033[31m"
stop_flag = False

def create_opener():
    """Create URL opener with SSL and proxy support"""
    proxy = os.environ.get("http_proxy") or os.environ.get("https_proxy")
    ssl_ctx = ssl.create_default_context()
    ssl_ctx.check_hostname = False
    ssl_ctx.verify_mode = ssl.CERT_NONE

    handlers = [urllib.request.HTTPSHandler(context=ssl_ctx)]
    if proxy: handlers.insert(0, urllib.request.ProxyHandler({"http": proxy, "https": proxy}))
    return urllib.request.build_opener(*handlers)

def register_tool(name, desc, params):
    """Register a tool from extension code"""
    def decorator(func):
        TOOLS[name] = (desc, params, func)
        return func
    return decorator

def search_extension(args):
    """Search extensions from gist.kitchain.cn"""
    query = args.get("query", "")
    if not query: return "error: query required"
    try:
        # Split query into keywords
        keywords = query.lower().split()
        gist_info = {}  # {gist_path: {"hits": count, "title": str, "desc": str, "topics": []}}
        opener = create_opener()

        # Search each keyword as a topic
        for keyword in keywords:
            url = f"https://gist.kitchain.cn/topics/{urllib.parse.quote(keyword)}"
            html = opener.open(urllib.request.Request(url), timeout=10).read().decode()

            # Extract gist URLs and titles
            gist_matches = re.findall(
                r'<a class="font-bold" href="https://gist\.kitchain\.cn/([^/]+/[a-f0-9]+)">([^<]+)</a>',
                html
            )

            for gist_path, title in gist_matches:
                if gist_path not in gist_info:
                    # Extract description and topics for this gist
                    gist_section = re.search(
                        rf'{re.escape(gist_path)}.*?'
                        r'<h6 class="text-xs[^"]*">([^<]+)</h6>(.*?)</div>\s*</div>',
                        html, re.DOTALL
                    )
                    desc = ""
                    topics = []
                    if gist_section:
                        desc = gist_section.group(1).strip()
                        topics_section = gist_section.group(2)
                        topics = re.findall(r'topics/([^"]+)"[^>]*>([^<]+)<', topics_section)
                        topics = [t[1] for t in topics]  # Extract topic names

                    gist_info[gist_path] = {
                        "hits": 0,
                        "title": title.strip(),
                        "desc": desc,
                        "topics": topics,
                        "filename": title.strip()
                    }
                gist_info[gist_path]["hits"] += 1

        if not gist_info: return f"No extensions found: {query}"

        # Sort by hit count (descending)
        sorted_gists = sorted(gist_info.items(), key=lambda x: x[1]["hits"], reverse=True)[:10]

        result = f"Found {len(sorted_gists)} extensions:\n\n"
        for gist_path, info in sorted_gists:
            result += f"• {info['title']}\n"
            if info['desc']:
                result += f"  {info['desc']}\n"
            if info['topics']:
                result += f"  Topics: {', '.join(info['topics'])}\n"
            result += f"  Matched: {info['hits']} keyword(s)\n\n"

        # Return first gist's load URL
        first_gist = sorted_gists[0][0]
        first_filename = sorted_gists[0][1]['filename']
        result += f"To load the top result:\nload({{\"url\": \"https://gist.kitchain.cn/{first_gist}/raw/HEAD/{first_filename}\"}})"
        return result
    except Exception as e:
        return f"error: {e}"

def load(args):
    """Load extension from URL"""
    url = args.get("url")
    if not url: return "error: url required"
    try:
        opener = create_opener()
        code = opener.open(urllib.request.Request(url), timeout=10).read().decode()
        exec(code, {"register_tool": register_tool, "TOOLS": TOOLS, "urllib": urllib, "json": json, "re": re, "subprocess": subprocess})
        new = [k for k in TOOLS if k not in ["read","write","edit","glob","grep","bash","web_search","search_extension","load"]]
        return f"Loaded. New tools: {', '.join(new)}"
    except Exception as e:
        return f"error: {e}"

# --- Tools ---
def read(args):
    lines = open(args["path"]).readlines()
    offset, limit = args.get("offset", 0), args.get("limit", len(lines))
    return "".join(f"{offset+i+1:4}| {l}" for i, l in enumerate(lines[offset:offset+limit]))

def write(args):
    filepath = args["path"]
    content = args["content"]
    print(f"{DIM}[LOG] write: {filepath} ({len(content)} bytes){RESET}", flush=True)
    open(filepath, "w").write(content)
    print(f"{DIM}[LOG] write completed: {filepath}{RESET}", flush=True)
    return "ok"

def edit(args):
    filepath = args["path"]
    print(f"{DIM}[LOG] edit: {filepath}{RESET}", flush=True)
    text = open(filepath).read()
    print(f"{DIM}[LOG] edit read: {len(text)} bytes{RESET}", flush=True)
    old, new = args["old"], args["new"]
    if old not in text: return "error: old_string not found"
    count = text.count(old)
    if not args.get("all") and count > 1:
        return f"error: old_string appears {count} times (use all=true)"
    result = text.replace(old, new) if args.get("all") else text.replace(old, new, 1)
    print(f"{DIM}[LOG] edit writing: {len(result)} bytes{RESET}", flush=True)
    open(filepath, "w").write(result)
    print(f"{DIM}[LOG] edit completed: {filepath}{RESET}", flush=True)
    return "ok"

def glob(args):
    pattern = (args.get("path", ".") + "/" + args["pat"]).replace("//", "/")
    files = sorted(globlib.glob(pattern, recursive=True), 
                   key=lambda f: os.path.getmtime(f) if os.path.isfile(f) else 0, reverse=True)
    return "\n".join(files) or "none"

def grep(args):
    pattern, hits = re.compile(args["pat"]), []
    for fp in globlib.glob(args.get("path", ".") + "/**", recursive=True):
        try:
            for n, l in enumerate(open(fp), 1):
                if pattern.search(l): hits.append(f"{fp}:{n}:{l.rstrip()}")
        except: pass
    return "\n".join(hits[:50]) or "none"

def bash(args):
    global stop_flag
    proc = subprocess.Popen(args["cmd"], shell=True, stdout=subprocess.PIPE, 
                           stderr=subprocess.STDOUT, text=True)
    lines = []
    old_settings = termios.tcgetattr(sys.stdin)
    try:
        tty.setcbreak(sys.stdin.fileno())
        if proc.stdout:
            import fcntl
            fd = proc.stdout.fileno()
            fcntl.fcntl(fd, fcntl.F_SETFL, fcntl.fcntl(fd, fcntl.F_GETFL) | os.O_NONBLOCK)
            
            while True:
                # Check ESC key
                if select.select([sys.stdin], [], [], 0)[0]:
                    if sys.stdin.read(1) == '\x1b':
                        stop_flag = True
                        proc.kill()
                        lines.append("\n(stopped)")
                        print(f"\n{YELLOW}⏸ Stopped{RESET}")
                        break
                
                # Read output
                if select.select([proc.stdout], [], [], 0.1)[0]:
                    line = proc.stdout.readline()
                    if line:
                        print(f"  {DIM}│ {line.rstrip()}{RESET}", flush=True)
                        lines.append(line)
                
                # Check if done
                if proc.poll() is not None:
                    remaining = proc.stdout.read()
                    if remaining:
                        for line in remaining.split('\n'):
                            if line:
                                print(f"  {DIM}│ {line.rstrip()}{RESET}", flush=True)
                                lines.append(line + '\n')
                    break
        
        if not stop_flag:
            proc.wait(timeout=30)
    except subprocess.TimeoutExpired:
        proc.kill()
        lines.append("\n(timeout)")
    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
    
    return "".join(lines).strip() or "(empty)"

def web_search(args):
    """Search web using DuckDuckGo"""
    query, max_results = args["query"], args.get("max_results", 5)
    try:
        url = f"https://html.duckduckgo.com/html/?q={urllib.parse.quote_plus(query)}"
        headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"}
        opener = create_opener()
        html = opener.open(urllib.request.Request(url, headers=headers), timeout=30).read().decode()
        
        # Extract titles and URLs
        links = re.findall(r'class="result__a"[^>]+href="([^"]+)"[^>]*>([^<]+)<', html)
        # Extract snippets
        snippets = re.findall(r'class="result__snippet"[^>]*>([^<]*)<', html)
        if not links: return "No results found"
        
        results = []
        for i, ((link, title), snippet) in enumerate(zip(links[:max_results], snippets[:max_results] + [""] * max_results), 1):
            results.append(f"{i}. {title.strip()}\n   URL: {link}\n   {snippet.strip()}\n")
        return "\n".join(results)
    except Exception as e:
        return f"error: {e}"


TOOLS = {
    "read": ("Read file with line numbers", {"path": "string", "offset": "number?", "limit": "number?"}, read),
    "write": ("Write content to file", {"path": "string", "content": "string"}, write),
    "edit": ("Replace old with new in file", {"path": "string", "old": "string", "new": "string", "all": "boolean?"}, edit),
    "glob": ("Find files by pattern", {"pat": "string", "path": "string?"}, glob),
    "grep": ("Search files for regex", {"pat": "string", "path": "string?"}, grep),
    "bash": ("Run shell command", {"cmd": "string"}, bash),
    "web_search": ("Search the web using DuckDuckGo", {"query": "string", "max_results": "number?"}, web_search),
    "search_extension": ("Search for extensions to add new capabilities (GitHub docs, web scraping, APIs, etc)", {"query": "string"}, search_extension),
    "load": ("Load extension from URL to add new tools", {"url": "string"}, load),
}

def run_tool(name, args):
    try: return TOOLS[name][2](args)
    except Exception as e: return f"error: {e}"

def make_schema():
    result = []
    for name, (desc, params, _) in TOOLS.items():
        props, req = {}, []
        for pname, ptype in params.items():
            opt = ptype.endswith("?")
            props[pname] = {"type": "integer" if ptype.rstrip("?") == "number" else ptype.rstrip("?")}
            if not opt: req.append(pname)
        result.append({"name": name, "description": desc, 
                      "input_schema": {"type": "object", "properties": props, "required": req}})
    return result

def call_api(messages, system_prompt, stream=True, enable_thinking=True, use_tools=True):
    headers = {"Content-Type": "application/json", "anthropic-version": "2023-06-01"}
    if LOCAL_API_KEY: headers["Authorization"] = f"Bearer {LOCAL_API_KEY}"
    elif OPENROUTER_KEY: headers["Authorization"] = f"Bearer {OPENROUTER_KEY}"
    else: headers["x-api-key"] = os.environ.get("ANTHROPIC_API_KEY", "")
    
    data = {"model": MODEL, "max_tokens": 8192, "system": system_prompt,
            "messages": messages, "stream": stream}
    
    if use_tools:
        data["tools"] = make_schema()

    if enable_thinking and os.environ.get("THINKING"):
        data["thinking"] = {"type": "enabled", "budget_tokens": int(os.environ.get("THINKING_BUDGET", "10000"))}

    req = urllib.request.Request(API_URL, json.dumps(data).encode(), headers, method="POST")
    return create_opener().open(req)

def summarize_changes(user_input, files_modified, checkpoint_manager, checkpoint_id):
    """Use LLM to summarize the changes made in this turn
    
    Args:
        user_input: User's request
        files_modified: Set of modified file paths
        checkpoint_manager: CheckpointManager instance
        checkpoint_id: Checkpoint hash to get diff from
        
    Returns:
        str: One-line summary of changes
    """
    if not files_modified or not checkpoint_id:
        return user_input[:50]
    
    try:
        # Get diff from git
        diff_output = checkpoint_manager._git_command(
            "--git-dir", checkpoint_manager.bare_repo,
            "show", "--format=", checkpoint_id
        )
        
        # Check if diff is empty or error - no actual changes
        if not diff_output or diff_output.startswith("error") or len(diff_output.strip()) == 0:
            # No diff available, just use user input
            return user_input[:50]
        
        # Limit diff size to avoid token overflow (max ~3000 chars)
        if len(diff_output) > 3000:
            diff_output = diff_output[:3000] + "\n... (truncated)"
        
        summary_prompt = f"""Based on the actual code changes (diff), generate a brief Chinese summary (max 30 Chinese characters).

IMPORTANT: Must be based on the actual code changes, not the user's description.

Code changes (diff):
{diff_output}

User description (for reference only): {user_input}

Requirements:
1. Describe what code/functionality was actually modified
2. Reply in Chinese only, no explanation
3. No quotes
4. Max 30 Chinese characters

Good examples:
- 在 auth.py 添加 JWT 验证
- 修复 parser.py 空指针异常
- 重构 database.py 连接池
- 更新 README 添加安装说明
"""
        
        messages = [{"role": "user", "content": summary_prompt}]
        response = call_api(messages, "You are a code change analyzer, skilled at extracting key information from diffs. Reply in Chinese.", 
                           stream=False, enable_thinking=False, use_tools=False)
        
        # Parse non-streaming response
        data = json.loads(response.read().decode())
        blocks = data.get("content", [])
        
        for block in blocks:
            if block.get("type") == "text":
                summary = block.get("text", "").strip()
                
                # Remove thinking tags if present
                if "<thinking>" in summary:
                    # Extract content after </thinking>
                    parts = summary.split("</thinking>")
                    if len(parts) > 1:
                        summary = parts[-1].strip()
                
                # Clean up and limit length
                summary = summary.replace('"', '').replace("'", "")
                if summary and len(summary) <= 80:
                    return summary
        
        # Fallback to user input
        return user_input[:50]
    except Exception as e:
        # On error, fallback to user input
        return user_input[:50]

def process_stream(response):
    """简化的流式处理，支持ESC中断"""
    global stop_flag
    blocks, current, text_buf, json_buf, think_buf = [], None, "", "", ""
    
    # Save terminal settings
    old_settings = termios.tcgetattr(sys.stdin)
    try:
        tty.setcbreak(sys.stdin.fileno())
        
        for line in response:
            if select.select([sys.stdin], [], [], 0)[0]:
                ch = sys.stdin.read(1)
                if ch == '\x1b':  # ESC key
                    stop_flag = True
                    print(f"\n{YELLOW}⏸ Stopped{RESET}")
                    break
            
            line = line.decode("utf-8").strip()
            if not line.startswith("data: "): continue
            if line == "data: [DONE]": continue
            
            try:
                data = json.loads(line[6:])
                etype = data.get("type")
                
                if etype == "content_block_start":
                    block = data.get("content_block", {})
                    current = {"type": block.get("type"), "id": block.get("id")}
                    if current["type"] == "text":
                        text_buf = ""
                        print(f"\n{CYAN}⏺{RESET} ", end="", flush=True)
                    elif current["type"] == "thinking":
                        think_buf = ""
                        print(f"\n{YELLOW}💭{RESET} {DIM}", end="", flush=True)
                    elif current["type"] == "tool_use":
                        current["name"] = block.get("name")
                        json_buf = ""
                        
                elif etype == "content_block_delta":
                    delta = data.get("delta", {})
                    dtype = delta.get("type")
                    if dtype == "text_delta":
                        text = delta.get("text", "")
                        text_buf += text
                        print(text, end="", flush=True)
                    elif dtype == "thinking_delta":
                        text = delta.get("thinking", "")
                        think_buf += text
                        print(text, end="", flush=True)
                    elif dtype == "input_json_delta" and current:
                        json_buf += delta.get("partial_json", "")
                        
                elif etype == "content_block_stop" and current:
                    if current["type"] == "text":
                        current["text"] = text_buf
                        print()
                    elif current["type"] == "thinking":
                        print(RESET)
                    elif current["type"] == "tool_use":
                        try: current["input"] = json.loads(json_buf)
                        except: current["input"] = {}
                    blocks.append(current)
                    current = None
            except: pass
    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
    
    return blocks

def is_file_in_project(filepath, project_path):
    """Check if file is within project directory"""
    try:
        abs_file = os.path.abspath(filepath)
        abs_project = os.path.abspath(project_path)
        # Check if file is under project directory
        return abs_file.startswith(abs_project + os.sep) or abs_file == abs_project
    except:
        return False

def read_multiline_input():
    """Read multiline input. Enter to submit, Alt+Enter for newline."""
    lines = []
    current = ""
    cursor_pos = 0  # Cursor position in current line
    
    # Enable bracketed paste mode
    print("\033[?2004h", end="", flush=True)
    
    old_settings = termios.tcgetattr(sys.stdin)
    try:
        tty.setcbreak(sys.stdin.fileno())
        print(f"{BOLD}{BLUE}❯{RESET} ", end="", flush=True)
        
        while True:
            ch = sys.stdin.read(1)
            
            if ch == '\x03':  # Ctrl+C - clear input
                lines.clear()
                current = ""
                cursor_pos = 0
                print("\r\033[K", end="", flush=True)
                print(f"{BOLD}{BLUE}❯{RESET} ", end="", flush=True)
                continue
            
            if ch == '\x04':  # Ctrl+D
                raise EOFError
            
            if ch == '\x1b':  # Escape sequence
                next_ch = sys.stdin.read(1)
                if next_ch in ('\r', '\n'):  # Alt+Enter
                    lines.append(current)
                    current = ""
                    cursor_pos = 0
                    print(f"\n{BOLD}{BLUE}│{RESET} ", end="", flush=True)
                elif next_ch == '[':  # Escape sequence
                    seq = sys.stdin.read(1)
                    if seq == 'C':  # Right arrow
                        if cursor_pos < len(current):
                            cursor_pos += 1
                            print("\033[C", end="", flush=True)
                    elif seq == 'D':  # Left arrow
                        if cursor_pos > 0:
                            cursor_pos -= 1
                            print("\033[D", end="", flush=True)
                    elif seq == '2':  # Bracketed paste start: ESC[200~
                        rest = sys.stdin.read(3)  # Read "00~"
                        if rest == '00~':
                            # Read pasted content until ESC[201~
                            paste_buf = ""
                            while True:
                                c = sys.stdin.read(1)
                                if c == '\x1b':
                                    # Check for [201~
                                    peek = sys.stdin.read(5)
                                    if peek == '[201~':
                                        break
                                    else:
                                        paste_buf += c + peek
                                else:
                                    paste_buf += c
                            
                            # Process pasted content
                            paste_lines = paste_buf.split('\n')
                            
                            if len(paste_lines) == 1:
                                # Single line paste
                                current = current[:cursor_pos] + paste_lines[0] + current[cursor_pos:]
                                cursor_pos += len(paste_lines[0])
                                prefix = f"{BOLD}{BLUE}{'│' if lines else '❯'}{RESET} "
                                print(f"\r\033[K{prefix}{current}", end="", flush=True)
                            else:
                                # Multi-line paste
                                # First line appends to current
                                first_line = current[:cursor_pos] + paste_lines[0]
                                print(paste_lines[0], end="", flush=True)
                                if first_line:
                                    lines.append(first_line)
                                
                                # Middle lines
                                for line in paste_lines[1:-1]:
                                    print(f"\n{BOLD}{BLUE}│{RESET} {line}", end="", flush=True)
                                    lines.append(line)
                                
                                # Last line becomes new current
                                current = paste_lines[-1]
                                cursor_pos = len(current)
                                print(f"\n{BOLD}{BLUE}│{RESET} {current}", end="", flush=True)
                continue
            
            if ch in ('\r', '\n'):  # Enter - submit
                if current:
                    lines.append(current)
                print()
                break
            
            if ch in ('\x7f', '\x08'):  # Backspace
                if cursor_pos > 0:
                    # Delete character before cursor
                    current = current[:cursor_pos-1] + current[cursor_pos:]
                    cursor_pos -= 1
                    # Redraw current line
                    prefix = f"{BOLD}{BLUE}{'│' if lines else '❯'}{RESET} "
                    print(f"\r\033[K{prefix}{current}", end="", flush=True)
                    # Move cursor back to position
                    if cursor_pos < len(current):
                        print(f"\033[{len(current) - cursor_pos}D", end="", flush=True)
                elif lines:
                    # Merge with previous line
                    prev_line = lines.pop()
                    cursor_pos = len(prev_line)  # Cursor at end of previous line
                    current = prev_line + current
                    # Move up and redraw
                    print("\033[A\033[K", end="", flush=True)
                    prefix = f"{BOLD}{BLUE}{'│' if lines else '❯'}{RESET} "
                    print(f"\r{prefix}{current}", end="", flush=True)
                    if cursor_pos < len(current):
                        print(f"\033[{len(current) - cursor_pos}D", end="", flush=True)
                continue
            
            if ch.isprintable() or ch == '\t':
                # Insert character at cursor position
                current = current[:cursor_pos] + ch + current[cursor_pos:]
                cursor_pos += 1
                # Redraw from cursor position
                print(f"{ch}{current[cursor_pos:]}", end="", flush=True)
                # Move cursor back if needed
                if cursor_pos < len(current):
                    print(f"\033[{len(current) - cursor_pos}D", end="", flush=True)
        
    finally:
        # Disable bracketed paste mode
        print("\033[?2004l", end="", flush=True)
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
    
    return "\n".join(lines).strip()

def main():
    global stop_flag
    # Parse command line arguments
    continue_session = "-c" in sys.argv or "--continue" in sys.argv
    list_sessions = "-l" in sys.argv or "--list" in sys.argv
    
    # Disable Ctrl+C signal
    old_settings = termios.tcgetattr(sys.stdin)
    new_settings = termios.tcgetattr(sys.stdin)
    new_settings[3] = new_settings[3] & ~termios.ISIG  # Disable signal generation
    termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)

    try:
        proxy = os.environ.get("http_proxy") or os.environ.get("https_proxy")
        proxy_info = f" | {DIM}🌐 {proxy}{RESET}" if proxy else ""
        thinking_info = f" | {YELLOW}💭{RESET}" if os.environ.get("THINKING") else ""
        
        if list_sessions:
            session_mode = f" | {YELLOW}Select{RESET}"
        elif continue_session:
            session_mode = f" | {GREEN}Continue{RESET}"
        else:
            session_mode = f" | {CYAN}New{RESET}"
            
        print(f"{BOLD}nanocode{RESET} | {DIM}{MODEL} | {os.getcwd()}{proxy_info}{thinking_info}{session_mode}{RESET}")
        print(f"{DIM}Shortcuts: Enter=submit | Alt+Enter=newline | Ctrl+C=clear input | Ctrl+D=exit | ESC=stop{RESET}")
        print(f"{DIM}Commands: /c [all|baseline|<id>] | /ca | /clear{RESET}")
        print(f"{DIM}Usage: nanocode (new) | nanocode -c (continue) | nanocode -l (select){RESET}\n")
        
        selected_session_id = None
        if list_sessions:
            selected_session_id = select_session_interactive()
            if not selected_session_id:
                print(f"{DIM}Exiting...{RESET}")
                return
        
        run_main_loop(continue_session, selected_session_id)
    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)

def select_session_interactive():
    """Display sessions and let user select one
    
    Returns:
        session_id: Selected session ID, or None if cancelled
    """
    session_manager = SessionManager(os.getcwd())
    sessions = session_manager.list_sessions()[:10]  # Limit to 10 most recent
    
    if not sessions:
        print(f"{YELLOW}⚠ No previous sessions found{RESET}")
        print(f"{DIM}Starting new session...{RESET}\n")
        return None
    
    print(f"{BOLD}📂 Recent Sessions:{RESET}\n")
    
    for i, sess_info in enumerate(sessions, 1):
        created = datetime.fromtimestamp(sess_info['metadata']['created_at']).strftime('%Y-%m-%d %H:%M')
        last_active = datetime.fromtimestamp(sess_info['metadata']['last_active']).strftime('%Y-%m-%d %H:%M')
        desc = sess_info['metadata'].get('description', '(no description)')
        
        # Git info
        git_commit = sess_info['metadata'].get('git_commit')
        git_branch = sess_info['metadata'].get('git_branch')
        git_dirty = sess_info['metadata'].get('git_dirty', False)
        
        print(f"{CYAN}{i}.{RESET} {BOLD}{sess_info['session_id']}{RESET}")
        print(f"   {desc}")
        
        git_info = ""
        if git_commit and git_branch:
            dirty_mark = f"{YELLOW}*{RESET}" if git_dirty else ""
            git_info = f" | Git: {git_branch}@{git_commit}{dirty_mark}"
        
        print(f"   Created: {created} | Last: {last_active} | {sess_info['message_count']} messages{git_info}\n")
    
    print(f"{DIM}Enter session number (1-{len(sessions)}), or press Enter for new session:{RESET}")
    
    try:
        choice = input(f"{BOLD}{BLUE}❯{RESET} ").strip()
        
        if not choice:
            # Empty input = new session
            return None
        
        try:
            idx = int(choice) - 1
            if 0 <= idx < len(sessions):
                return sessions[idx]['session_id']
            else:
                print(f"{RED}✗ Invalid number{RESET}")
                return None
        except ValueError:
            print(f"{RED}✗ Invalid input{RESET}")
            return None
    except (EOFError, KeyboardInterrupt):
        return None


def run_main_loop(continue_session=False, selected_session_id=None):
    # Initialize session manager
    session_manager = SessionManager(os.getcwd())
    
    # Load or create session based on parameters
    if selected_session_id:
        # Load specific session selected by user
        session = session_manager.load_session(selected_session_id)
        if session:
            git_info = ""
            git_commit = session.metadata.get('git_commit')
            git_branch = session.metadata.get('git_branch')
            if git_commit and git_branch:
                git_dirty = session.metadata.get('git_dirty', False)
                dirty_mark = f"{YELLOW}*{RESET}" if git_dirty else ""
                git_info = f" | Git: {git_branch}@{git_commit}{dirty_mark}"
            
            print(f"{GREEN}✓ Loaded session: {session.session_id}{RESET}")
            print(f"{DIM}  └─ {len(session.messages)} messages{git_info}{RESET}")
            
            # Check for conflicts
            conflicts = session.detect_conflicts()
            if conflicts:
                print(f"\n{YELLOW}⚠ File conflicts detected:{RESET}")
                for filepath in conflicts[:5]:
                    print(f"  - {filepath}")
                if len(conflicts) > 5:
                    print(f"  ... and {len(conflicts)-5} more")
                print(f"\n{DIM}These files have been modified outside this session.{RESET}")
                confirm = input(f"{BOLD}Continue anyway? (y/N/u=update): {RESET}").strip().lower()
                
                if confirm == 'u':
                    session.update_file_states()
                    session_manager.save_session()
                    print(f"{GREEN}✓ Updated file states{RESET}\n")
                elif confirm != 'y':
                    print(f"{DIM}Creating new session instead...{RESET}\n")
                    session_manager.create_session()
                else:
                    print()
            else:
                print()
        else:
            print(f"{RED}✗ Failed to load session{RESET}")
            print(f"{GREEN}✓ Creating new session instead{RESET}\n")
            session_manager.create_session()
    elif continue_session:
        # Continue last session
        last_session = session_manager.load_last_session()
        if last_session:
            git_info = ""
            git_commit = last_session.metadata.get('git_commit')
            git_branch = last_session.metadata.get('git_branch')
            if git_commit and git_branch:
                git_dirty = last_session.metadata.get('git_dirty', False)
                dirty_mark = f"{YELLOW}*{RESET}" if git_dirty else ""
                git_info = f" | Git: {git_branch}@{git_commit}{dirty_mark}"
            
            print(f"{GREEN}✓ Continued session: {last_session.session_id}{RESET}")
            print(f"{DIM}  └─ {len(last_session.messages)} messages{git_info}{RESET}")
            
            # Check for conflicts
            conflicts = last_session.detect_conflicts()
            if conflicts:
                print(f"\n{YELLOW}⚠ File conflicts detected:{RESET}")
                for filepath in conflicts[:5]:
                    print(f"  - {filepath}")
                if len(conflicts) > 5:
                    print(f"  ... and {len(conflicts)-5} more")
                print(f"\n{DIM}These files have been modified outside this session.{RESET}")
                confirm = input(f"{BOLD}Continue anyway? (y/N/u=update): {RESET}").strip().lower()
                
                if confirm == 'u':
                    last_session.update_file_states()
                    session_manager.save_session()
                    print(f"{GREEN}✓ Updated file states{RESET}\n")
                elif confirm != 'y':
                    print(f"{DIM}Creating new session instead...{RESET}\n")
                    session_manager.create_session()
                else:
                    print()
            else:
                print()
        else:
            # No previous session, create new one
            session_manager.create_session()
            print(f"{YELLOW}⚠ No previous session found{RESET}")
            print(f"{GREEN}✓ Created new session: {session_manager.current_session.session_id}{RESET}\n")
    else:
        # Always create new session by default
        # Try to detect parent from last session's latest checkpoint
        parent_checkpoint = None
        parent_session = None
        
        last_session = session_manager.load_last_session()
        if last_session:
            # Get the latest checkpoint from last session
            checkpoints = session_manager.checkpoint_manager.list_checkpoints(show_all=False)
            if checkpoints:
                parent_checkpoint = checkpoints[0][0]  # Latest checkpoint hash
                parent_session = last_session.session_id
        
        session_manager.create_session(
            parent_checkpoint=parent_checkpoint,
            parent_session=parent_session
        )
        
        git_info = ""
        git_commit = session_manager.current_session.metadata.get('git_commit')
        git_branch = session_manager.current_session.metadata.get('git_branch')
        if git_commit and git_branch:
            git_dirty = session_manager.current_session.metadata.get('git_dirty', False)
            dirty_mark = f"{YELLOW}*{RESET}" if git_dirty else ""
            git_info = f" | Git: {git_branch}@{git_commit}{dirty_mark}"
        
        if parent_checkpoint:
            print(f"{GREEN}✓ Created new session: {session_manager.current_session.session_id}{RESET}")
            print(f"{DIM}  └─ Branched from {parent_session[:8]}... @ {parent_checkpoint}{git_info}{RESET}\n")
        else:
            print(f"{GREEN}✓ Created new session: {session_manager.current_session.session_id}{RESET}")
            if git_info:
                print(f"{DIM}  └─{git_info}{RESET}\n")
            else:
                print()
    
    files_modified = set()
    auto_checkpoint = True
    
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    system_prompt = f"""Concise coding assistant. cwd: {os.getcwd()} Current time: {current_time}
IMPORTANT: When you don't have a tool for the task, ALWAYS try search_extension first before saying you can't do it.
Examples:
- User asks about GitHub repo → search_extension({{"query": "github documentation"}})
- User needs web data → search_extension({{"query": "web scraping"}})
- User needs API → search_extension({{"query": "api client"}})"""

    while True:
        try:
            print(f"{DIM}{'─'*80}{RESET}")
            user_input = read_multiline_input()
            print(f"{DIM}{'─'*80}{RESET}")
            
            if not user_input: continue
            if user_input in ("/q", "exit"):
                session_manager.save_session()
                break
            
            # Handle /clear command first (before /c to avoid conflict)
            if user_input == "/clear":
                # Save current session
                session_manager.save_session()
                
                # Get latest checkpoint from current session (if any)
                checkpoints = session_manager.checkpoint_manager.list_checkpoints(show_all=False)
                parent_checkpoint = checkpoints[0][0] if checkpoints else None
                parent_session = session_manager.current_session.session_id
                
                # Create new session branched from current
                session_manager.create_session(
                    parent_checkpoint=parent_checkpoint,
                    parent_session=parent_session
                )
                
                # Reset state
                files_modified.clear()
                
                print(f"{GREEN}✓ Started new session: {session_manager.current_session.session_id}{RESET}")
                if parent_checkpoint:
                    print(f"{DIM}  └─ Branched from {parent_session[:8]}... @ {parent_checkpoint}{RESET}")
                continue
            
            # Handle checkpoint commands
            if user_input.startswith("/checkpoint") or user_input.startswith("/c") or user_input == "/ca":
                parts = user_input.split()
                
                # /ca is shortcut for /c all
                if parts[0] == "/ca":
                    parts = ["/c", "all"]
                
                # /c without args defaults to list
                if len(parts) == 1 and parts[0] in ["/c", "/checkpoint"]:
                    parts.append("list")
                
                restored_messages = handle_checkpoint_command(parts, session_manager, files_modified)
                if restored_messages is not None:
                    # Restore conversation by replacing session messages
                    session_manager.current_session.messages = restored_messages
                    session_manager.save_session()
                continue
            
            # Add user message to current session
            session_manager.current_session.messages.append({"role": "user", "content": user_input})
            
            # Reset stop flag for new turn
            stop_flag = False
            
            # Track files modified in this turn
            files_modified_this_turn = set()
            
            while True:
                response = call_api(session_manager.current_session.messages, system_prompt)
                blocks = process_stream(response)
                if stop_flag: break
                
                tool_results = []
                for block in blocks:
                    if block["type"] == "tool_use":
                        name, args = block["name"], block["input"]
                        
                        # Save baseline BEFORE executing write/edit
                        if name in ['write', 'edit']:
                            filepath = args.get('path')
                            if filepath and is_file_in_project(filepath, session_manager.project_path):
                                session_manager.save_baseline_if_needed(filepath)
                        
                        preview = str(list(args.values())[0])[:50] if args else ""
                        print(f"\n{GREEN}⏺ {name}{RESET}({DIM}{preview}{RESET})")
                        
                        result = run_tool(name, args)
                        lines = result.split("\n")
                        prev = lines[0][:60] + ("..." if len(lines[0]) > 60 else "")
                        if len(lines) > 1: prev += f" +{len(lines)-1}"
                        print(f"  {DIM}⎿ {prev}{RESET}")
                        
                        # Track file modifications (only project files)
                        if name in ['write', 'edit']:
                            filepath = args.get('path')
                            if filepath and is_file_in_project(filepath, session_manager.project_path):
                                files_modified.add(filepath)
                                files_modified_this_turn.add(filepath)
                                session_manager.current_session.track_file_state(filepath)
                        
                        tool_results.append({"type": "tool_result", "tool_use_id": block["id"], "content": result})
                        
                        # Check stop_flag after each tool execution
                        if stop_flag:
                            print(f"{YELLOW}⚠ Tool execution stopped{RESET}")
                            break
                
                session_manager.current_session.messages.append({"role": "assistant", "content": blocks})
                if not tool_results or stop_flag: break
                session_manager.current_session.messages.append({"role": "user", "content": tool_results})
            
            # Auto checkpoint after AI work (if project files were modified)
            if auto_checkpoint and files_modified_this_turn:
                # files_modified_this_turn already filtered to project files only
                # Use parent_commit for first checkpoint of new session
                parent_commit = session_manager.parent_commit_for_next_checkpoint
                checkpoint_id = session_manager.checkpoint_manager.create_checkpoint(
                    f"Auto: {user_input[:50]}", 
                    list(files_modified_this_turn),
                    conversation_snapshot=session_manager.current_session.messages.copy(),
                    parent_commit=parent_commit
                )
                # Clear parent after first checkpoint
                if parent_commit:
                    session_manager.parent_commit_for_next_checkpoint = None
                
                if checkpoint_id:
                    # Generate summary using LLM with actual diff
                    print(f"{DIM}Generating checkpoint summary...{RESET}", end="", flush=True)
                    summary = summarize_changes(
                        user_input, 
                        files_modified_this_turn,
                        session_manager.checkpoint_manager,
                        checkpoint_id
                    )
                    print(f"\r{' ' * 40}\r", end="", flush=True)  # Clear the line
                    
                    # Update commit message with better summary (only if different from temp message)
                    temp_message = f"Auto: {user_input[:50]}"
                    if summary != user_input[:50] and summary != temp_message:
                        session_manager.checkpoint_manager._git_command(
                            "--git-dir", session_manager.checkpoint_manager.bare_repo,
                            "commit", "--amend", "-m", summary
                        )
                    
                    print(f"\n{YELLOW}📍 {checkpoint_id}: {summary}{RESET}")
                else:
                    # Checkpoint creation failed (e.g., no actual diff)
                    print(f"\n{DIM}(No project file changes to checkpoint){RESET}")
            
            # Auto-save session after each interaction
            session_manager.save_session()

            print()
        except EOFError:
            session_manager.save_session()
            break
        except Exception as e: print(f"{RED}⏺ Error: {e}{RESET}")

# ============================================================================
# Checkpoint & Session Management (Phase 1+2)
# ============================================================================

class CheckpointManager:
    """Manage checkpoints using shadow bare git repository with session isolation"""
    
    def __init__(self, project_path, session_id=None):
        self.project_path = project_path
        self.session_id = session_id
        self.nanocode_dir = os.path.join(project_path, ".nanocode")
        self.bare_repo = os.path.join(self.nanocode_dir, "checkpoint.git")
        self._init_bare_repo()
    
    def set_session(self, session_id):
        """Set current session for checkpoint operations"""
        self.session_id = session_id
    
    def _get_branch_name(self):
        """Get git branch name for current session"""
        if not self.session_id:
            return "main"
        return f"session_{self.session_id}"
    
    def _init_bare_repo(self):
        """Initialize shadow bare repository"""
        if not os.path.exists(self.bare_repo):
            os.makedirs(self.bare_repo, exist_ok=True)
            try:
                subprocess.run(
                    ["git", "init", "--bare", self.bare_repo],
                    capture_output=True, check=True
                )
            except (subprocess.CalledProcessError, FileNotFoundError):
                # Git not available, will handle gracefully
                pass
    
    def _git_command(self, *args, cwd=None):
        """Execute git command"""
        try:
            result = subprocess.run(
                ["git"] + list(args),
                cwd=cwd or self.project_path,
                capture_output=True,
                text=True,
                check=True
            )
            return result.stdout.strip()
        except (subprocess.CalledProcessError, FileNotFoundError) as e:
            return f"error: {e}"
    
    def save_file_to_blob(self, filepath):
        """Save file to git blob storage
        
        Returns:
            str: blob hash
        """
        try:
            result = subprocess.run(
                ["git", "--git-dir", self.bare_repo, "hash-object", "-w", filepath],
                capture_output=True, text=True, check=True
            )
            return result.stdout.strip()
        except Exception as e:
            return None
    
    def restore_file_from_blob(self, blob_hash, filepath):
        """Restore file from git blob storage"""
        try:
            content = subprocess.run(
                ["git", "--git-dir", self.bare_repo, "cat-file", "-p", blob_hash],
                capture_output=True, check=True
            ).stdout
            
            os.makedirs(os.path.dirname(filepath) or ".", exist_ok=True)
            with open(filepath, 'wb') as f:
                f.write(content)
            return True
        except Exception as e:
            return False
    
    def get_file_git_info(self, filepath):
        """Get file's git info from project .git
        
        Returns:
            dict: {"commit": "abc123", "has_changes": True/False} or None
        """
        try:
            # Check if file is tracked
            result = subprocess.run(
                ["git", "ls-files", "--error-unmatch", filepath],
                cwd=self.project_path,
                capture_output=True, text=True
            )
            
            if result.returncode != 0:
                return None  # Not tracked
            
            # Get last commit for this file
            commit = subprocess.run(
                ["git", "log", "-1", "--format=%H", "--", filepath],
                cwd=self.project_path,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            # Check if file has local changes
            diff = subprocess.run(
                ["git", "diff", "HEAD", "--", filepath],
                cwd=self.project_path,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            return {
                "commit": commit,
                "has_changes": bool(diff)
            }
        except Exception as e:
            return None
    
    def create_checkpoint(self, message, files_changed, conversation_snapshot=None, parent_commit=None):
        """Create a checkpoint on current session's branch
        
        Args:
            message: Commit message
            files_changed: List of modified files
            conversation_snapshot: Conversation state to save
            parent_commit: Parent commit hash to branch from (for new sessions)
        """
        print(f"{DIM}[LOG] create_checkpoint: files_changed={files_changed}{RESET}", flush=True)
        if not files_changed or not self.session_id:
            return None
        
        branch_name = self._get_branch_name()
        
        # Save conversation snapshot
        if conversation_snapshot:
            snapshot_file = os.path.join(self.nanocode_dir, "conversation_snapshots.json")
            snapshots = {}
            if os.path.exists(snapshot_file):
                with open(snapshot_file, 'r') as f:
                    snapshots = json.load(f)
        
        # Create temp worktree for this session
        temp_worktree = os.path.join(self.nanocode_dir, f"temp_worktree_{self.session_id}")
        
        try:
            # Check if branch exists
            branch_exists = self._git_command("--git-dir", self.bare_repo, "rev-parse", "--verify", branch_name)
            
            if not branch_exists or branch_exists.startswith("error"):
                # Create new branch
                os.makedirs(temp_worktree, exist_ok=True)
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "config", "core.bare", "false")
                
                # If parent_commit specified, branch from it
                if parent_commit:
                    # Create branch from parent commit
                    self._git_command("--git-dir", self.bare_repo, "branch", branch_name, parent_commit)
                    self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "checkout", branch_name, "-f")
                else:
                    # Create orphan branch (no parent)
                    self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "checkout", "--orphan", branch_name)
                
                # Copy files to temp worktree
                for filepath in files_changed:
                    print(f"{DIM}[LOG] checkpoint copying: {filepath}{RESET}", flush=True)
                    if os.path.exists(filepath):
                        file_size = os.path.getsize(filepath)
                        print(f"{DIM}[LOG] source file exists: {filepath} ({file_size} bytes){RESET}", flush=True)
                        # Convert absolute path to relative path
                        if os.path.isabs(filepath):
                            rel_filepath = os.path.relpath(filepath, self.project_path)
                        else:
                            rel_filepath = filepath
                        dest = os.path.join(temp_worktree, rel_filepath)
                        os.makedirs(os.path.dirname(dest), exist_ok=True)
                        with open(filepath, 'rb') as src, open(dest, 'wb') as dst:
                            content = src.read()
                            dst.write(content)
                            print(f"{DIM}[LOG] copied to temp_worktree: {dest} ({len(content)} bytes){RESET}", flush=True)
                    else:
                        print(f"{DIM}[LOG] source file NOT exists: {filepath}{RESET}", flush=True)
                
                # Commit
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "add", "-A")
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, 
                                "commit", "-m", message, "--allow-empty")
                
                commit_hash = self._git_command("--git-dir", self.bare_repo, "rev-parse", "HEAD")
                checkpoint_id = commit_hash[:8] if commit_hash and not commit_hash.startswith("error") else None
                
                # Save conversation snapshot with checkpoint_id
                if checkpoint_id and conversation_snapshot:
                    snapshots[checkpoint_id] = conversation_snapshot
                    with open(snapshot_file, 'w') as f:
                        json.dump(snapshots, f, indent=2)
                
                return checkpoint_id
            else:
                # Branch exists, checkout and commit
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "checkout", branch_name, "-f")
                
                # Update temp worktree
                for filepath in files_changed:
                    print(f"{DIM}[LOG] checkpoint updating: {filepath}{RESET}", flush=True)
                    if os.path.exists(filepath):
                        file_size = os.path.getsize(filepath)
                        print(f"{DIM}[LOG] source file exists: {filepath} ({file_size} bytes){RESET}", flush=True)
                        # Convert absolute path to relative path
                        if os.path.isabs(filepath):
                            rel_filepath = os.path.relpath(filepath, self.project_path)
                        else:
                            rel_filepath = filepath
                        dest = os.path.join(temp_worktree, rel_filepath)
                        os.makedirs(os.path.dirname(dest), exist_ok=True)
                        with open(filepath, 'rb') as src, open(dest, 'wb') as dst:
                            content = src.read()
                            dst.write(content)
                            print(f"{DIM}[LOG] copied to temp_worktree: {dest} ({len(content)} bytes){RESET}", flush=True)
                    else:
                        print(f"{DIM}[LOG] source file NOT exists: {filepath}{RESET}", flush=True)
                
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "add", "-A")
                self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, 
                                "commit", "-m", message, "--allow-empty")
                
                commit_hash = self._git_command("--git-dir", self.bare_repo, "rev-parse", "HEAD")
                checkpoint_id = commit_hash[:8] if commit_hash and not commit_hash.startswith("error") else None
                
                # Save conversation snapshot with checkpoint_id
                if checkpoint_id and conversation_snapshot:
                    snapshots[checkpoint_id] = conversation_snapshot
                    with open(snapshot_file, 'w') as f:
                        json.dump(snapshots, f, indent=2)
                
                return checkpoint_id
        except Exception as e:
            return None
    
    def list_checkpoints(self, limit=10, show_all=False):
        """List recent checkpoints for current session
        
        Args:
            limit: Maximum number of checkpoints to show
            show_all: If True, show all sessions; if False, only show current session
        """
        if not self.session_id and not show_all:
            return []
        
        try:
            if show_all:
                # Show all branches
                args = ["--git-dir", self.bare_repo, "log", f"--max-count={limit}", "--oneline", "--all"]
            else:
                # Show only current session's branch
                branch_name = self._get_branch_name()
                args = ["--git-dir", self.bare_repo, "log", f"--max-count={limit}", "--oneline", branch_name]
            
            log = self._git_command(*args)
            if log and not log.startswith("error"):
                return [line.split(" ", 1) for line in log.split("\n") if line]
            return []
        except:
            return []
    
    def restore_checkpoint(self, checkpoint_id, session_baseline_files):
        """Restore files to checkpoint state and reset current session's branch
        
        This properly handles files that were added after the checkpoint by:
        1. Restoring files that exist in checkpoint
        2. Restoring files to baseline if they don't exist in checkpoint
        
        Args:
            checkpoint_id: Checkpoint hash to restore to
            session_baseline_files: Dict of baseline file states from session
        
        Returns:
            tuple: (success: bool, conversation_snapshot: dict or None)
        """
        if not self.session_id:
            return False, None
        
        branch_name = self._get_branch_name()
        temp_worktree = os.path.join(self.nanocode_dir, f"temp_worktree_{self.session_id}")
        
        try:
            # Checkout branch first
            self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree, "checkout", branch_name, "-f")
            
            # Reset branch to checkpoint (discards future commits on this branch)
            self._git_command("--git-dir", self.bare_repo, "reset", "--hard", checkpoint_id)
            
            # Checkout to temp worktree
            self._git_command("--git-dir", self.bare_repo, "--work-tree", temp_worktree,
                            "checkout", checkpoint_id, "-f")
            
            # Get list of files in checkpoint
            files_in_checkpoint_str = self._git_command(
                "--git-dir", self.bare_repo,
                "ls-tree", "-r", "--name-only", checkpoint_id
            )
            
            files_in_checkpoint = set()
            if files_in_checkpoint_str and not files_in_checkpoint_str.startswith("error"):
                files_in_checkpoint = set(f for f in files_in_checkpoint_str.split('\n') if f.strip())
            
            print(f"\n{DIM}Files in checkpoint: {len(files_in_checkpoint)}{RESET}")
            print(f"{DIM}Files modified in session: {len(session_baseline_files)}{RESET}\n")
            
            # Process each file that was modified in this session
            for filepath, baseline_source in session_baseline_files.items():
                # Convert to relative path for comparison
                if os.path.isabs(filepath):
                    rel_filepath = os.path.relpath(filepath, self.project_path)
                else:
                    rel_filepath = filepath
                
                # Normalize path: remove leading ./
                normalized_rel_path = rel_filepath.lstrip('./')
                
                if normalized_rel_path in files_in_checkpoint:
                    # File exists in checkpoint - restore from checkpoint
                    src = os.path.join(temp_worktree, normalized_rel_path)
                    dest = os.path.join(self.project_path, normalized_rel_path)
                    
                    dest_dir = os.path.dirname(dest)
                    if dest_dir:
                        os.makedirs(dest_dir, exist_ok=True)
                    
                    with open(src, 'rb') as s, open(dest, 'wb') as d:
                        d.write(s.read())
                    
                    print(f"  {GREEN}✓{RESET} {filepath} {DIM}(from checkpoint){RESET}")
                else:
                    # File doesn't exist in checkpoint - restore to baseline
                    abs_filepath = filepath if os.path.isabs(filepath) else os.path.join(self.project_path, filepath)
                    
                    if baseline_source["type"] == "git":
                        # Restore from project .git (use normalized path for git)
                        result = subprocess.run(
                            ["git", "checkout", baseline_source["commit"], "--", normalized_rel_path],
                            cwd=self.project_path,
                            capture_output=True, text=True
                        )
                        if result.returncode == 0:
                            print(f"  {CYAN}↺{RESET} {filepath} {DIM}(to baseline: git {baseline_source['commit'][:8]}){RESET}")
                    
                    elif baseline_source["type"] == "blob":
                        # Restore from blob
                        if self.restore_file_from_blob(baseline_source["hash"], abs_filepath):
                            print(f"  {CYAN}↺{RESET} {filepath} {DIM}(to baseline: blob {baseline_source['hash'][:8]}){RESET}")
                    
                    elif baseline_source["type"] == "new":
                        # Delete new file
                        if os.path.exists(abs_filepath):
                            os.remove(abs_filepath)
                            print(f"  {YELLOW}✗{RESET} {filepath} {DIM}(deleted: was added after checkpoint){RESET}")
            
            # Load conversation snapshot
            snapshot_file = os.path.join(self.nanocode_dir, "conversation_snapshots.json")
            conversation_snapshot = None
            if os.path.exists(snapshot_file):
                with open(snapshot_file, 'r') as f:
                    snapshots = json.load(f)
                    conversation_snapshot = snapshots.get(checkpoint_id)
            
            return True, conversation_snapshot
        except Exception as e:
            print(f"{RED}Error during restore: {e}{RESET}")
            return False, None


class Session:
    """Represents a conversation session"""
    
    def __init__(self, session_id=None):
        self.session_id = session_id or self._generate_session_id()
        self.messages = []
        self.file_states = {}
        self.baseline_files = {}  # Track original file versions for rollback
        self.metadata = {
            'created_at': time.time(),
            'last_active': time.time(),
            'description': '',
            'cwd': os.getcwd(),
            'parent_checkpoint': None,  # Track where this session branched from
            'parent_session': None,     # Track which session it branched from
            'git_commit': None,         # Project .git commit hash when session started
            'git_branch': None,         # Project .git branch when session started
            'git_dirty': False,         # Whether project had uncommitted changes
        }
    
    def _generate_session_id(self):
        """Generate unique session ID"""
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        random_suffix = ''.join(random.choices('0123456789abcdef', k=4))
        return f"{timestamp}_{random_suffix}"
    
    def _get_project_git_info(self):
        """Get current project git state"""
        try:
            cwd = self.metadata.get('cwd', os.getcwd())
            
            # Get current commit
            commit = subprocess.run(
                ["git", "rev-parse", "HEAD"],
                cwd=cwd,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            # Get current branch
            branch = subprocess.run(
                ["git", "rev-parse", "--abbrev-ref", "HEAD"],
                cwd=cwd,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            # Check if dirty
            status = subprocess.run(
                ["git", "status", "--porcelain"],
                cwd=cwd,
                capture_output=True, text=True, check=True
            ).stdout.strip()
            
            return {
                'git_commit': commit[:8],
                'git_branch': branch,
                'git_dirty': bool(status)
            }
        except:
            return None
    
    def capture_git_state(self):
        """Capture current project git state into metadata"""
        git_info = self._get_project_git_info()
        if git_info:
            self.metadata.update(git_info)
    
    def track_file_state(self, filepath):
        """Track file state for conflict detection"""
        if os.path.exists(filepath):
            with open(filepath, 'rb') as f:
                content = f.read()
                file_hash = hashlib.md5(content).hexdigest()
                self.file_states[filepath] = {
                    'hash': file_hash,
                    'mtime': os.path.getmtime(filepath),
                    'size': len(content)
                }
    
    def detect_conflicts(self):
        """Detect if tracked files have been modified outside this session
        
        Returns:
            list: List of conflicted file paths
        """
        conflicts = []
        
        for filepath, saved_state in self.file_states.items():
            if os.path.exists(filepath):
                with open(filepath, 'rb') as f:
                    content = f.read()
                    current_hash = hashlib.md5(content).hexdigest()
                    current_mtime = os.path.getmtime(filepath)
                    
                    # Check if file has changed
                    if (current_hash != saved_state['hash'] or 
                        current_mtime != saved_state['mtime']):
                        conflicts.append(filepath)
            else:
                # File was deleted
                conflicts.append(f"{filepath} (deleted)")
        
        return conflicts
    
    def update_file_states(self):
        """Update all tracked file states to current state"""
        for filepath in list(self.file_states.keys()):
            if os.path.exists(filepath):
                self.track_file_state(filepath)
            else:
                # Remove deleted files from tracking
                del self.file_states[filepath]
    
    def to_dict(self):
        """Serialize to dict"""
        return {
            'session_id': self.session_id,
            'messages': self.messages,
            'file_states': self.file_states,
            'baseline_files': self.baseline_files,
            'metadata': self.metadata
        }
    
    @staticmethod
    def from_dict(data):
        """Deserialize from dict"""
        session = Session(session_id=data['session_id'])
        session.messages = data.get('messages', [])
        session.file_states = data.get('file_states', {})
        session.baseline_files = data.get('baseline_files', {})
        session.metadata = data.get('metadata', {})
        return session


class SessionManager:
    """Manage multiple sessions"""
    
    def __init__(self, project_path):
        self.project_path = project_path
        self.sessions_dir = os.path.join(project_path, ".nanocode", "sessions")
        self.current_session = None
        self.checkpoint_manager = CheckpointManager(project_path)
        self.parent_commit_for_next_checkpoint = None  # Track parent for first checkpoint
        os.makedirs(self.sessions_dir, exist_ok=True)
    
    def save_baseline_if_needed(self, filepath):
        """Save file's baseline version before first modification
        
        This is called before write/edit operations to preserve the original state.
        """
        if not self.current_session:
            return
        
        # Already saved
        if filepath in self.current_session.baseline_files:
            return
        
        print(f"{DIM}[LOG] Saving baseline for: {filepath}{RESET}", flush=True)
        
        # Get file info from project .git
        git_info = self.checkpoint_manager.get_file_git_info(filepath)
        
        if git_info:
            # File is tracked by project .git
            if git_info["has_changes"]:
                # Has local changes - save to blob
                blob_hash = self.checkpoint_manager.save_file_to_blob(filepath)
                if blob_hash:
                    self.current_session.baseline_files[filepath] = {
                        "type": "blob",
                        "hash": blob_hash
                    }
                    print(f"{DIM}[LOG] Saved dirty file to blob: {blob_hash[:8]}{RESET}", flush=True)
            else:
                # Clean - just record commit
                self.current_session.baseline_files[filepath] = {
                    "type": "git",
                    "commit": git_info["commit"]
                }
                print(f"{DIM}[LOG] Recorded git commit: {git_info['commit'][:8]}{RESET}", flush=True)
        else:
            # Untracked file or no .git
            if os.path.exists(filepath):
                # Save existing untracked file to blob
                blob_hash = self.checkpoint_manager.save_file_to_blob(filepath)
                if blob_hash:
                    self.current_session.baseline_files[filepath] = {
                        "type": "blob",
                        "hash": blob_hash
                    }
                    print(f"{DIM}[LOG] Saved untracked file to blob: {blob_hash[:8]}{RESET}", flush=True)
            else:
                # New file - mark as new
                self.current_session.baseline_files[filepath] = {
                    "type": "new"
                }
                print(f"{DIM}[LOG] Marked as new file{RESET}", flush=True)
        
        # Auto-save session to persist baseline_files
        self.save_session()
    
    def restore_baseline(self):
        """Restore all files to their baseline state
        
        Returns:
            bool: Success status
        """
        if not self.current_session:
            return False
        
        if not self.current_session.baseline_files:
            print(f"{YELLOW}⚠ No baseline files to restore{RESET}")
            return False
        
        print(f"\n{BOLD}Restoring {len(self.current_session.baseline_files)} files to baseline...{RESET}\n")
        
        success_count = 0
        for filepath, source in self.current_session.baseline_files.items():
            try:
                # Normalize to absolute path
                abs_filepath = filepath if os.path.isabs(filepath) else os.path.join(self.project_path, filepath)
                # Get relative path for git operations
                rel_filepath = os.path.relpath(abs_filepath, self.project_path)
                
                if source["type"] == "git":
                    # Restore from project .git (use relative path)
                    result = subprocess.run(
                        ["git", "checkout", source["commit"], "--", rel_filepath],
                        cwd=self.project_path,
                        capture_output=True, text=True
                    )
                    if result.returncode == 0:
                        print(f"  {GREEN}✓{RESET} {filepath} {DIM}(from git {source['commit'][:8]}){RESET}")
                        success_count += 1
                    else:
                        print(f"  {RED}✗{RESET} {filepath} {DIM}(git checkout failed){RESET}")
                
                elif source["type"] == "blob":
                    # Restore from checkpoint.git blob (use absolute path)
                    if self.checkpoint_manager.restore_file_from_blob(source["hash"], abs_filepath):
                        print(f"  {GREEN}✓{RESET} {filepath} {DIM}(from blob {source['hash'][:8]}){RESET}")
                        success_count += 1
                    else:
                        print(f"  {RED}✗{RESET} {filepath} {DIM}(blob restore failed){RESET}")
                
                elif source["type"] == "new":
                    # Delete new file (use absolute path)
                    if os.path.exists(abs_filepath):
                        os.remove(abs_filepath)
                        print(f"  {GREEN}✓{RESET} {filepath} {DIM}(deleted new file){RESET}")
                        success_count += 1
                    else:
                        print(f"  {DIM}○{RESET} {filepath} {DIM}(already deleted){RESET}")
                        success_count += 1
            
            except Exception as e:
                print(f"  {RED}✗{RESET} {filepath} {DIM}(error: {e}){RESET}")
        
        print(f"\n{GREEN}✓ Restored {success_count}/{len(self.current_session.baseline_files)} files{RESET}")
        return success_count > 0
    
    def create_session(self, description="", parent_checkpoint=None, parent_session=None):
        """Create new session
        
        Args:
            description: Session description
            parent_checkpoint: Checkpoint ID this session branches from
            parent_session: Session ID this session branches from
        """
        session = Session()
        session.metadata['description'] = description
        session.metadata['parent_checkpoint'] = parent_checkpoint
        session.metadata['parent_session'] = parent_session
        # Capture project git state
        session.capture_git_state()
        self.current_session = session
        # Set checkpoint manager to use this session
        self.checkpoint_manager.set_session(session.session_id)
        # Store parent commit for first checkpoint
        self.parent_commit_for_next_checkpoint = parent_checkpoint
        self.save_session()
        return session
    
    def save_session(self):
        """Save current session to disk"""
        if not self.current_session:
            return
        
        self.current_session.metadata['last_active'] = time.time()
        session_file = os.path.join(
            self.sessions_dir,
            f"{self.current_session.session_id}.json"
        )
        
        with open(session_file, 'w') as f:
            json.dump(self.current_session.to_dict(), f, indent=2)
    
    def load_session(self, session_id):
        """Load session from disk"""
        session_file = os.path.join(self.sessions_dir, f"{session_id}.json")
        
        if not os.path.exists(session_file):
            return None
        
        with open(session_file, 'r') as f:
            data = json.load(f)
        
        session = Session.from_dict(data)
        self.current_session = session
        # Set checkpoint manager to use this session
        self.checkpoint_manager.set_session(session.session_id)
        return session
    
    def list_sessions(self):
        """List all sessions"""
        sessions = []
        
        if not os.path.exists(self.sessions_dir):
            return sessions
        
        for filename in os.listdir(self.sessions_dir):
            if filename.endswith('.json'):
                filepath = os.path.join(self.sessions_dir, filename)
                try:
                    with open(filepath, 'r') as f:
                        data = json.load(f)
                        sessions.append({
                            'session_id': data['session_id'],
                            'metadata': data['metadata'],
                            'message_count': len(data.get('messages', [])),
                        })
                except:
                    pass
        
        return sorted(sessions, key=lambda x: x['metadata'].get('last_active', 0), reverse=True)
    
    def load_last_session(self):
        """Load the most recent session"""
        sessions = self.list_sessions()
        if sessions:
            return self.load_session(sessions[0]['session_id'])
        return None


def handle_checkpoint_command(parts, session_manager, files_modified):
    """Handle /checkpoint or /c commands
    
    Returns:
        messages: New messages list if conversation was restored, None otherwise
    """
    # Default to list if no subcommand
    if len(parts) < 2:
        parts.append("list")
    
    cmd = parts[1]
    
    # If cmd looks like a commit hash (7-8 hex chars), treat as restore
    if len(cmd) >= 7 and len(cmd) <= 8 and all(c in '0123456789abcdef' for c in cmd.lower()):
        cmd = "restore"
        checkpoint_id = parts[1]
    else:
        checkpoint_id = None
    
    if cmd == "baseline" or cmd == "base":
        # Restore all files to baseline (session start state)
        if not session_manager.current_session.baseline_files:
            print(f"{YELLOW}⚠ No baseline: no files have been modified in this session{RESET}")
            return None
        
        print(f"{YELLOW}⚠ This will restore all modified files to their original state{RESET}")
        print(f"{YELLOW}⚠ Files to restore: {len(session_manager.current_session.baseline_files)}{RESET}")
        
        # Show files
        print(f"\n{DIM}Files:{RESET}")
        for filepath in list(session_manager.current_session.baseline_files.keys())[:10]:
            print(f"  {DIM}• {filepath}{RESET}")
        if len(session_manager.current_session.baseline_files) > 10:
            print(f"  {DIM}... and {len(session_manager.current_session.baseline_files) - 10} more{RESET}")
        
        confirm = input(f"\n{BOLD}Continue? (y/N): {RESET}").strip().lower()
        
        if confirm != 'y':
            print(f"{DIM}Cancelled{RESET}")
            return None
        
        success = session_manager.restore_baseline()
        if success:
            # Clear conversation
            print(f"{GREEN}✓ Conversation cleared{RESET}")
            return []
        else:
            return None
    
    elif cmd == "list" or cmd == "all" or cmd == "--all":
        show_all = (cmd == "all" or "--all" in parts)
        
        if show_all:
            # Show git graph of all branches
            print(f"\n{BOLD}📍 Checkpoint Graph:{RESET}\n")
            
            # Use git log --graph --all to show the tree
            # Format: %h = short hash, %d = ref names, %s = subject, %ar = relative date
            graph_output = session_manager.checkpoint_manager._git_command(
                "--git-dir", session_manager.checkpoint_manager.bare_repo,
                "log", "--graph", "--all", "--oneline", 
                "--format=%h %s (%ar)", "-20"
            )
            
            if graph_output and not graph_output.startswith("error"):
                # Also get branch info for each commit
                branches_output = session_manager.checkpoint_manager._git_command(
                    "--git-dir", session_manager.checkpoint_manager.bare_repo,
                    "branch", "-a", "--contains"
                )
                
                # Parse and display
                for line in graph_output.split('\n'):
                    if not line.strip():
                        continue
                    
                    # Extract commit hash
                    match = re.search(r'\b([0-9a-f]{7,8})\b', line)
                    if match:
                        commit_hash = match.group(1)
                        
                        # Get branches containing this commit
                        branch_info = session_manager.checkpoint_manager._git_command(
                            "--git-dir", session_manager.checkpoint_manager.bare_repo,
                            "branch", "-a", "--contains", commit_hash
                        )
                        
                        # Extract session names from branches
                        session_names = []
                        if branch_info and not branch_info.startswith("error"):
                            for branch_line in branch_info.split('\n'):
                                branch_line = branch_line.strip().lstrip('* ')
                                if branch_line.startswith('session_'):
                                    # Shorten session name: session_20260130_103323_f7 -> s:20260130_103323_f7
                                    session_short = 's:' + branch_line[8:]  # Remove 'session_' prefix
                                    session_names.append(session_short)
                        
                        # Highlight commit hash
                        line = line.replace(commit_hash, f"{CYAN}{commit_hash}{RESET}")
                        
                        # Add session info if found
                        if session_names:
                            # Insert session names after commit hash
                            session_str = f"{GREEN}[{', '.join(session_names[:2])}]{RESET}"
                            line = line.replace(commit_hash + f"{RESET}", commit_hash + f"{RESET} {session_str}")
                    
                    
                    print(f"  {line}")
                print()
            else:
                print(f"{DIM}No checkpoints yet{RESET}\n")
            
            print(f"{DIM}Restore: /c <hash>{RESET}")
            return None
        else:
            # Show current session's checkpoints
            checkpoints = session_manager.checkpoint_manager.list_checkpoints(show_all=False)
            if not checkpoints:
                print(f"{DIM}No checkpoints yet{RESET}")
                return None
            
            print(f"\n{BOLD}📍 Checkpoints:{RESET}\n")
            
            # Get checkpoint details from git log with timestamp
            for commit_hash, message in checkpoints[:10]:  # Show first 10 (already newest first from git log)
                # Try to get timestamp from git
                timestamp_str = session_manager.checkpoint_manager._git_command(
                    "--git-dir", session_manager.checkpoint_manager.bare_repo,
                    "log", "-1", "--format=%ar", commit_hash
                )
                if timestamp_str.startswith("error"):
                    timestamp_str = ""
                
                # Get modified files
                files_str = session_manager.checkpoint_manager._git_command(
                    "--git-dir", session_manager.checkpoint_manager.bare_repo,
                    "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash
                )
                files = []
                if files_str and not files_str.startswith("error"):
                    files = [f.strip() for f in files_str.split('\n') if f.strip()]
                
                # Format: hash | time ago | message
                time_part = f"{DIM}{timestamp_str}{RESET}" if timestamp_str else ""
                print(f"  {CYAN}{commit_hash}{RESET} {time_part}")
                print(f"  {DIM}└─{RESET} {message}")
                if files:
                    files_display = ", ".join(files[:3])
                    if len(files) > 3:
                        files_display += f" +{len(files)-3} more"
                    print(f"  {DIM}   Files: {files_display}{RESET}")
                print()
            
            print(f"{DIM}Tip: Use '/c all' or '/ca' to see git graph{RESET}")
            print(f"{DIM}Restore: /c <hash>{RESET}")
            return None
    
    elif cmd == "restore":
        if not checkpoint_id:
            print(f"{RED}Usage: /c <checkpoint_id>{RESET}")
            return None
        
        print(f"{YELLOW}⚠ This will restore files AND conversation to checkpoint {checkpoint_id}{RESET}")
        print(f"{YELLOW}⚠ Future checkpoints will be discarded from history{RESET}")
        confirm = input(f"{BOLD}Continue? (y/N): {RESET}").strip().lower()
        
        if confirm != 'y':
            print(f"{DIM}Cancelled{RESET}")
            return None
        
        success, conversation_snapshot = session_manager.checkpoint_manager.restore_checkpoint(
            checkpoint_id,
            session_manager.current_session.baseline_files
        )
        if success:
            print(f"{GREEN}✓ Restored files to checkpoint {checkpoint_id}{RESET}")
            if conversation_snapshot:
                print(f"{GREEN}✓ Restored conversation ({len(conversation_snapshot)} messages){RESET}")
                return conversation_snapshot
            else:
                print(f"{YELLOW}⚠ No conversation snapshot found for this checkpoint{RESET}")
                return None
        else:
            print(f"{RED}✗ Failed to restore checkpoint{RESET}")
            return None
    
    else:
        print(f"{RED}Unknown command: {cmd}{RESET}")
        return None


if __name__ == "__main__":
    main()
