19 #include <containers/KernelMatrix.hpp> 50 size_t nrhs = Y->
col();
57 xsyrk(
"Lower",
"No transpose", d, n,
58 1.0, X->data(), d, 0.0, XXt.data(), d );
60 for (
size_t i = 0; i < d; i ++ ) XXt( i, i ) += lambda;
63 xgemm(
"No transpose",
"No transpose", d, n, nrhs,
69 xposv(
"Lower", d, nrhs, X->data(), d, XY.data(), d );
99 for (
size_t it = 0; it < niter; it ++ )
101 Data<T> Gradient( n, nclass, 0.0 );
106 #pragma omp parallel for 107 for (
size_t i = 0; i < n; i ++ )
110 for (
size_t j = 0; j < nclass; j ++ ) prob_all += P( i, j );
111 for (
size_t j = 0; j < nclass; j ++ ) P( i, j ) /= prob_all;
112 P( i, (
size_t)(*Y)[ i ] ) -= 1.0;
115 H.Multiply( Gradient, P );
117 #pragma omp parallel for 118 for (
size_t i = 0; i < n; i ++ )
120 for (
size_t j = 0; j < nclass; j ++ )
122 W( i, j ) += ( -1.0 * alpha / n ) * Gradient( i, j );
130 size_t n_correct = 0;
131 for (
size_t i = 0; i < n; i ++ )
133 size_t goal = (*Y)[ i ];
136 for (
size_t j = 0; j < nclass; j ++ )
138 if ( P( i, j ) > prob )
144 if ( pred == goal ) n_correct ++;
147 printf(
"Accuracy: %lf\n", (
double)n_correct / n );
150 ofstream fout(
"weight.dat", ios::out | ios::binary );
151 fout.write( (
char*)W.data(), W.size() *
sizeof(T) );
175 Data<T> W( n, (
size_t)1.0, 0.0 );
176 Data<T> B( n, (
size_t)1.0, 0.0 );
178 for (
size_t it = 0; it < niter; it ++ )
184 H.Multiply( Gradient, W );
187 for (
size_t i = 0; i < n; i ++ )
188 Gradient[ i ] += B[ i ] - (*Y)[ i ];
194 for (
size_t i = 0; i < n; i ++ )
195 Gradient[ i ] += lambda * W[ i ];
197 for (
size_t i = 0; i < n; i ++ )
198 Gradient[ i ] = ( -1.0 * alpha / n ) * Gradient[ i ];
204 H.Multiply( tmp, Gradient );
205 for (
size_t i = 0; i < n; i ++ )
216 size_t n_correct = 0;
217 for (
size_t i = 0; i < n; i ++ )
219 double pred = (int)( Z[ i ] + 0.5 );
220 double goal = (*Y)[ i ];
221 if ( pred == goal ) n_correct ++;
225 printf(
"it %4lu Accuracy: %lf\n", it, (
double)n_correct / n );
234 size_t n_correct = 0;
235 for (
size_t i = 0; i < n; i ++ )
237 double pred = (int)( Z[ i ] + 0.5 );
238 double goal = (*Y)[ i ];
242 if ( pred == goal ) n_correct ++;
245 printf(
"Accuracy: %lf\n", (
double)n_correct / n );
249 std::ofstream fout(
"weight.dat", std::ios::out | std::ios::binary );
250 fout.write( (
char*)W.data(), W.size() *
sizeof(T) );
254 std::ofstream fout(
"bias.dat", std::ios::out | std::ios::binary );
255 fout.write( (
char*)B.data(), B.size() *
sizeof(T) );
Definition: KernelMatrix.hpp:162
Data< T > SoftMax(kernel_s< T > &kernel, size_t nclass, size_t niter)
Definition: regression.hpp:88
Data< T > Lasso(kernel_s< T > &kernel, size_t niter)
Definition: regression.hpp:82
Definition: gofmm.hpp:3779
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Definition: KernelMatrix.hpp:54
Definition: regression.hpp:30
size_t col() const noexcept
Definition: Data.hpp:281
Data< T > Ridge(kernel_s< T > &kernel, size_t niter)
: Support SVD
Definition: regression.hpp:48
size_t row() const noexcept
Definition: Data.hpp:278
Data< T > Solve(kernel_s< T > &kernel, size_t niter)
gradient descent
Definition: regression.hpp:167