/ src / systems / ndarray_serde.go
ndarray_serde.go
  1  // Package systems implements 4D array serialization helpers
  2  package systems
  3  
  4  import (
  5  	"encoding/json"
  6  	"errors"
  7  	"fmt"
  8  )
  9  
 10  // Array4D represents a 4-dimensional float64 array
 11  type Array4D [][][][]float64
 12  
 13  // MarshalJSON implements custom JSON serialization
 14  func (a Array4D) MarshalJSON() ([]byte, error) {
 15  	shape := a.getShape()
 16  	data := a.flatten()
 17  	return json.Marshal([]interface{}{shape, data})
 18  }
 19  
 20  // UnmarshalJSON implements custom JSON deserialization
 21  func (a *Array4D) UnmarshalJSON(data []byte) error {
 22  	var raw []json.RawMessage
 23  	if err := json.Unmarshal(data, &raw); err != nil {
 24  		return err
 25  	}
 26  	
 27  	if len(raw) != 2 {
 28  		return errors.New("invalid Array4D format")
 29  	}
 30  
 31  	var shape [4]int
 32  	if err := json.Unmarshal(raw[0], &shape); err != nil {
 33  		return fmt.Errorf("shape deserialization error: %w", err)
 34  	}
 35  
 36  	var flatData []float64
 37  	if err := json.Unmarshal(raw[1], &flatData); err != nil {
 38  		return fmt.Errorf("data deserialization error: %w", err)
 39  	}
 40  
 41  	res, err := reshape4D(shape, flatData)
 42  	if err != nil {
 43  		return fmt.Errorf("reshape error: %w", err)
 44  	}
 45  
 46  	*a = res
 47  	return nil
 48  }
 49  
 50  // Helper methods
 51  func (a Array4D) getShape() [4]int {
 52  	if len(a) == 0 {
 53  		return [4]int{0, 0, 0, 0}
 54  	}
 55  	return [4]int{
 56  		len(a),
 57  		len(a[0]),
 58  		len(a[0][0]),
 59  		len(a[0][0][0]),
 60  	}
 61  }
 62  
 63  func (a Array4D) flatten() []float64 {
 64  	size := 1
 65  	for _, dim := range a.getShape() {
 66  		size *= dim
 67  	}
 68  	result := make([]float64, 0, size)
 69  	
 70  	for _, d1 := range a {
 71  		for _, d2 := range d1 {
 72  			for _, d3 := range d2 {
 73  				result = append(result, d3...)
 74  			}
 75  		}
 76  	}
 77  	return result
 78  }
 79  
 80  func reshape4D(shape [4]int, data []float64) (Array4D, error) {
 81  	total := shape[0] * shape[1] * shape[2] * shape[3]
 82  	if len(data) != total {
 83  		return nil, fmt.Errorf("data length %d doesn't match shape %v", len(data), shape)
 84  	}
 85  
 86  	arr := make(Array4D, shape[0])
 87  	idx := 0
 88  
 89  	for i := 0; i < shape[0]; i++ {
 90  		arr[i] = make([][][]float64, shape[1])
 91  		for j := 0; j < shape[1]; j++ {
 92  			arr[i][j] = make([][]float64, shape[2])
 93  			for k := 0; k < shape[2]; k++ {
 94  				arr[i][j][k] = make([]float64, shape[3])
 95  				for l := 0; l < shape[3]; l++ {
 96  					if idx >= len(data) {
 97  						return nil, errors.New("data index overflow")
 98  					}
 99  					arr[i][j][k][l] = data[idx]
100  					idx++
101  				}
102  			}
103  		}
104  	}
105  	return arr, nil
106  }