// Copyright Michael Monagan 2019-2020 // Compile with gcc -O3 -shared -o linalg.so -fPIC linalg.c #define LONG long long int #include #include #include "int128g.c" /******************************************************************************************/ /* Zp arithmetic */ /******************************************************************************************/ LONG add64s(LONG a, LONG b, LONG p) { LONG t; t = (a-p)+b; t += (t>>63) & p; return t; } LONG sub64s(LONG a, LONG b, LONG p) { LONG t; t = a-b; t += (t>>63) & p; return t; } LONG neg64s(LONG a, LONG p) { return (a==0) ? 0 : p-a; } LONG mul64s(LONG a, LONG b, LONG p) { LONG q, r; __asm__ __volatile__( \ " mulq %%rdx \n\t" \ " divq %4 \n\t" \ : "=a"(q), "=d"(r) : "0"(a), "1"(b), "rm"(p)); return r; } /* c^(-1) mod p assuming 0 < c < p < 2^63 */ LONG inv64s( LONG c, LONG p ) { LONG d,r,q,r1,c1,d1; d = p; c1 = 1; d1 = 0; while( d != 0 ) { q = c / d; r = c - q*d; r1 = c1 - q*d1; c = d; c1 = d1; d = r; d1 = r1; } if( c!=1 ) return( 0 ); if( c1 < 0 ) c1 += p; return( c1 ); } /******************************************************************************************/ /* Linear algebra routines */ /******************************************************************************************/ LONG * matrix64s( int n ) { LONG *A; LONG N; N = n; N = n*N; N = sizeof(LONG) * N; A = (LONG *) malloc(N); return A; } /* print an array in form [a0,a1,...,an-1] */ void vecprint64s( LONG *A, int n ) { int i; printf("["); for( i=0; i0 && k%100 == 0 ) printf("elimination at row %d\n",k); for( i=k; i=n ) { d = 0; break; } if( i!=k ) { // interchange row k with row i for( j=k; j M = [[1,1],[1,2]] int i,j,k,r,c,rank; LONG t,m64; //printf("n=%d m=%d\n",n,m); //printf("A=\n"); matprint64s( A, n, m ); recint P; P = recip1(p); for( i=0; i=n ) { c++; continue; } // move to next column if( i!=r ) { j = rows[i]; rows[i] = rows[r]; rows[r] = j; }; // row interchange cols[rank] = c; rank ++; k = rows[r]; // pivot row t = inv64s(A[k*m64+c],p); A[k*m64+c] = 1; // make the pivot row have a pivot = 1 for( j=c+1; j