package controller

import (
	"crypto/rand"
	"encoding/json"
	"fmt"
	"log"
	"net"
	"net/http"
	"sync"
	"time"

	"github.com/gorilla/mux"
	"github.com/gorilla/websocket"
	"golang.org/x/crypto/curve25519"
	"tailscale-clone/internal/types"
)

// Controller manages the network and node registration
type Controller struct {
	port         int
	dataDir      string
	webPort      int
	network      *types.Network
	nodes        map[string]*types.Node
	clients      map[string]*ClientConnection
	subnetRoutes map[string]*types.SubnetRoute // nodeID -> subnet route
	mu           sync.RWMutex
	upgrader     websocket.Upgrader
	server       *http.Server
	webServer    *http.Server
}

// ClientConnection represents a connected client
type ClientConnection struct {
	NodeID   string
	Conn     *websocket.Conn
	Send     chan []byte
	LastPing time.Time
}

// New creates a new controller
func New(port int, dataDir string, webPort int) (*Controller, error) {
	// Generate network keys
	privateKey, publicKey, err := generateKeyPair()
	if err != nil {
		return nil, fmt.Errorf("failed to generate keys: %v", err)
	}

	network := &types.Network{
		ID:         types.GenerateID(),
		Name:       "Tailscale Clone Network",
		CIDR:       "10.0.0.0/24",
		PrivateKey: privateKey,
		PublicKey:  publicKey,
		CreatedAt:  time.Now(),
		UpdatedAt:  time.Now(),
	}

	c := &Controller{
		port:         port,
		dataDir:      dataDir,
		webPort:      webPort,
		network:      network,
		nodes:        make(map[string]*types.Node),
		clients:      make(map[string]*ClientConnection),
		subnetRoutes: make(map[string]*types.SubnetRoute),
		upgrader: websocket.Upgrader{
			CheckOrigin: func(r *http.Request) bool {
				return true // Allow all origins for simplicity
			},
		},
	}

	return c, nil
}

// Start starts the controller
func (c *Controller) Start() error {
	// Start API server
	go c.startAPIServer()
	
	// Start web interface
	go c.startWebServer()
	
	// Start cleanup goroutine
	go c.cleanupRoutine()

	return nil
}

// Stop stops the controller
func (c *Controller) Stop() {
	if c.server != nil {
		c.server.Close()
	}
	if c.webServer != nil {
		c.webServer.Close()
	}
}

func (c *Controller) startAPIServer() {
	router := mux.NewRouter()
	
	// WebSocket endpoint for client connections
	router.HandleFunc("/ws", c.handleWebSocket)
	
	// REST API endpoints
	router.HandleFunc("/api/nodes", c.handleGetNodes).Methods("GET")
	router.HandleFunc("/api/network", c.handleGetNetwork).Methods("GET")
	router.HandleFunc("/api/nodes/{id}", c.handleGetNode).Methods("GET")
	router.HandleFunc("/api/nodes/{id}", c.handleDeleteNode).Methods("DELETE")
	router.HandleFunc("/api/subnets", c.handleGetSubnets).Methods("GET")

	c.server = &http.Server{
		Addr:    fmt.Sprintf(":%d", c.port),
		Handler: router,
	}

	log.Printf("API server starting on port %d", c.port)
	if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
		log.Printf("API server error: %v", err)
	}
}

func (c *Controller) startWebServer() {
	router := mux.NewRouter()
	
	// Serve static files
	router.PathPrefix("/").Handler(http.FileServer(http.Dir("web")))
	
	c.webServer = &http.Server{
		Addr:    fmt.Sprintf(":%d", c.webPort),
		Handler: router,
	}

	log.Printf("Web server starting on port %d", c.webPort)
	if err := c.webServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
		log.Printf("Web server error: %v", err)
	}
}

func (c *Controller) handleWebSocket(w http.ResponseWriter, r *http.Request) {
	conn, err := c.upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Printf("WebSocket upgrade failed: %v", err)
		return
	}

	client := &ClientConnection{
		Conn:     conn,
		Send:     make(chan []byte, 256),
		LastPing: time.Now(),
	}

	// Start goroutines for this client
	go c.readPump(client)
	go c.writePump(client)
}

func (c *Controller) readPump(client *ClientConnection) {
	defer func() {
		c.unregisterClient(client)
		client.Conn.Close()
	}()

	for {
		_, message, err := client.Conn.ReadMessage()
		if err != nil {
			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
				log.Printf("WebSocket read error: %v", err)
			}
			break
		}

		c.handleMessage(client, message)
	}
}

