/ components / execd / pkg / runtime / helpers_test.go
helpers_test.go
  1  // Copyright 2025 Alibaba Group Holding Ltd.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  
 15  package runtime
 16  
 17  import (
 18  	"context"
 19  	"database/sql"
 20  	"database/sql/driver"
 21  	"errors"
 22  	"fmt"
 23  	"io"
 24  	"sync/atomic"
 25  	"testing"
 26  	"time"
 27  
 28  	"github.com/stretchr/testify/require"
 29  )
 30  
 31  type stubDriver struct {
 32  	columns          []string
 33  	rows             [][]driver.Value
 34  	execRowsAffected int64
 35  	queryErr         error
 36  	execErr          error
 37  	pingErr          error
 38  	execCalled       int32
 39  	queryCalled      int32
 40  }
 41  
 42  type stubConn struct {
 43  	d *stubDriver
 44  }
 45  
 46  func (c *stubConn) Prepare(string) (driver.Stmt, error) { return nil, errors.New("not implemented") }
 47  func (c *stubConn) Close() error                        { return nil }
 48  func (c *stubConn) Begin() (driver.Tx, error)           { return nil, errors.New("not implemented") }
 49  
 50  func (c *stubConn) Ping(context.Context) error {
 51  	return c.d.pingErr
 52  }
 53  
 54  func (c *stubConn) ExecContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Result, error) {
 55  	atomic.AddInt32(&c.d.execCalled, 1)
 56  	if c.d.execErr != nil {
 57  		return nil, c.d.execErr
 58  	}
 59  	return driver.RowsAffected(c.d.execRowsAffected), nil
 60  }
 61  
 62  func (c *stubConn) QueryContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Rows, error) {
 63  	atomic.AddInt32(&c.d.queryCalled, 1)
 64  	if c.d.queryErr != nil {
 65  		return nil, c.d.queryErr
 66  	}
 67  	return &stubRows{
 68  		columns: c.d.columns,
 69  		rows:    c.d.rows,
 70  	}, nil
 71  }
 72  
 73  type stubRows struct {
 74  	columns []string
 75  	rows    [][]driver.Value
 76  	idx     int
 77  }
 78  
 79  func (r *stubRows) Columns() []string { return r.columns }
 80  func (r *stubRows) Close() error      { return nil }
 81  func (r *stubRows) Next(dest []driver.Value) error {
 82  	if r.idx >= len(r.rows) {
 83  		return io.EOF
 84  	}
 85  	row := r.rows[r.idx]
 86  	r.idx++
 87  	for i, v := range row {
 88  		dest[i] = v
 89  	}
 90  	return nil
 91  }
 92  
 93  type stubConnector struct {
 94  	d *stubDriver
 95  }
 96  
 97  func (c *stubConnector) Connect(context.Context) (driver.Conn, error) {
 98  	return &stubConn{d: c.d}, nil
 99  }
100  
101  func (c *stubConnector) Driver() driver.Driver {
102  	return c
103  }
104  
105  func (c *stubConnector) Open(string) (driver.Conn, error) {
106  	return &stubConn{d: c.d}, nil
107  }
108  
109  func newStubDB(t *testing.T, d *stubDriver) *sql.DB {
110  	t.Helper()
111  	driverName := fmt.Sprintf("stub-%d", time.Now().UnixNano())
112  	sql.Register(driverName, &stubConnector{d: d})
113  	db, err := sql.Open(driverName, "")
114  	require.NoError(t, err)
115  	return db
116  }