/ vendor / github.com / btcsuite / btcd / btcjson / cmdparse.go
cmdparse.go
  1  // Copyright (c) 2014 The btcsuite developers
  2  // Use of this source code is governed by an ISC
  3  // license that can be found in the LICENSE file.
  4  
  5  package btcjson
  6  
  7  import (
  8  	"encoding/json"
  9  	"fmt"
 10  	"reflect"
 11  	"strconv"
 12  	"strings"
 13  )
 14  
 15  // makeParams creates a slice of interface values for the given struct.
 16  func makeParams(rt reflect.Type, rv reflect.Value) []interface{} {
 17  	numFields := rt.NumField()
 18  	params := make([]interface{}, 0, numFields)
 19  	for i := 0; i < numFields; i++ {
 20  		rtf := rt.Field(i)
 21  		rvf := rv.Field(i)
 22  		if rtf.Type.Kind() == reflect.Ptr {
 23  			if rvf.IsNil() {
 24  				break
 25  			}
 26  			rvf.Elem()
 27  		}
 28  		params = append(params, rvf.Interface())
 29  	}
 30  
 31  	return params
 32  }
 33  
 34  // MarshalCmd marshals the passed command to a JSON-RPC request byte slice that
 35  // is suitable for transmission to an RPC server.  The provided command type
 36  // must be a registered type.  All commands provided by this package are
 37  // registered by default.
 38  func MarshalCmd(id interface{}, cmd interface{}) ([]byte, error) {
 39  	// Look up the cmd type and error out if not registered.
 40  	rt := reflect.TypeOf(cmd)
 41  	registerLock.RLock()
 42  	method, ok := concreteTypeToMethod[rt]
 43  	registerLock.RUnlock()
 44  	if !ok {
 45  		str := fmt.Sprintf("%q is not registered", method)
 46  		return nil, makeError(ErrUnregisteredMethod, str)
 47  	}
 48  
 49  	// The provided command must not be nil.
 50  	rv := reflect.ValueOf(cmd)
 51  	if rv.IsNil() {
 52  		str := "the specified command is nil"
 53  		return nil, makeError(ErrInvalidType, str)
 54  	}
 55  
 56  	// Create a slice of interface values in the order of the struct fields
 57  	// while respecting pointer fields as optional params and only adding
 58  	// them if they are non-nil.
 59  	params := makeParams(rt.Elem(), rv.Elem())
 60  
 61  	// Generate and marshal the final JSON-RPC request.
 62  	rawCmd, err := NewRequest(id, method, params)
 63  	if err != nil {
 64  		return nil, err
 65  	}
 66  	return json.Marshal(rawCmd)
 67  }
 68  
 69  // checkNumParams ensures the supplied number of params is at least the minimum
 70  // required number for the command and less than the maximum allowed.
 71  func checkNumParams(numParams int, info *methodInfo) error {
 72  	if numParams < info.numReqParams || numParams > info.maxParams {
 73  		if info.numReqParams == info.maxParams {
 74  			str := fmt.Sprintf("wrong number of params (expected "+
 75  				"%d, received %d)", info.numReqParams,
 76  				numParams)
 77  			return makeError(ErrNumParams, str)
 78  		}
 79  
 80  		str := fmt.Sprintf("wrong number of params (expected "+
 81  			"between %d and %d, received %d)", info.numReqParams,
 82  			info.maxParams, numParams)
 83  		return makeError(ErrNumParams, str)
 84  	}
 85  
 86  	return nil
 87  }
 88  
 89  // populateDefaults populates default values into any remaining optional struct
 90  // fields that did not have parameters explicitly provided.  The caller should
 91  // have previously checked that the number of parameters being passed is at
 92  // least the required number of parameters to avoid unnecessary work in this
 93  // function, but since required fields never have default values, it will work
 94  // properly even without the check.
 95  func populateDefaults(numParams int, info *methodInfo, rv reflect.Value) {
 96  	// When there are no more parameters left in the supplied parameters,
 97  	// any remaining struct fields must be optional.  Thus, populate them
 98  	// with their associated default value as needed.
 99  	for i := numParams; i < info.maxParams; i++ {
100  		rvf := rv.Field(i)
101  		if defaultVal, ok := info.defaults[i]; ok {
102  			rvf.Set(defaultVal)
103  		}
104  	}
105  }
106  
107  // UnmarshalCmd unmarshals a JSON-RPC request into a suitable concrete command
108  // so long as the method type contained within the marshalled request is
109  // registered.
110  func UnmarshalCmd(r *Request) (interface{}, error) {
111  	registerLock.RLock()
112  	rtp, ok := methodToConcreteType[r.Method]
113  	info := methodToInfo[r.Method]
114  	registerLock.RUnlock()
115  	if !ok {
116  		str := fmt.Sprintf("%q is not registered", r.Method)
117  		return nil, makeError(ErrUnregisteredMethod, str)
118  	}
119  	rt := rtp.Elem()
120  	rvp := reflect.New(rt)
121  	rv := rvp.Elem()
122  
123  	// Ensure the number of parameters are correct.
124  	numParams := len(r.Params)
125  	if err := checkNumParams(numParams, &info); err != nil {
126  		return nil, err
127  	}
128  
129  	// Loop through each of the struct fields and unmarshal the associated
130  	// parameter into them.
131  	for i := 0; i < numParams; i++ {
132  		rvf := rv.Field(i)
133  		// Unmarshal the parameter into the struct field.
134  		concreteVal := rvf.Addr().Interface()
135  		if err := json.Unmarshal(r.Params[i], &concreteVal); err != nil {
136  			// The most common error is the wrong type, so
137  			// explicitly detect that error and make it nicer.
138  			fieldName := strings.ToLower(rt.Field(i).Name)
139  			if jerr, ok := err.(*json.UnmarshalTypeError); ok {
140  				str := fmt.Sprintf("parameter #%d '%s' must "+
141  					"be type %v (got %v)", i+1, fieldName,
142  					jerr.Type, jerr.Value)
143  				return nil, makeError(ErrInvalidType, str)
144  			}
145  
146  			// Fallback to showing the underlying error.
147  			str := fmt.Sprintf("parameter #%d '%s' failed to "+
148  				"unmarshal: %v", i+1, fieldName, err)
149  			return nil, makeError(ErrInvalidType, str)
150  		}
151  	}
152  
153  	// When there are less supplied parameters than the total number of
154  	// params, any remaining struct fields must be optional.  Thus, populate
155  	// them with their associated default value as needed.
156  	if numParams < info.maxParams {
157  		populateDefaults(numParams, &info, rv)
158  	}
159  
160  	return rvp.Interface(), nil
161  }
162  
163  // isNumeric returns whether the passed reflect kind is a signed or unsigned
164  // integer of any magnitude or a float of any magnitude.
165  func isNumeric(kind reflect.Kind) bool {
166  	switch kind {
167  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
168  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
169  		reflect.Uint64, reflect.Float32, reflect.Float64:
170  
171  		return true
172  	}
173  
174  	return false
175  }
176  
177  // typesMaybeCompatible returns whether the source type can possibly be
178  // assigned to the destination type.  This is intended as a relatively quick
179  // check to weed out obviously invalid conversions.
180  func typesMaybeCompatible(dest reflect.Type, src reflect.Type) bool {
181  	// The same types are obviously compatible.
182  	if dest == src {
183  		return true
184  	}
185  
186  	// When both types are numeric, they are potentially compatibile.
187  	srcKind := src.Kind()
188  	destKind := dest.Kind()
189  	if isNumeric(destKind) && isNumeric(srcKind) {
190  		return true
191  	}
192  
193  	if srcKind == reflect.String {
194  		// Strings can potentially be converted to numeric types.
195  		if isNumeric(destKind) {
196  			return true
197  		}
198  
199  		switch destKind {
200  		// Strings can potentially be converted to bools by
201  		// strconv.ParseBool.
202  		case reflect.Bool:
203  			return true
204  
205  		// Strings can be converted to any other type which has as
206  		// underlying type of string.
207  		case reflect.String:
208  			return true
209  
210  		// Strings can potentially be converted to arrays, slice,
211  		// structs, and maps via json.Unmarshal.
212  		case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
213  			return true
214  		}
215  	}
216  
217  	return false
218  }
219  
220  // baseType returns the type of the argument after indirecting through all
221  // pointers along with how many indirections were necessary.
222  func baseType(arg reflect.Type) (reflect.Type, int) {
223  	var numIndirects int
224  	for arg.Kind() == reflect.Ptr {
225  		arg = arg.Elem()
226  		numIndirects++
227  	}
228  	return arg, numIndirects
229  }
230  
231  // assignField is the main workhorse for the NewCmd function which handles
232  // assigning the provided source value to the destination field.  It supports
233  // direct type assignments, indirection, conversion of numeric types, and
234  // unmarshaling of strings into arrays, slices, structs, and maps via
235  // json.Unmarshal.
236  func assignField(paramNum int, fieldName string, dest reflect.Value, src reflect.Value) error {
237  	// Just error now when the types have no chance of being compatible.
238  	destBaseType, destIndirects := baseType(dest.Type())
239  	srcBaseType, srcIndirects := baseType(src.Type())
240  	if !typesMaybeCompatible(destBaseType, srcBaseType) {
241  		str := fmt.Sprintf("parameter #%d '%s' must be type %v (got "+
242  			"%v)", paramNum, fieldName, destBaseType, srcBaseType)
243  		return makeError(ErrInvalidType, str)
244  	}
245  
246  	// Check if it's possible to simply set the dest to the provided source.
247  	// This is the case when the base types are the same or they are both
248  	// pointers that can be indirected to be the same without needing to
249  	// create pointers for the destination field.
250  	if destBaseType == srcBaseType && srcIndirects >= destIndirects {
251  		for i := 0; i < srcIndirects-destIndirects; i++ {
252  			src = src.Elem()
253  		}
254  		dest.Set(src)
255  		return nil
256  	}
257  
258  	// When the destination has more indirects than the source, the extra
259  	// pointers have to be created.  Only create enough pointers to reach
260  	// the same level of indirection as the source so the dest can simply be
261  	// set to the provided source when the types are the same.
262  	destIndirectsRemaining := destIndirects
263  	if destIndirects > srcIndirects {
264  		indirectDiff := destIndirects - srcIndirects
265  		for i := 0; i < indirectDiff; i++ {
266  			dest.Set(reflect.New(dest.Type().Elem()))
267  			dest = dest.Elem()
268  			destIndirectsRemaining--
269  		}
270  	}
271  
272  	if destBaseType == srcBaseType {
273  		dest.Set(src)
274  		return nil
275  	}
276  
277  	// Make any remaining pointers needed to get to the base dest type since
278  	// the above direct assign was not possible and conversions are done
279  	// against the base types.
280  	for i := 0; i < destIndirectsRemaining; i++ {
281  		dest.Set(reflect.New(dest.Type().Elem()))
282  		dest = dest.Elem()
283  	}
284  
285  	// Indirect through to the base source value.
286  	for src.Kind() == reflect.Ptr {
287  		src = src.Elem()
288  	}
289  
290  	// Perform supported type conversions.
291  	switch src.Kind() {
292  	// Source value is a signed integer of various magnitude.
293  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
294  		reflect.Int64:
295  
296  		switch dest.Kind() {
297  		// Destination is a signed integer of various magnitude.
298  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
299  			reflect.Int64:
300  
301  			srcInt := src.Int()
302  			if dest.OverflowInt(srcInt) {
303  				str := fmt.Sprintf("parameter #%d '%s' "+
304  					"overflows destination type %v",
305  					paramNum, fieldName, destBaseType)
306  				return makeError(ErrInvalidType, str)
307  			}
308  
309  			dest.SetInt(srcInt)
310  
311  		// Destination is an unsigned integer of various magnitude.
312  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
313  			reflect.Uint64:
314  
315  			srcInt := src.Int()
316  			if srcInt < 0 || dest.OverflowUint(uint64(srcInt)) {
317  				str := fmt.Sprintf("parameter #%d '%s' "+
318  					"overflows destination type %v",
319  					paramNum, fieldName, destBaseType)
320  				return makeError(ErrInvalidType, str)
321  			}
322  			dest.SetUint(uint64(srcInt))
323  
324  		default:
325  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
326  				"%v (got %v)", paramNum, fieldName, destBaseType,
327  				srcBaseType)
328  			return makeError(ErrInvalidType, str)
329  		}
330  
331  	// Source value is an unsigned integer of various magnitude.
332  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
333  		reflect.Uint64:
334  
335  		switch dest.Kind() {
336  		// Destination is a signed integer of various magnitude.
337  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
338  			reflect.Int64:
339  
340  			srcUint := src.Uint()
341  			if srcUint > uint64(1<<63)-1 {
342  				str := fmt.Sprintf("parameter #%d '%s' "+
343  					"overflows destination type %v",
344  					paramNum, fieldName, destBaseType)
345  				return makeError(ErrInvalidType, str)
346  			}
347  			if dest.OverflowInt(int64(srcUint)) {
348  				str := fmt.Sprintf("parameter #%d '%s' "+
349  					"overflows destination type %v",
350  					paramNum, fieldName, destBaseType)
351  				return makeError(ErrInvalidType, str)
352  			}
353  			dest.SetInt(int64(srcUint))
354  
355  		// Destination is an unsigned integer of various magnitude.
356  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
357  			reflect.Uint64:
358  
359  			srcUint := src.Uint()
360  			if dest.OverflowUint(srcUint) {
361  				str := fmt.Sprintf("parameter #%d '%s' "+
362  					"overflows destination type %v",
363  					paramNum, fieldName, destBaseType)
364  				return makeError(ErrInvalidType, str)
365  			}
366  			dest.SetUint(srcUint)
367  
368  		default:
369  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
370  				"%v (got %v)", paramNum, fieldName, destBaseType,
371  				srcBaseType)
372  			return makeError(ErrInvalidType, str)
373  		}
374  
375  	// Source value is a float.
376  	case reflect.Float32, reflect.Float64:
377  		destKind := dest.Kind()
378  		if destKind != reflect.Float32 && destKind != reflect.Float64 {
379  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
380  				"%v (got %v)", paramNum, fieldName, destBaseType,
381  				srcBaseType)
382  			return makeError(ErrInvalidType, str)
383  		}
384  
385  		srcFloat := src.Float()
386  		if dest.OverflowFloat(srcFloat) {
387  			str := fmt.Sprintf("parameter #%d '%s' overflows "+
388  				"destination type %v", paramNum, fieldName,
389  				destBaseType)
390  			return makeError(ErrInvalidType, str)
391  		}
392  		dest.SetFloat(srcFloat)
393  
394  	// Source value is a string.
395  	case reflect.String:
396  		switch dest.Kind() {
397  		// String -> bool
398  		case reflect.Bool:
399  			b, err := strconv.ParseBool(src.String())
400  			if err != nil {
401  				str := fmt.Sprintf("parameter #%d '%s' must "+
402  					"parse to a %v", paramNum, fieldName,
403  					destBaseType)
404  				return makeError(ErrInvalidType, str)
405  			}
406  			dest.SetBool(b)
407  
408  		// String -> signed integer of varying size.
409  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
410  			reflect.Int64:
411  
412  			srcInt, err := strconv.ParseInt(src.String(), 0, 0)
413  			if err != nil {
414  				str := fmt.Sprintf("parameter #%d '%s' must "+
415  					"parse to a %v", paramNum, fieldName,
416  					destBaseType)
417  				return makeError(ErrInvalidType, str)
418  			}
419  			if dest.OverflowInt(srcInt) {
420  				str := fmt.Sprintf("parameter #%d '%s' "+
421  					"overflows destination type %v",
422  					paramNum, fieldName, destBaseType)
423  				return makeError(ErrInvalidType, str)
424  			}
425  			dest.SetInt(srcInt)
426  
427  		// String -> unsigned integer of varying size.
428  		case reflect.Uint, reflect.Uint8, reflect.Uint16,
429  			reflect.Uint32, reflect.Uint64:
430  
431  			srcUint, err := strconv.ParseUint(src.String(), 0, 0)
432  			if err != nil {
433  				str := fmt.Sprintf("parameter #%d '%s' must "+
434  					"parse to a %v", paramNum, fieldName,
435  					destBaseType)
436  				return makeError(ErrInvalidType, str)
437  			}
438  			if dest.OverflowUint(srcUint) {
439  				str := fmt.Sprintf("parameter #%d '%s' "+
440  					"overflows destination type %v",
441  					paramNum, fieldName, destBaseType)
442  				return makeError(ErrInvalidType, str)
443  			}
444  			dest.SetUint(srcUint)
445  
446  		// String -> float of varying size.
447  		case reflect.Float32, reflect.Float64:
448  			srcFloat, err := strconv.ParseFloat(src.String(), 0)
449  			if err != nil {
450  				str := fmt.Sprintf("parameter #%d '%s' must "+
451  					"parse to a %v", paramNum, fieldName,
452  					destBaseType)
453  				return makeError(ErrInvalidType, str)
454  			}
455  			if dest.OverflowFloat(srcFloat) {
456  				str := fmt.Sprintf("parameter #%d '%s' "+
457  					"overflows destination type %v",
458  					paramNum, fieldName, destBaseType)
459  				return makeError(ErrInvalidType, str)
460  			}
461  			dest.SetFloat(srcFloat)
462  
463  		// String -> string (typecast).
464  		case reflect.String:
465  			dest.SetString(src.String())
466  
467  		// String -> arrays, slices, structs, and maps via
468  		// json.Unmarshal.
469  		case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
470  			concreteVal := dest.Addr().Interface()
471  			err := json.Unmarshal([]byte(src.String()), &concreteVal)
472  			if err != nil {
473  				str := fmt.Sprintf("parameter #%d '%s' must "+
474  					"be valid JSON which unsmarshals to a %v",
475  					paramNum, fieldName, destBaseType)
476  				return makeError(ErrInvalidType, str)
477  			}
478  			dest.Set(reflect.ValueOf(concreteVal).Elem())
479  		}
480  	}
481  
482  	return nil
483  }
484  
485  // NewCmd provides a generic mechanism to create a new command that can marshal
486  // to a JSON-RPC request while respecting the requirements of the provided
487  // method.  The method must have been registered with the package already along
488  // with its type definition.  All methods associated with the commands exported
489  // by this package are already registered by default.
490  //
491  // The arguments are most efficient when they are the exact same type as the
492  // underlying field in the command struct associated with the the method,
493  // however this function also will perform a variety of conversions to make it
494  // more flexible.  This allows, for example, command line args which are strings
495  // to be passed unaltered.  In particular, the following conversions are
496  // supported:
497  //
498  //   - Conversion between any size signed or unsigned integer so long as the
499  //     value does not overflow the destination type
500  //   - Conversion between float32 and float64 so long as the value does not
501  //     overflow the destination type
502  //   - Conversion from string to boolean for everything strconv.ParseBool
503  //     recognizes
504  //   - Conversion from string to any size integer for everything
505  //     strconv.ParseInt and strconv.ParseUint recognizes
506  //   - Conversion from string to any size float for everything
507  //     strconv.ParseFloat recognizes
508  //   - Conversion from string to arrays, slices, structs, and maps by treating
509  //     the string as marshalled JSON and calling json.Unmarshal into the
510  //     destination field
511  func NewCmd(method string, args ...interface{}) (interface{}, error) {
512  	// Look up details about the provided method.  Any methods that aren't
513  	// registered are an error.
514  	registerLock.RLock()
515  	rtp, ok := methodToConcreteType[method]
516  	info := methodToInfo[method]
517  	registerLock.RUnlock()
518  	if !ok {
519  		str := fmt.Sprintf("%q is not registered", method)
520  		return nil, makeError(ErrUnregisteredMethod, str)
521  	}
522  
523  	// Ensure the number of parameters are correct.
524  	numParams := len(args)
525  	if err := checkNumParams(numParams, &info); err != nil {
526  		return nil, err
527  	}
528  
529  	// Create the appropriate command type for the method.  Since all types
530  	// are enforced to be a pointer to a struct at registration time, it's
531  	// safe to indirect to the struct now.
532  	rvp := reflect.New(rtp.Elem())
533  	rv := rvp.Elem()
534  	rt := rtp.Elem()
535  
536  	// Loop through each of the struct fields and assign the associated
537  	// parameter into them after checking its type validity.
538  	for i := 0; i < numParams; i++ {
539  		// Attempt to assign each of the arguments to the according
540  		// struct field.
541  		rvf := rv.Field(i)
542  		fieldName := strings.ToLower(rt.Field(i).Name)
543  		err := assignField(i+1, fieldName, rvf, reflect.ValueOf(args[i]))
544  		if err != nil {
545  			return nil, err
546  		}
547  	}
548  
549  	return rvp.Interface(), nil
550  }