func (c *Controller) writePump(client *ClientConnection) {
	ticker := time.NewTicker(54 * time.Second)
	defer func() {
		ticker.Stop()
		client.Conn.Close()
	}()

	for {
		select {
		case message, ok := <-client.Send:
			if !ok {
				client.Conn.WriteMessage(websocket.CloseMessage, []byte{})
				return
			}

			w, err := client.Conn.NextWriter(websocket.TextMessage)
			if err != nil {
				return
			}
			w.Write(message)

			if err := w.Close(); err != nil {
				return
			}
		case <-ticker.C:
			if err := client.Conn.WriteMessage(websocket.TextMessage, []byte("ping")); err != nil {
				return
			}
		}
	}
}

func (c *Controller) handleMessage(client *ClientConnection, message []byte) {
	var msg types.Message
	if err := json.Unmarshal(message, &msg); err != nil {
		log.Printf("Failed to unmarshal message: %v", err)
		return
	}

	switch msg.Type {
	case types.MessageTypeRegister:
		c.handleRegister(client, msg.Payload)
	case types.MessageTypeUpdate:
		c.handleUpdate(client, msg.Payload)
	case types.MessageTypePing:
		c.handlePing(client)
	case types.MessageTypeSubnetAdvert:
		c.handleSubnetAdvert(client, msg.Payload)
	}
}

func (c *Controller) handleRegister(client *ClientConnection, payload interface{}) {
	var req types.RegisterRequest
	data, _ := json.Marshal(payload)
	json.Unmarshal(data, &req)

	// Generate node ID and IP
	nodeID := types.GenerateID()
	ip := c.allocateIP()

	node := &types.Node{
		ID:             nodeID,
		Name:           req.NodeName,
		PublicKey:      req.PublicKey,
		IP:             ip,
		LastSeen:       time.Now(),
		IsOnline:       true,
		IsSubnetRouter: req.IsSubnetRouter,
		Subnets:        req.Subnets,
	}

	c.mu.Lock()
	c.nodes[nodeID] = node
	client.NodeID = nodeID
	c.clients[nodeID] = client

	// Handle subnet routing
	if req.IsSubnetRouter && len(req.Subnets) > 0 {
		subnetRoute := &types.SubnetRoute{
			NodeID:   nodeID,
			NodeName: req.NodeName,
			Subnets:  req.Subnets,
			Active:   true,
		}
		c.subnetRoutes[nodeID] = subnetRoute
		log.Printf("Subnet router registered: %s advertising %v", req.NodeName, req.Subnets)
	}
	c.mu.Unlock()

	// Send response
	peers := c.getPeersForNode(nodeID)
	subnetRoutes := c.getAllSubnetRoutes()
	response := types.RegisterResponse{
		NodeID:       nodeID,
		IP:           ip.String(),
		Network:      c.network.CIDR,
		Peers:        peers,
		SubnetRoutes: subnetRoutes,
	}

	c.sendToClient(client, types.Message{
		Type:    types.MessageTypeRegister,
		Payload: response,
	})

	log.Printf("Node registered: %s (%s) at %s", req.NodeName, nodeID, ip)
}

func (c *Controller) handleUpdate(client *ClientConnection, payload interface{}) {
	var req types.UpdateRequest
	data, _ := json.Marshal(payload)
	json.Unmarshal(data, &req)

	c.mu.Lock()
	if node, exists := c.nodes[req.NodeID]; exists {
		node.LastSeen = time.Now()
		node.IsOnline = true
	}
	c.mu.Unlock()

	// Send updated peers
	peers := c.getPeersForNode(req.NodeID)
	c.sendToClient(client, types.Message{
		Type:    types.MessageTypePeers,
		Payload: peers,
	})
}

func (c *Controller) handlePing(client *ClientConnection) {
	client.LastPing = time.Now()
	c.sendToClient(client, types.Message{
		Type:    types.MessageTypePong,
		Payload: map[string]interface{}{},
	})
}

func (c *Controller) handleSubnetAdvert(client *ClientConnection, payload interface{}) {
	var req types.SubnetAdvertRequest
	data, _ := json.Marshal(payload)
	json.Unmarshal(data, &req)

	c.mu.Lock()
	if node, exists := c.nodes[req.NodeID]; exists {
		node.IsSubnetRouter = true
		node.Subnets = req.Subnets
		
		// Update or create subnet route
		subnetRoute := &types.SubnetRoute{
			NodeID:   req.NodeID,
			NodeName: node.Name,
			Subnets:  req.Subnets,
			Active:   true,
		}
		c.subnetRoutes[req.NodeID] = subnetRoute
		log.Printf("Subnet advertisement from %s: %v", node.Name, req.Subnets)
	}
	c.mu.Unlock()

	// Broadcast subnet update to all clients
	c.broadcastSubnetUpdate()
}

func (c *Controller) getAllSubnetRoutes() []types.SubnetRoute {
	c.mu.RLock()
	defer c.mu.RUnlock()

	var routes []types.SubnetRoute
	for _, route := range c.subnetRoutes {
		if route.Active {
			routes = append(routes, *route)
		}
	}
	return routes
}

