This commit is contained in:
hailin 2025-06-12 14:35:31 +08:00
parent f085a55c9a
commit d395130c61
1 changed files with 57 additions and 34 deletions

View File

@ -5,7 +5,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"net" "net"
"os"
"os/exec" "os/exec"
"runtime" "runtime"
"strings" "strings"
@ -19,6 +18,7 @@ type BroadcastMessage struct {
Name string `json:"name"` Name string `json:"name"`
} }
// 获取默认网关 IP
func getGatewayIP() (string, error) { func getGatewayIP() (string, error) {
var cmd *exec.Cmd var cmd *exec.Cmd
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
@ -28,7 +28,6 @@ func getGatewayIP() (string, error) {
} else { } else {
return "", fmt.Errorf("unsupported OS") return "", fmt.Errorf("unsupported OS")
} }
out, err := cmd.Output() out, err := cmd.Output()
if err != nil { if err != nil {
return "", err return "", err
@ -36,36 +35,33 @@ func getGatewayIP() (string, error) {
return strings.TrimSpace(string(out)), nil return strings.TrimSpace(string(out)), nil
} }
// 判断两个 IP 是否在同一子网内
func ipInSameSubnet(ip1 net.IP, ip2 net.IP, mask net.IPMask) bool { func ipInSameSubnet(ip1 net.IP, ip2 net.IP, mask net.IPMask) bool {
if mask == nil { if mask == nil {
return false return false
} }
net1 := ip1.Mask(mask) return ip1.Mask(mask).Equal(ip2.Mask(mask))
net2 := ip2.Mask(mask)
return net1.Equal(net2)
} }
// 获取本地与网关同网段的 IP 地址
func getLocalIP() string { func getLocalIP() string {
gatewayStr, err := getGatewayIP() gatewayStr, err := getGatewayIP()
if err != nil { if err != nil {
fmt.Println("Failed to get gateway IP:", err)
return "127.0.0.1" return "127.0.0.1"
} }
gateway := net.ParseIP(gatewayStr) gateway := net.ParseIP(gatewayStr)
if gateway == nil { if gateway == nil {
fmt.Println("Invalid gateway IP")
return "127.0.0.1" return "127.0.0.1"
} }
interfaces, err := net.Interfaces() interfaces, err := net.Interfaces()
if err != nil { if err != nil {
fmt.Println("Failed to get network interfaces:", err)
return "127.0.0.1" return "127.0.0.1"
} }
for _, iface := range interfaces { for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 { if iface.Flags&net.FlagUp == 0 {
continue // interface down continue
} }
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
@ -74,7 +70,6 @@ func getLocalIP() string {
for _, addr := range addrs { for _, addr := range addrs {
var ip net.IP var ip net.IP
var mask net.IPMask var mask net.IPMask
switch v := addr.(type) { switch v := addr.(type) {
case *net.IPNet: case *net.IPNet:
ip = v.IP ip = v.IP
@ -82,11 +77,9 @@ func getLocalIP() string {
case *net.IPAddr: case *net.IPAddr:
ip = v.IP ip = v.IP
} }
if ip == nil || ip.IsLoopback() || ip.To4() == nil { if ip == nil || ip.IsLoopback() || ip.To4() == nil {
continue continue
} }
if ipInSameSubnet(ip, gateway, mask) { if ipInSameSubnet(ip, gateway, mask) {
return ip.String() return ip.String()
} }
@ -95,46 +88,76 @@ func getLocalIP() string {
return "127.0.0.1" return "127.0.0.1"
} }
// 判断 IP 是否有效(用于判断网络是否连通)
func isValidIP(ip string) bool {
if ip == "" || ip == "127.0.0.1" {
return false
}
parsed := net.ParseIP(ip)
if parsed == nil || parsed.IsLoopback() || parsed.IsUnspecified() || parsed.IsLinkLocalUnicast() || parsed.IsLinkLocalMulticast() {
return false
}
return true
}
func main() { func main() {
// 参数 // 命令行参数
var ( var (
port = flag.Int("port", 8000, "Device service port") port = flag.Int("port", 9876, "Device service port")
interval = flag.Int("interval", 2, "Broadcast interval in seconds") interval = flag.Int("interval", 2, "Broadcast interval in seconds")
name = flag.String("name", "My-AI-Server", "Device name") name = flag.String("name", "PlugAI Server", "T1")
) )
flag.Parse() flag.Parse()
ip := getLocalIP() var (
if ip == "127.0.0.1" { prevIP string
fmt.Println("Warning: failed to find suitable IP address, using loopback") conn *net.UDPConn
} )
addr := net.UDPAddr{
IP: net.IPv4bcast,
Port: 9876,
}
conn, err := net.DialUDP("udp4", nil, &addr)
if err != nil {
fmt.Println("Failed to dial UDP broadcast:", err)
os.Exit(1)
}
defer conn.Close()
for { for {
currentIP := getLocalIP()
if !isValidIP(currentIP) {
fmt.Println("网络未就绪,等待中...")
time.Sleep(time.Duration(*interval) * time.Second)
continue
}
// 如果 IP 改变或连接尚未建立,则重建 UDP 连接
if currentIP != prevIP || conn == nil {
if conn != nil {
conn.Close()
}
addr := net.UDPAddr{
IP: net.IPv4bcast,
Port: 9876,
}
newConn, err := net.DialUDP("udp4", nil, &addr)
if err != nil {
fmt.Println("UDP 连接失败:", err)
time.Sleep(time.Duration(*interval) * time.Second)
continue
}
conn = newConn
prevIP = currentIP
fmt.Println("绑定新 IP 广播:", currentIP)
}
// 构建并发送广播消息
msg := BroadcastMessage{ msg := BroadcastMessage{
Type: "ai_server_announce", Type: "ai_server_announce",
IP: ip, IP: currentIP,
Port: *port, Port: *port,
Name: *name, Name: *name,
} }
data, _ := json.Marshal(msg) data, _ := json.Marshal(msg)
_, err := conn.Write(data) _, err := conn.Write(data)
if err != nil { if err != nil {
fmt.Println("Broadcast failed:", err) fmt.Println("广播失败:", err)
} else { } else {
fmt.Printf("Broadcasted: %s:%d (%s)\n", ip, *port, *name) fmt.Printf("广播中: %s:%d (%s)\n", currentIP, *port, *name)
} }
time.Sleep(time.Duration(*interval) * time.Second) time.Sleep(time.Duration(*interval) * time.Second)
} }
} }