core/service/address.go

207 lines
4.8 KiB
Go
Raw Normal View History

2025-09-23 13:09:27 +08:00
package service
import (
"fmt"
"net"
"net/url"
"strings"
)
type NetworkAddress struct {
Protocol string // tcp, tcp4, tcp6, unix, unixpacket
Host string // IP 地址或主机名
Port string // 端口号
Path string // Unix socket 路径
Raw string // 原始字符串
}
// 解析网络地址字符串
// "tcp://0.0.0.0:1212",
//
// "tcp4://127.0.0.1:8080",
// "tcp6://[::1]:8080",
// "unix:///data/app/passport.sock",
// "unixpacket:///tmp/mysql.sock",
// ":8080", // 传统格式
// "/tmp/server.sock", // 传统Unix格式
// "invalid://address", // 错误格式
func ParseNetworkAddress(addr string) (*NetworkAddress, error) {
// 如果包含 ://,按 URL 解析
if strings.Contains(addr, "://") {
return parseURLStyle(addr)
}
// 否则按传统格式解析
return parseTraditionalStyle(addr)
}
// 解析 tcp://0.0.0.0:1212 或 unix:///path/to/socket 格式
func parseURLStyle(addr string) (*NetworkAddress, error) {
u, err := url.Parse(addr)
if err != nil {
return nil, fmt.Errorf("解析URL失败: %w", err)
}
result := &NetworkAddress{
Protocol: u.Scheme,
Raw: addr,
}
switch u.Scheme {
case "tcp", "tcp4", "tcp6":
return parseTCPURL(u, result)
case "unix", "unixpacket":
return parseUnixURL(u, result)
default:
return nil, fmt.Errorf("不支持的协议: %s", u.Scheme)
}
}
// 解析 TCP 类型的 URL
func parseTCPURL(u *url.URL, result *NetworkAddress) (*NetworkAddress, error) {
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// 如果没有端口,尝试添加默认端口
if strings.Contains(err.Error(), "missing port") {
host = u.Host
port = "0" // 默认端口
} else {
return nil, fmt.Errorf("解析TCP地址失败: %w", err)
}
}
result.Host = host
result.Port = port
// 根据主机地址确定具体的协议类型
if result.Protocol == "tcp" {
result.Protocol = determineTCPProtocol(host)
}
return result, nil
}
// 解析 Unix socket 类型的 URL
func parseUnixURL(u *url.URL, result *NetworkAddress) (*NetworkAddress, error) {
// Unix socket 路径在 URL 的 Path 字段
if u.Path == "" {
return nil, fmt.Errorf("Unix socket 路径不能为空")
}
result.Path = u.Path
// 如果协议是 unix但路径表明需要数据包传输可以自动升级
if result.Protocol == "unix" && strings.Contains(u.Path, "packet") {
result.Protocol = "unixpacket"
}
return result, nil
}
// 根据主机地址确定 TCP 协议类型
func determineTCPProtocol(host string) string {
if host == "" {
return "tcp" // 默认
}
// 解析 IP 地址
ip := net.ParseIP(host)
if ip != nil {
if ip.To4() != nil {
return "tcp4"
}
return "tcp6"
}
// 如果是特殊地址
switch host {
case "0.0.0.0", "127.0.0.1", "localhost":
return "tcp4"
case "::", "::1":
return "tcp6"
default:
return "tcp" // 默认支持双栈
}
}
// 解析传统格式如 ":8080", "127.0.0.1:8080", "/tmp/socket"
func parseTraditionalStyle(addr string) (*NetworkAddress, error) {
result := &NetworkAddress{Raw: addr}
// 检查是否是 Unix socket包含路径分隔符
if strings.Contains(addr, "/") || strings.HasPrefix(addr, "@/") {
result.Protocol = "unix"
result.Path = addr
return result, nil
}
// 否则按 TCP 地址解析
result.Protocol = "tcp"
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("解析地址失败: %w", err)
}
result.Host = host
result.Port = port
result.Protocol = determineTCPProtocol(host)
return result, nil
}
// 获取网络类型用于 net.Dial 或 net.Listen
func (na *NetworkAddress) Network() string {
return na.Protocol
}
// 获取地址字符串用于 net.Dial 或 net.Listen
func (na *NetworkAddress) Address() string {
switch na.Protocol {
case "tcp", "tcp4", "tcp6":
if na.Port == "" {
return na.Host
}
return net.JoinHostPort(na.Host, na.Port)
case "unix", "unixpacket":
return na.Path
default:
return na.Raw
}
}
// 格式化输出
func (na *NetworkAddress) String() string {
switch na.Protocol {
case "tcp", "tcp4", "tcp6":
return fmt.Sprintf("%s://%s", na.Protocol, net.JoinHostPort(na.Host, na.Port))
case "unix", "unixpacket":
return fmt.Sprintf("%s://%s", na.Protocol, na.Path)
default:
return na.Raw
}
}
// 验证地址是否有效
func (na *NetworkAddress) Validate() error {
switch na.Protocol {
case "tcp", "tcp4", "tcp6":
if na.Host == "" && na.Port == "" {
return fmt.Errorf("TCP地址需要主机和端口")
}
// 验证端口
if na.Port != "" {
if _, err := net.LookupPort("tcp", na.Port); err != nil {
return fmt.Errorf("无效的端口: %s", na.Port)
}
}
case "unix", "unixpacket":
if na.Path == "" {
return fmt.Errorf("Unix socket路径不能为空")
}
default:
return fmt.Errorf("不支持的协议: %s", na.Protocol)
}
return nil
}