func (c *Controller) broadcastSubnetUpdate() {
	subnetRoutes := c.getAllSubnetRoutes()
	update := types.SubnetUpdateRequest{
		SubnetRoutes: subnetRoutes,
	}

	msg := types.Message{
		Type:    types.MessageTypeSubnetUpdate,
		Payload: update,
	}

	c.mu.RLock()
	for _, client := range c.clients {
		c.sendToClient(client, msg)
	}
	c.mu.RUnlock()
}

func (c *Controller) sendToClient(client *ClientConnection, msg types.Message) {
	data, err := json.Marshal(msg)
	if err != nil {
		log.Printf("Failed to marshal message: %v", err)
		return
	}

	select {
	case client.Send <- data:
	default:
		close(client.Send)
	}
}

func (c *Controller) unregisterClient(client *ClientConnection) {
	if client.NodeID != "" {
		c.mu.Lock()
		if node, exists := c.nodes[client.NodeID]; exists {
			node.IsOnline = false
		}
		delete(c.clients, client.NodeID)
		c.mu.Unlock()
		log.Printf("Client disconnected: %s", client.NodeID)
	}
}

func (c *Controller) allocateIP() net.IP {
	c.mu.Lock()
	defer c.mu.Unlock()

	// Simple IP allocation from 10.0.0.2 to 10.0.0.254
	for i := 2; i <= 254; i++ {
		ip := net.IPv4(10, 0, 0, byte(i))
		used := false
		for _, node := range c.nodes {
			if node.IP.Equal(ip) {
				used = true
				break
			}
		}
		if !used {
			return ip
		}
	}
	
	// Fallback
	return net.IPv4(10, 0, 0, 2)
}

func (c *Controller) getPeersForNode(nodeID string) []types.Peer {
	c.mu.RLock()
	defer c.mu.RUnlock()

	var peers []types.Peer
	for id, node := range c.nodes {
		if id != nodeID && node.IsOnline {
			allowedIPs := []string{node.IP.String() + "/32"}
			
			// Add subnet routes if this node is a subnet router
			if node.IsSubnetRouter && len(node.Subnets) > 0 {
				allowedIPs = append(allowedIPs, node.Subnets...)
			}
			
			peer := types.Peer{
				PublicKey:      node.PublicKey,
				AllowedIPs:     allowedIPs,
				KeepAlive:      25,
				IsSubnetRouter: node.IsSubnetRouter,
				Subnets:        node.Subnets,
			}
			peers = append(peers, peer)
		}
	}
	return peers
}

func (c *Controller) cleanupRoutine() {
	ticker := time.NewTicker(30 * time.Second)
	for range ticker.C {
		c.mu.Lock()
		now := time.Now()
		for id, client := range c.clients {
			if now.Sub(client.LastPing) > 2*time.Minute {
				log.Printf("Removing stale client: %s", id)
				client.Conn.Close()
				delete(c.clients, id)
				if node, exists := c.nodes[id]; exists {
					node.IsOnline = false
				}
			}
		}
		c.mu.Unlock()
	}
}

// HTTP handlers
func (c *Controller) handleGetNodes(w http.ResponseWriter, r *http.Request) {
	c.mu.RLock()
	nodes := make([]*types.Node, 0, len(c.nodes))
	for _, node := range c.nodes {
		nodes = append(nodes, node)
	}
	c.mu.RUnlock()

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(nodes)
}

func (c *Controller) handleGetNetwork(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(c.network)
}

func (c *Controller) handleGetNode(w http.ResponseWriter, r *http.Request) {
	vars := mux.Vars(r)
	nodeID := vars["id"]

	c.mu.RLock()
	node, exists := c.nodes[nodeID]
	c.mu.RUnlock()

	if !exists {
		http.NotFound(w, r)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(node)
}

func (c *Controller) handleDeleteNode(w http.ResponseWriter, r *http.Request) {
	vars := mux.Vars(r)
	nodeID := vars["id"]

	c.mu.Lock()
	if client, exists := c.clients[nodeID]; exists {
		client.Conn.Close()
		delete(c.clients, nodeID)
	}
	delete(c.nodes, nodeID)
	delete(c.subnetRoutes, nodeID)
	c.mu.Unlock()

	w.WriteHeader(http.StatusNoContent)
}

func (c *Controller) handleGetSubnets(w http.ResponseWriter, r *http.Request) {
	c.mu.RLock()
	subnetRoutes := make([]*types.SubnetRoute, 0, len(c.subnetRoutes))
	for _, route := range c.subnetRoutes {
		subnetRoutes = append(subnetRoutes, route)
	}
	c.mu.RUnlock()

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(subnetRoutes)
}

func generateKeyPair() (string, string, error) {
	privateKey := make([]byte, curve25519.ScalarSize)
	if _, err := rand.Read(privateKey); err != nil {
		return "", "", err
	}

	publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint)
	if err != nil {
		return "", "", err
	}

	return fmt.Sprintf("%x", privateKey), fmt.Sprintf("%x", publicKey), nil
} 