/ Drivers / CMSIS / DSP / Source / MatrixFunctions / arm_mat_solve_lower_triangular_f32.c
arm_mat_solve_lower_triangular_f32.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_mat_solve_lower_triangular_f32.c
  4   * Description:  Solve linear system LT X = A with LT lower triangular matrix
  5   *
  6   * $Date:        23 April 2021
  7   * $Revision:    V1.9.0
  8   *
  9   * Target Processor: Cortex-M and Cortex-A cores
 10   * -------------------------------------------------------------------- */
 11  /*
 12   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 13   *
 14   * SPDX-License-Identifier: Apache-2.0
 15   *
 16   * Licensed under the Apache License, Version 2.0 (the License); you may
 17   * not use this file except in compliance with the License.
 18   * You may obtain a copy of the License at
 19   *
 20   * www.apache.org/licenses/LICENSE-2.0
 21   *
 22   * Unless required by applicable law or agreed to in writing, software
 23   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 24   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 25   * See the License for the specific language governing permissions and
 26   * limitations under the License.
 27   */
 28  
 29  #include "dsp/matrix_functions.h"
 30  
 31  /**
 32    @ingroup groupMatrix
 33   */
 34  
 35  
 36  /**
 37    @addtogroup MatrixInv
 38    @{
 39   */
 40  
 41  
 42     /**
 43     * @brief Solve LT . X = A where LT is a lower triangular matrix
 44     * @param[in]  lt  The lower triangular matrix
 45     * @param[in]  a  The matrix a
 46     * @param[out] dst The solution X of LT . X = A
 47     * @return The function returns ARM_MATH_SINGULAR, if the system can't be solved.
 48     */
 49  
 50  #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
 51  
 52  #include "arm_helium_utils.h"
 53  
 54    arm_status arm_mat_solve_lower_triangular_f32(
 55    const arm_matrix_instance_f32 * lt,
 56    const arm_matrix_instance_f32 * a,
 57    arm_matrix_instance_f32 * dst)
 58    {
 59    arm_status status;                             /* status of matrix inverse */
 60  
 61  
 62  #ifdef ARM_MATH_MATRIX_CHECK
 63  
 64    /* Check for matrix mismatch condition */
 65    if ((lt->numRows != lt->numCols) ||
 66        (lt->numRows != a->numRows)   )
 67    {
 68      /* Set status as ARM_MATH_SIZE_MISMATCH */
 69      status = ARM_MATH_SIZE_MISMATCH;
 70    }
 71    else
 72  
 73  #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
 74  
 75    {
 76      /* a1 b1 c1   x1 = a1
 77            b2 c2   x2   a2
 78               c3   x3   a3
 79  
 80      x3 = a3 / c3 
 81      x2 = (a2 - c2 x3) / b2
 82  
 83      */
 84      int i,j,k,n,cols;
 85  
 86      n = dst->numRows;
 87      cols = dst->numCols;
 88  
 89      float32_t *pX = dst->pData;
 90      float32_t *pLT = lt->pData;
 91      float32_t *pA = a->pData;
 92  
 93      float32_t *lt_row;
 94      float32_t *a_col;
 95  
 96      float32_t invLT;
 97  
 98      f32x4_t vecA;
 99      f32x4_t vecX;
100  
101      for(i=0; i < n ; i++)
102      {
103  
104        for(j=0; j+3 < cols; j += 4)
105        {
106              vecA = vld1q_f32(&pA[i * cols + j]);
107  
108              for(k=0; k < i; k++)
109              {
110                  vecX = vld1q_f32(&pX[cols*k+j]);
111                  vecA = vfmsq(vecA,vdupq_n_f32(pLT[n*i + k]),vecX);
112              }
113  
114              if (pLT[n*i + i]==0.0f)
115              {
116                return(ARM_MATH_SINGULAR);
117              }
118  
119              invLT = 1.0f / pLT[n*i + i];
120              vecA = vmulq(vecA,vdupq_n_f32(invLT));
121              vst1q(&pX[i*cols+j],vecA);
122  
123         }
124  
125         for(; j < cols; j ++)
126         {
127              a_col = &pA[j];
128              lt_row = &pLT[n*i];
129  
130              float32_t tmp=a_col[i * cols];
131              
132              for(k=0; k < i; k++)
133              {
134                  tmp -= lt_row[k] * pX[cols*k+j];
135              }
136  
137              if (lt_row[i]==0.0f)
138              {
139                return(ARM_MATH_SINGULAR);
140              }
141              tmp = tmp / lt_row[i];
142              pX[i*cols+j] = tmp;
143          }
144  
145      }
146      status = ARM_MATH_SUCCESS;
147  
148    }
149  
150    /* Return to application */
151    return (status);
152  }
153  #else
154  #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
155    arm_status arm_mat_solve_lower_triangular_f32(
156    const arm_matrix_instance_f32 * lt,
157    const arm_matrix_instance_f32 * a,
158    arm_matrix_instance_f32 * dst)
159    {
160    arm_status status;                             /* status of matrix inverse */
161  
162  
163  #ifdef ARM_MATH_MATRIX_CHECK
164  
165    /* Check for matrix mismatch condition */
166    if ((lt->numRows != lt->numCols) ||
167        (lt->numRows != a->numRows)   )
168    {
169      /* Set status as ARM_MATH_SIZE_MISMATCH */
170      status = ARM_MATH_SIZE_MISMATCH;
171    }
172    else
173  
174  #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
175  
176    {
177      /* a1 b1 c1   x1 = a1
178            b2 c2   x2   a2
179               c3   x3   a3
180  
181      x3 = a3 / c3 
182      x2 = (a2 - c2 x3) / b2
183  
184      */
185      int i,j,k,n,cols;
186  
187      n = dst->numRows;
188      cols = dst->numCols;
189  
190      float32_t *pX = dst->pData;
191      float32_t *pLT = lt->pData;
192      float32_t *pA = a->pData;
193  
194      float32_t *lt_row;
195      float32_t *a_col;
196  
197      float32_t invLT;
198  
199      f32x4_t vecA;
200      f32x4_t vecX;
201  
202      for(i=0; i < n ; i++)
203      {
204  
205        for(j=0; j+3 < cols; j += 4)
206        {
207              vecA = vld1q_f32(&pA[i * cols + j]);
208  
209              for(k=0; k < i; k++)
210              {
211                  vecX = vld1q_f32(&pX[cols*k+j]);
212                  vecA = vfmsq_f32(vecA,vdupq_n_f32(pLT[n*i + k]),vecX);
213              }
214  
215              if (pLT[n*i + i]==0.0f)
216              {
217                return(ARM_MATH_SINGULAR);
218              }
219  
220              invLT = 1.0f / pLT[n*i + i];
221              vecA = vmulq_f32(vecA,vdupq_n_f32(invLT));
222              vst1q_f32(&pX[i*cols+j],vecA);
223  
224         }
225  
226         for(; j < cols; j ++)
227         {
228              a_col = &pA[j];
229              lt_row = &pLT[n*i];
230  
231              float32_t tmp=a_col[i * cols];
232              
233              for(k=0; k < i; k++)
234              {
235                  tmp -= lt_row[k] * pX[cols*k+j];
236              }
237  
238              if (lt_row[i]==0.0f)
239              {
240                return(ARM_MATH_SINGULAR);
241              }
242              tmp = tmp / lt_row[i];
243              pX[i*cols+j] = tmp;
244          }
245  
246      }
247      status = ARM_MATH_SUCCESS;
248  
249    }
250  
251    /* Return to application */
252    return (status);
253  }
254  #else
255    arm_status arm_mat_solve_lower_triangular_f32(
256    const arm_matrix_instance_f32 * lt,
257    const arm_matrix_instance_f32 * a,
258    arm_matrix_instance_f32 * dst)
259    {
260    arm_status status;                             /* status of matrix inverse */
261  
262  
263  #ifdef ARM_MATH_MATRIX_CHECK
264    /* Check for matrix mismatch condition */
265    if ((lt->numRows != lt->numCols) ||
266        (lt->numRows != a->numRows)   )
267    {
268      /* Set status as ARM_MATH_SIZE_MISMATCH */
269      status = ARM_MATH_SIZE_MISMATCH;
270    }
271    else
272  
273  #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
274  
275    {
276      /* a1 b1 c1   x1 = a1
277            b2 c2   x2   a2
278               c3   x3   a3
279  
280      x3 = a3 / c3 
281      x2 = (a2 - c2 x3) / b2
282  
283      */
284      int i,j,k,n,cols;
285  
286      float32_t *pX = dst->pData;
287      float32_t *pLT = lt->pData;
288      float32_t *pA = a->pData;
289  
290      float32_t *lt_row;
291      float32_t *a_col;
292  
293      n = dst->numRows;
294      cols = dst -> numCols;
295  
296  
297      for(j=0; j < cols; j ++)
298      {
299         a_col = &pA[j];
300  
301         for(i=0; i < n ; i++)
302         {
303              float32_t tmp=a_col[i * cols];
304  
305              lt_row = &pLT[n*i];
306              
307              for(k=0; k < i; k++)
308              {
309                  tmp -= lt_row[k] * pX[cols*k+j];
310              }
311  
312              if (lt_row[i]==0.0f)
313              {
314                return(ARM_MATH_SINGULAR);
315              }
316              tmp = tmp / lt_row[i];
317              pX[i*cols+j] = tmp;
318         }
319  
320      }
321      status = ARM_MATH_SUCCESS;
322  
323    }
324  
325    /* Return to application */
326    return (status);
327  }
328  #endif /* #if defined(ARM_MATH_NEON) */
329  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
330  
331  /**
332    @} end of MatrixInv group
333   */