upstream_test.go
1 // Copyright 2026 Alibaba Group Holding Ltd. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dnsproxy 16 17 import ( 18 "net" 19 "testing" 20 "time" 21 22 "github.com/miekg/dns" 23 "github.com/stretchr/testify/require" 24 25 "github.com/alibaba/opensandbox/egress/pkg/constants" 26 ) 27 28 func TestUpstreamProbeIntervalFromEnv(t *testing.T) { 29 t.Setenv(constants.EnvDNSUpstreamProbeIntervalSec, "") 30 require.Equal(t, defaultUpstreamProbeInterval, upstreamProbeIntervalFromEnv()) 31 32 t.Setenv(constants.EnvDNSUpstreamProbeIntervalSec, "5") 33 require.Equal(t, 5*time.Second, upstreamProbeIntervalFromEnv()) 34 35 t.Setenv(constants.EnvDNSUpstreamProbeIntervalSec, "not-a-number") 36 require.Equal(t, defaultUpstreamProbeInterval, upstreamProbeIntervalFromEnv()) 37 } 38 39 func TestUpstreamProbeFromEnv(t *testing.T) { 40 t.Setenv(constants.EnvDNSUpstreamProbe, "") 41 n, qt := upstreamProbeFromEnv() 42 require.Equal(t, ".", n) 43 require.Equal(t, dns.TypeNS, qt) 44 45 t.Setenv(constants.EnvDNSUpstreamProbe, ".") 46 n, qt = upstreamProbeFromEnv() 47 require.Equal(t, ".", n) 48 require.Equal(t, dns.TypeNS, qt) 49 50 t.Setenv(constants.EnvDNSUpstreamProbe, "intranet.corp") 51 n, qt = upstreamProbeFromEnv() 52 require.Equal(t, "intranet.corp.", n) 53 require.Equal(t, dns.TypeA, qt) 54 } 55 56 func TestNormalizeEnvUpstreamAddr(t *testing.T) { 57 got, err := normalizeEnvUpstreamAddr("8.8.8.8") 58 require.NoError(t, err) 59 require.Equal(t, "8.8.8.8:53", got) 60 61 got, err = normalizeEnvUpstreamAddr("1.1.1.1:5353") 62 require.NoError(t, err) 63 require.Equal(t, "1.1.1.1:5353", got) 64 65 got, err = normalizeEnvUpstreamAddr("2001:db8::1") 66 require.NoError(t, err) 67 require.Equal(t, "[2001:db8::1]:53", got) 68 69 got, err = normalizeEnvUpstreamAddr("[2001:db8::2]:853") 70 require.NoError(t, err) 71 require.Equal(t, "[2001:db8::2]:853", got) 72 73 got, err = normalizeEnvUpstreamAddr("[2001:db8::3]") 74 require.NoError(t, err) 75 require.Equal(t, "[2001:db8::3]:53", got) 76 77 _, err = normalizeEnvUpstreamAddr("") 78 require.Error(t, err) 79 80 _, err = normalizeEnvUpstreamAddr("dns.google") 81 require.Error(t, err) 82 83 _, err = normalizeEnvUpstreamAddr("dns.google:53") 84 require.Error(t, err) 85 } 86 87 func TestParseEnvDNSUpstreams(t *testing.T) { 88 got, err := parseEnvDNSUpstreams("8.8.8.8,1.1.1.1") 89 require.NoError(t, err) 90 require.Equal(t, []string{"8.8.8.8:53", "1.1.1.1:53"}, got) 91 92 got, err = parseEnvDNSUpstreams("8.8.8.8") 93 require.NoError(t, err) 94 require.Equal(t, []string{"8.8.8.8:53"}, got) 95 96 got, err = parseEnvDNSUpstreams("8.8.8.8, 8.8.8.8, 1.1.1.1") 97 require.NoError(t, err) 98 require.Equal(t, []string{"8.8.8.8:53", "1.1.1.1:53"}, got) 99 100 _, err = parseEnvDNSUpstreams("dns.google,8.8.8.8") 101 require.Error(t, err) 102 } 103 104 func TestAllowIPsFromUpstreamAddrs(t *testing.T) { 105 ips := AllowIPsFromUpstreamAddrs([]string{"8.8.8.8:53", "1.1.1.1:53", "resolver.example.com:53"}) 106 require.Len(t, ips, 2) 107 require.Equal(t, "8.8.8.8", ips[0].String()) 108 require.Equal(t, "1.1.1.1", ips[1].String()) 109 } 110 111 func TestShouldFailoverAfterResponse(t *testing.T) { 112 p2 := &Proxy{upstreams: []string{"198.51.100.254:53", "8.8.8.8:53"}} 113 114 emptyOK := new(dns.Msg) 115 emptyOK.Rcode = dns.RcodeSuccess 116 try, _ := p2.shouldFailoverAfterResponse(emptyOK) 117 require.False(t, try, "empty NOERROR should not failover") 118 119 try, _ = p2.shouldFailoverAfterResponse(emptyOK) 120 require.False(t, try, "empty NOERROR on last upstream should not failover") 121 122 emptyNODATA := new(dns.Msg) 123 emptyNODATA.Rcode = dns.RcodeSuccess 124 emptyNODATA.Ns = []dns.RR{ 125 &dns.SOA{ 126 Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 60}, 127 Ns: "ns1.example.com.", 128 Mbox: "hostmaster.example.com.", 129 Serial: 1, 130 Refresh: 60, 131 Retry: 60, 132 Expire: 60, 133 Minttl: 60, 134 }, 135 } 136 try, _ = p2.shouldFailoverAfterResponse(emptyNODATA) 137 require.False(t, try, "A/AAAA NODATA with authority should not failover") 138 139 withA := new(dns.Msg) 140 withA.Rcode = dns.RcodeSuccess 141 withA.Answer = []dns.RR{&dns.A{Hdr: dns.RR_Header{Name: "x."}, A: net.ParseIP("1.1.1.1")}} 142 try, _ = p2.shouldFailoverAfterResponse(withA) 143 require.False(t, try) 144 145 nx := new(dns.Msg) 146 nx.Rcode = dns.RcodeNameError 147 try, _ = p2.shouldFailoverAfterResponse(nx) 148 require.False(t, try) 149 150 sf := new(dns.Msg) 151 sf.Rcode = dns.RcodeServerFailure 152 try, _ = p2.shouldFailoverAfterResponse(sf) 153 require.True(t, try) 154 }