exec.go
  1  package util
  2  
  3  import (
  4  	"bytes"
  5  	"context"
  6  	"io"
  7  	"os/exec"
  8  	"reflect"
  9  	"strings"
 10  	"sync"
 11  	"time"
 12  
 13  	pkgerrors "github.com/pkg/errors"
 14  	"github.com/zimmski/osutil"
 15  	"github.com/zimmski/osutil/bytesutil"
 16  
 17  	"github.com/symflower/eval-dev-quality/log"
 18  )
 19  
 20  // Command defines a command that should be executed.
 21  type Command struct {
 22  	// Command holds the command with its optional arguments.
 23  	Command []string
 24  	// Stdin holds a string which is passed on as STDIN.
 25  	Stdin string
 26  
 27  	// Directory defines the directory the execution should run in, without changing the working directory of the caller.
 28  	Directory string
 29  	// Env overwrites the environment variables of the executed command.
 30  	Env map[string]string
 31  }
 32  
 33  // CommandWithResult executes a command and returns its output, while printing the same output to the given logger.
 34  func CommandWithResult(ctx context.Context, logger *log.Logger, command *Command) (output string, err error) {
 35  	logger.Printf("$ %s", strings.Join(command.Command, " "))
 36  
 37  	var writer bytesutil.SynchronizedBuffer
 38  	c := exec.CommandContext(ctx, command.Command[0], command.Command[1:]...)
 39  	if command.Directory != "" {
 40  		c.Dir = command.Directory
 41  	}
 42  	if command.Env != nil {
 43  		envs := osutil.EnvironMap()
 44  		for k, v := range command.Env {
 45  			envs[k] = v
 46  		}
 47  		for k, v := range envs {
 48  			c.Env = append(c.Env, k+"="+v)
 49  		}
 50  	}
 51  	if command.Stdin != "" {
 52  		c.Stdin = bytes.NewBufferString(command.Stdin)
 53  	}
 54  	c.Stdout = io.MultiWriter(logger.Writer(), &writer)
 55  	c.Stderr = c.Stdout
 56  
 57  	c.WaitDelay = 3 * time.Second // Some binaries do not like to be killed, e.g. "ollama", so we kill them after some time automatically.
 58  
 59  	if err := c.Run(); err != nil {
 60  		return writer.String(), pkgerrors.WithStack(pkgerrors.WithMessage(err, writer.String()))
 61  	}
 62  
 63  	return writer.String(), nil
 64  }
 65  
 66  // Flags returns a list of `long` flags bound on the command or nil.
 67  func Flags(cmd any) (args []string) {
 68  	typ := reflect.TypeOf(cmd)
 69  
 70  	// Dereference pointer
 71  	if typ.Kind() == reflect.Pointer {
 72  		typ = typ.Elem()
 73  	}
 74  
 75  	if typ.Kind() != reflect.Struct {
 76  		return nil
 77  	}
 78  
 79  	for i := 0; i < typ.NumField(); i++ {
 80  		field := typ.Field(i)
 81  		arg, ok := field.Tag.Lookup("long")
 82  		if !ok {
 83  			continue
 84  		}
 85  
 86  		args = append(args, arg)
 87  	}
 88  
 89  	return args
 90  }
 91  
 92  // FilterArgs filters the arguments by either ignoring/allowing them in the result.
 93  func FilterArgs(args []string, filter []string, ignore bool) (filtered []string) {
 94  	filterMap := map[string]bool{}
 95  	for _, v := range filter {
 96  		filterMap["--"+v] = true
 97  	}
 98  
 99  	// Resolve args with equals sign.
100  	var resolvedArgs []string
101  	for _, v := range args {
102  		if strings.HasPrefix(v, "--") && strings.Contains(v, "=") {
103  			resolvedArgs = append(resolvedArgs, strings.SplitN(v, "=", 2)...)
104  		} else {
105  			resolvedArgs = append(resolvedArgs, v)
106  		}
107  	}
108  
109  	skip := true
110  	for _, v := range resolvedArgs {
111  		if strings.HasPrefix(v, "--") {
112  			if ignore {
113  				skip = filterMap[v]
114  			} else {
115  				skip = !filterMap[v]
116  			}
117  		}
118  
119  		if skip {
120  			continue
121  		}
122  
123  		filtered = append(filtered, v)
124  	}
125  
126  	return filtered
127  }
128  
129  // FilterArgsKeep filters the given argument list and only returns arguments defined present in "filter".
130  func FilterArgsKeep(args []string, filter []string) (filtered []string) {
131  	return FilterArgs(args, filter, false)
132  }
133  
134  // FilterArgsRemove filters the given argument list and returns arguments where "filter" entries are removed.
135  func FilterArgsRemove(args []string, filter []string) (filtered []string) {
136  	return FilterArgs(args, filter, true)
137  }
138  
139  // Parallel holds a buffered channel for limiting parallel executions.
140  type Parallel struct {
141  	ch chan struct{}
142  	wg sync.WaitGroup
143  }
144  
145  // NewParallel returns a Parallel execution helper.
146  func NewParallel(maxWorkers uint) *Parallel {
147  	return &Parallel{
148  		ch: make(chan struct{}, maxWorkers),
149  	}
150  }
151  
152  // acquire slot.
153  func (p *Parallel) acquire() {
154  	p.ch <- struct{}{}
155  }
156  
157  // release slot.
158  func (p *Parallel) release() {
159  	<-p.ch
160  }
161  
162  // Execute runs the given function while checking for a execution limit.
163  func (p *Parallel) Execute(f func()) {
164  	p.acquire()
165  	p.wg.Add(1)
166  
167  	go func() {
168  		defer p.release()
169  		defer p.wg.Done()
170  		f()
171  	}()
172  }
173  
174  // Wait waits until all executions are done.
175  func (l *Parallel) Wait() {
176  	l.wg.Wait()
177  }