ndarray_serde.go
1 // Package systems implements 4D array serialization helpers 2 package systems 3 4 import ( 5 "encoding/json" 6 "errors" 7 "fmt" 8 ) 9 10 // Array4D represents a 4-dimensional float64 array 11 type Array4D [][][][]float64 12 13 // MarshalJSON implements custom JSON serialization 14 func (a Array4D) MarshalJSON() ([]byte, error) { 15 shape := a.getShape() 16 data := a.flatten() 17 return json.Marshal([]interface{}{shape, data}) 18 } 19 20 // UnmarshalJSON implements custom JSON deserialization 21 func (a *Array4D) UnmarshalJSON(data []byte) error { 22 var raw []json.RawMessage 23 if err := json.Unmarshal(data, &raw); err != nil { 24 return err 25 } 26 27 if len(raw) != 2 { 28 return errors.New("invalid Array4D format") 29 } 30 31 var shape [4]int 32 if err := json.Unmarshal(raw[0], &shape); err != nil { 33 return fmt.Errorf("shape deserialization error: %w", err) 34 } 35 36 var flatData []float64 37 if err := json.Unmarshal(raw[1], &flatData); err != nil { 38 return fmt.Errorf("data deserialization error: %w", err) 39 } 40 41 res, err := reshape4D(shape, flatData) 42 if err != nil { 43 return fmt.Errorf("reshape error: %w", err) 44 } 45 46 *a = res 47 return nil 48 } 49 50 // Helper methods 51 func (a Array4D) getShape() [4]int { 52 if len(a) == 0 { 53 return [4]int{0, 0, 0, 0} 54 } 55 return [4]int{ 56 len(a), 57 len(a[0]), 58 len(a[0][0]), 59 len(a[0][0][0]), 60 } 61 } 62 63 func (a Array4D) flatten() []float64 { 64 size := 1 65 for _, dim := range a.getShape() { 66 size *= dim 67 } 68 result := make([]float64, 0, size) 69 70 for _, d1 := range a { 71 for _, d2 := range d1 { 72 for _, d3 := range d2 { 73 result = append(result, d3...) 74 } 75 } 76 } 77 return result 78 } 79 80 func reshape4D(shape [4]int, data []float64) (Array4D, error) { 81 total := shape[0] * shape[1] * shape[2] * shape[3] 82 if len(data) != total { 83 return nil, fmt.Errorf("data length %d doesn't match shape %v", len(data), shape) 84 } 85 86 arr := make(Array4D, shape[0]) 87 idx := 0 88 89 for i := 0; i < shape[0]; i++ { 90 arr[i] = make([][][]float64, shape[1]) 91 for j := 0; j < shape[1]; j++ { 92 arr[i][j] = make([][]float64, shape[2]) 93 for k := 0; k < shape[2]; k++ { 94 arr[i][j][k] = make([]float64, shape[3]) 95 for l := 0; l < shape[3]; l++ { 96 if idx >= len(data) { 97 return nil, errors.New("data index overflow") 98 } 99 arr[i][j][k][l] = data[idx] 100 idx++ 101 } 102 } 103 } 104 } 105 return arr, nil 106 }