/ components / egress / pkg / dnsproxy / upstream_test.go
upstream_test.go
  1  // Copyright 2026 Alibaba Group Holding Ltd.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  
 15  package dnsproxy
 16  
 17  import (
 18  	"net"
 19  	"testing"
 20  	"time"
 21  
 22  	"github.com/miekg/dns"
 23  	"github.com/stretchr/testify/require"
 24  
 25  	"github.com/alibaba/opensandbox/egress/pkg/constants"
 26  )
 27  
 28  func TestUpstreamProbeIntervalFromEnv(t *testing.T) {
 29  	t.Setenv(constants.EnvDNSUpstreamProbeIntervalSec, "")
 30  	require.Equal(t, defaultUpstreamProbeInterval, upstreamProbeIntervalFromEnv())
 31  
 32  	t.Setenv(constants.EnvDNSUpstreamProbeIntervalSec, "5")
 33  	require.Equal(t, 5*time.Second, upstreamProbeIntervalFromEnv())
 34  
 35  	t.Setenv(constants.EnvDNSUpstreamProbeIntervalSec, "not-a-number")
 36  	require.Equal(t, defaultUpstreamProbeInterval, upstreamProbeIntervalFromEnv())
 37  }
 38  
 39  func TestUpstreamProbeFromEnv(t *testing.T) {
 40  	t.Setenv(constants.EnvDNSUpstreamProbe, "")
 41  	n, qt := upstreamProbeFromEnv()
 42  	require.Equal(t, ".", n)
 43  	require.Equal(t, dns.TypeNS, qt)
 44  
 45  	t.Setenv(constants.EnvDNSUpstreamProbe, ".")
 46  	n, qt = upstreamProbeFromEnv()
 47  	require.Equal(t, ".", n)
 48  	require.Equal(t, dns.TypeNS, qt)
 49  
 50  	t.Setenv(constants.EnvDNSUpstreamProbe, "intranet.corp")
 51  	n, qt = upstreamProbeFromEnv()
 52  	require.Equal(t, "intranet.corp.", n)
 53  	require.Equal(t, dns.TypeA, qt)
 54  }
 55  
 56  func TestNormalizeEnvUpstreamAddr(t *testing.T) {
 57  	got, err := normalizeEnvUpstreamAddr("8.8.8.8")
 58  	require.NoError(t, err)
 59  	require.Equal(t, "8.8.8.8:53", got)
 60  
 61  	got, err = normalizeEnvUpstreamAddr("1.1.1.1:5353")
 62  	require.NoError(t, err)
 63  	require.Equal(t, "1.1.1.1:5353", got)
 64  
 65  	got, err = normalizeEnvUpstreamAddr("2001:db8::1")
 66  	require.NoError(t, err)
 67  	require.Equal(t, "[2001:db8::1]:53", got)
 68  
 69  	got, err = normalizeEnvUpstreamAddr("[2001:db8::2]:853")
 70  	require.NoError(t, err)
 71  	require.Equal(t, "[2001:db8::2]:853", got)
 72  
 73  	got, err = normalizeEnvUpstreamAddr("[2001:db8::3]")
 74  	require.NoError(t, err)
 75  	require.Equal(t, "[2001:db8::3]:53", got)
 76  
 77  	_, err = normalizeEnvUpstreamAddr("")
 78  	require.Error(t, err)
 79  
 80  	_, err = normalizeEnvUpstreamAddr("dns.google")
 81  	require.Error(t, err)
 82  
 83  	_, err = normalizeEnvUpstreamAddr("dns.google:53")
 84  	require.Error(t, err)
 85  }
 86  
 87  func TestParseEnvDNSUpstreams(t *testing.T) {
 88  	got, err := parseEnvDNSUpstreams("8.8.8.8,1.1.1.1")
 89  	require.NoError(t, err)
 90  	require.Equal(t, []string{"8.8.8.8:53", "1.1.1.1:53"}, got)
 91  
 92  	got, err = parseEnvDNSUpstreams("8.8.8.8")
 93  	require.NoError(t, err)
 94  	require.Equal(t, []string{"8.8.8.8:53"}, got)
 95  
 96  	got, err = parseEnvDNSUpstreams("8.8.8.8, 8.8.8.8, 1.1.1.1")
 97  	require.NoError(t, err)
 98  	require.Equal(t, []string{"8.8.8.8:53", "1.1.1.1:53"}, got)
 99  
100  	_, err = parseEnvDNSUpstreams("dns.google,8.8.8.8")
101  	require.Error(t, err)
102  }
103  
104  func TestAllowIPsFromUpstreamAddrs(t *testing.T) {
105  	ips := AllowIPsFromUpstreamAddrs([]string{"8.8.8.8:53", "1.1.1.1:53", "resolver.example.com:53"})
106  	require.Len(t, ips, 2)
107  	require.Equal(t, "8.8.8.8", ips[0].String())
108  	require.Equal(t, "1.1.1.1", ips[1].String())
109  }
110  
111  func TestShouldFailoverAfterResponse(t *testing.T) {
112  	p2 := &Proxy{upstreams: []string{"198.51.100.254:53", "8.8.8.8:53"}}
113  
114  	emptyOK := new(dns.Msg)
115  	emptyOK.Rcode = dns.RcodeSuccess
116  	try, _ := p2.shouldFailoverAfterResponse(emptyOK)
117  	require.False(t, try, "empty NOERROR should not failover")
118  
119  	try, _ = p2.shouldFailoverAfterResponse(emptyOK)
120  	require.False(t, try, "empty NOERROR on last upstream should not failover")
121  
122  	emptyNODATA := new(dns.Msg)
123  	emptyNODATA.Rcode = dns.RcodeSuccess
124  	emptyNODATA.Ns = []dns.RR{
125  		&dns.SOA{
126  			Hdr:     dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 60},
127  			Ns:      "ns1.example.com.",
128  			Mbox:    "hostmaster.example.com.",
129  			Serial:  1,
130  			Refresh: 60,
131  			Retry:   60,
132  			Expire:  60,
133  			Minttl:  60,
134  		},
135  	}
136  	try, _ = p2.shouldFailoverAfterResponse(emptyNODATA)
137  	require.False(t, try, "A/AAAA NODATA with authority should not failover")
138  
139  	withA := new(dns.Msg)
140  	withA.Rcode = dns.RcodeSuccess
141  	withA.Answer = []dns.RR{&dns.A{Hdr: dns.RR_Header{Name: "x."}, A: net.ParseIP("1.1.1.1")}}
142  	try, _ = p2.shouldFailoverAfterResponse(withA)
143  	require.False(t, try)
144  
145  	nx := new(dns.Msg)
146  	nx.Rcode = dns.RcodeNameError
147  	try, _ = p2.shouldFailoverAfterResponse(nx)
148  	require.False(t, try)
149  
150  	sf := new(dns.Msg)
151  	sf.Rcode = dns.RcodeServerFailure
152  	try, _ = p2.shouldFailoverAfterResponse(sf)
153  	require.True(t, try)
154  }