| gmm_diag_meat.hpp | | gmm_diag_meat.hpp | |
| | | | |
| skipping to change at line 260 | | skipping to change at line 260 | |
| return status; | | return status; | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| inline | | inline | |
| Col<eT> | | Col<eT> | |
| gmm_diag<eT>::generate() const | | gmm_diag<eT>::generate() const | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
|
| Col<eT> out( (n_gaus > 0) ? n_dims : uword(0) ); | | Col<eT> out( (N_gaus > 0) ? N_dims : uword(0) ); | |
| | | | |
|
| if(n_gaus > 0) | | if(N_gaus > 0) | |
| { | | { | |
| const double val = randu<double>(); | | const double val = randu<double>(); | |
| | | | |
| double csum = double(0); | | double csum = double(0); | |
| uword gaus_id = 0; | | uword gaus_id = 0; | |
| | | | |
|
| for(uword j=0; j < n_gaus; ++j) | | for(uword j=0; j < N_gaus; ++j) | |
| { | | { | |
| csum += hefts[j]; | | csum += hefts[j]; | |
| | | | |
| if(val <= csum) { gaus_id = j; break; } | | if(val <= csum) { gaus_id = j; break; } | |
| } | | } | |
| | | | |
|
| out = randn< Col<eT> >(n_dims); | | out = randn< Col<eT> >(N_dims); | |
| out %= sqrt(dcovs.col(gaus_id)); | | out %= sqrt(dcovs.col(gaus_id)); | |
| out += means.col(gaus_id); | | out += means.col(gaus_id); | |
| } | | } | |
| | | | |
| return out; | | return out; | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| inline | | inline | |
| Mat<eT> | | Mat<eT> | |
|
| gmm_diag<eT>::generate(const uword N) const | | gmm_diag<eT>::generate(const uword N_vec) const | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
|
| Mat<eT> out( ( (n_gaus > 0) ? n_dims : uword(0) ), N ); | | Mat<eT> out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec ); | |
| | | | |
|
| if(n_gaus > 0) | | if(N_gaus > 0) | |
| { | | { | |
| const eT* hefts_mem = hefts.memptr(); | | const eT* hefts_mem = hefts.memptr(); | |
| | | | |
|
| for(uword i=0; i < N; ++i) | | for(uword i=0; i < N_vec; ++i) | |
| { | | { | |
| const double val = randu<double>(); | | const double val = randu<double>(); | |
| | | | |
| double csum = double(0); | | double csum = double(0); | |
| uword gaus_id = 0; | | uword gaus_id = 0; | |
| | | | |
|
| for(uword j=0; j < n_gaus; ++j) | | for(uword j=0; j < N_gaus; ++j) | |
| { | | { | |
| csum += hefts_mem[j]; | | csum += hefts_mem[j]; | |
| | | | |
| if(val <= csum) { gaus_id = j; break; } | | if(val <= csum) { gaus_id = j; break; } | |
| } | | } | |
| | | | |
| subview_col<eT> out_col = out.col(i); | | subview_col<eT> out_col = out.col(i); | |
| | | | |
|
| out_col = randn< Col<eT> >(n_dims); | | out_col = randn< Col<eT> >(N_dims); | |
| out_col %= sqrt(dcovs.col(gaus_id)); | | out_col %= sqrt(dcovs.col(gaus_id)); | |
| out_col += means.col(gaus_id); | | out_col += means.col(gaus_id); | |
| } | | } | |
| } | | } | |
| | | | |
| return out; | | return out; | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| template<typename T1> | | template<typename T1> | |
| | | | |
| skipping to change at line 575 | | skipping to change at line 575 | |
| return out; | | return out; | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| template<typename T1> | | template<typename T1> | |
| inline | | inline | |
| bool | | bool | |
| gmm_diag<eT>::learn | | gmm_diag<eT>::learn | |
| ( | | ( | |
| const Base<eT,T1>& data, | | const Base<eT,T1>& data, | |
|
| const uword n_gaus, | | const uword N_gaus, | |
| const gmm_dist_mode& dist_mode, | | const gmm_dist_mode& dist_mode, | |
| const gmm_seed_mode& seed_mode, | | const gmm_seed_mode& seed_mode, | |
| const uword km_iter, | | const uword km_iter, | |
| const uword em_iter, | | const uword em_iter, | |
| const eT var_floor, | | const eT var_floor, | |
| const bool print_mode | | const bool print_mode | |
| ) | | ) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| | | | |
| skipping to change at line 605 | | skipping to change at line 605 | |
| arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode
must be eucl_dist or maha_dist" ); | | arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode
must be eucl_dist or maha_dist" ); | |
| arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown se
ed_mode" ); | | arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown se
ed_mode" ); | |
| arma_debug_check( (var_floor < eT(0) ), "gmm_diag::learn(): variance f
loor is negative" ); | | arma_debug_check( (var_floor < eT(0) ), "gmm_diag::learn(): variance f
loor is negative" ); | |
| | | | |
| const unwrap<T1> tmp_X(data.get_ref()); | | const unwrap<T1> tmp_X(data.get_ref()); | |
| const Mat<eT>& X = tmp_X.M; | | const Mat<eT>& X = tmp_X.M; | |
| | | | |
| if(X.is_empty() ) { arma_warn(true, "gmm_diag::learn(): given m
atrix is empty" ); reset(); return false; } | | if(X.is_empty() ) { arma_warn(true, "gmm_diag::learn(): given m
atrix is empty" ); reset(); return false; } | |
| if(X.is_finite() == false) { arma_warn(true, "gmm_diag::learn(): given m
atrix has non-finite values"); reset(); return false; } | | if(X.is_finite() == false) { arma_warn(true, "gmm_diag::learn(): given m
atrix has non-finite values"); reset(); return false; } | |
| | | | |
|
| if(n_gaus == 0) { reset(); return true; } | | if(N_gaus == 0) { reset(); return true; } | |
| | | | |
| if(dist_mode == maha_dist) | | if(dist_mode == maha_dist) | |
| { | | { | |
| mah_aux = var(X,1,1); | | mah_aux = var(X,1,1); | |
| | | | |
| const uword mah_aux_n_elem = mah_aux.n_elem; | | const uword mah_aux_n_elem = mah_aux.n_elem; | |
| eT* mah_aux_mem = mah_aux.memptr(); | | eT* mah_aux_mem = mah_aux.memptr(); | |
| | | | |
| for(uword i=0; i < mah_aux_n_elem; ++i) | | for(uword i=0; i < mah_aux_n_elem; ++i) | |
| { | | { | |
| | | | |
| skipping to change at line 635 | | skipping to change at line 635 | |
| | | | |
| if(seed_mode == keep_existing) | | if(seed_mode == keep_existing) | |
| { | | { | |
| if(means.is_empty() ) { arma_warn(true, "gmm_diag::learn(): no
existing means" ); reset(); return false; } | | if(means.is_empty() ) { arma_warn(true, "gmm_diag::learn(): no
existing means" ); reset(); return false; } | |
| if(X.n_rows != means.n_rows) { arma_warn(true, "gmm_diag::learn(): dim
ensionality mismatch"); reset(); return false; } | | if(X.n_rows != means.n_rows) { arma_warn(true, "gmm_diag::learn(): dim
ensionality mismatch"); reset(); return false; } | |
| | | | |
| // TODO: also check for number of vectors? | | // TODO: also check for number of vectors? | |
| } | | } | |
| else | | else | |
| { | | { | |
|
| if(X.n_cols < n_gaus) { arma_warn(true, "gmm_diag::learn(): number of
vectors is less than number of gaussians"); reset(); return false; } | | if(X.n_cols < N_gaus) { arma_warn(true, "gmm_diag::learn(): number of
vectors is less than number of gaussians"); reset(); return false; } | |
| | | | |
|
| reset(X.n_rows, n_gaus); | | reset(X.n_rows, N_gaus); | |
| | | | |
| if(print_mode) { get_stream_err2() << "gmm_diag::learn(): generating i
nitial means\n"; } | | if(print_mode) { get_stream_err2() << "gmm_diag::learn(): generating i
nitial means\n"; } | |
| | | | |
| if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mo
de); } | | if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mo
de); } | |
| else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mo
de); } | | else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mo
de); } | |
| } | | } | |
| | | | |
| // k-means | | // k-means | |
| | | | |
| if(km_iter > 0) | | if(km_iter > 0) | |
| | | | |
| skipping to change at line 741 | | skipping to change at line 741 | |
| init_constants(); | | init_constants(); | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| inline | | inline | |
| void | | void | |
| gmm_diag<eT>::init_constants() | | gmm_diag<eT>::init_constants() | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
|
| const eT tmp = (eT(n_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi); | | const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi); | |
| | | | |
|
| log_det_etc.set_size(n_gaus); | | log_det_etc.set_size(N_gaus); | |
| | | | |
|
| for(uword i=0; i<n_gaus; ++i) | | for(uword i=0; i<N_gaus; ++i) | |
| { | | { | |
| const eT logdet = accu( log(dcovs.col(i)) ); | | const eT logdet = accu( log(dcovs.col(i)) ); | |
| | | | |
| log_det_etc[i] = eT(-1) * ( tmp + eT(0.5) * logdet ); | | log_det_etc[i] = eT(-1) * ( tmp + eT(0.5) * logdet ); | |
| } | | } | |
| | | | |
| log_hefts = log(hefts); // TODO: possible issue when one of the hefts is
zero | | log_hefts = log(hefts); // TODO: possible issue when one of the hefts is
zero | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| arma_hot | | arma_hot | |
| inline | | inline | |
| eT | | eT | |
| gmm_diag<eT>::internal_scalar_log_p(const eT* x) const | | gmm_diag<eT>::internal_scalar_log_p(const eT* x) const | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| const eT* log_hefts_mem = log_hefts.mem; | | const eT* log_hefts_mem = log_hefts.mem; | |
| | | | |
|
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
|
| if(n_gaus > 0) | | if(N_gaus > 0) | |
| { | | { | |
| eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0]; | | eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0]; | |
| | | | |
|
| for(uword g=1; g < n_gaus; ++g) | | for(uword g=1; g < N_gaus; ++g) | |
| { | | { | |
| const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g]; | | const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g]; | |
| | | | |
| log_sum = log_add_exp(log_sum, tmp); | | log_sum = log_add_exp(log_sum, tmp); | |
| } | | } | |
| | | | |
| return log_sum; | | return log_sum; | |
| } | | } | |
| else | | else | |
| { | | { | |
| | | | |
| skipping to change at line 800 | | skipping to change at line 800 | |
| arma_hot | | arma_hot | |
| inline | | inline | |
| eT | | eT | |
| gmm_diag<eT>::internal_scalar_log_p(const eT* x, const uword g) const | | gmm_diag<eT>::internal_scalar_log_p(const eT* x, const uword g) const | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| const eT* mean = means.colptr(g); | | const eT* mean = means.colptr(g); | |
| const eT* dcov = dcovs.colptr(g); | | const eT* dcov = dcovs.colptr(g); | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| | | | |
| eT val_i = eT(0); | | eT val_i = eT(0); | |
| eT val_j = eT(0); | | eT val_j = eT(0); | |
| | | | |
| uword i,j; | | uword i,j; | |
| | | | |
|
| for(i=0, j=1; j<n_dims; i+=2, j+=2) | | for(i=0, j=1; j<N_dims; i+=2, j+=2) | |
| { | | { | |
| eT tmp_i = x[i]; | | eT tmp_i = x[i]; | |
| eT tmp_j = x[j]; | | eT tmp_j = x[j]; | |
| | | | |
| tmp_i -= mean[i]; | | tmp_i -= mean[i]; | |
| tmp_j -= mean[j]; | | tmp_j -= mean[j]; | |
| | | | |
| val_i += (tmp_i*tmp_i) / dcov[i]; | | val_i += (tmp_i*tmp_i) / dcov[i]; | |
| val_j += (tmp_j*tmp_j) / dcov[j]; | | val_j += (tmp_j*tmp_j) / dcov[j]; | |
| } | | } | |
| | | | |
|
| if(i < n_dims) | | if(i < N_dims) | |
| { | | { | |
| const eT tmp = x[i] - mean[i]; | | const eT tmp = x[i] - mean[i]; | |
| | | | |
| val_i += (tmp*tmp) / dcov[i]; | | val_i += (tmp*tmp) / dcov[i]; | |
| } | | } | |
| | | | |
| return eT(-0.5)*(val_i + val_j) + log_det_etc.mem[g]; | | return eT(-0.5)*(val_i + val_j) + log_det_etc.mem[g]; | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| | | | |
| skipping to change at line 931 | | skipping to change at line 931 | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| template<typename T1> | | template<typename T1> | |
| inline | | inline | |
| uword | | uword | |
| gmm_diag<eT>::internal_scalar_assign(const T1& X, const gmm_dist_mode& dist
_mode) const | | gmm_diag<eT>::internal_scalar_assign(const T1& X, const gmm_dist_mode& dist
_mode) const | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
|
| arma_debug_check( (X.n_rows != n_dims), "gmm_diag::assign(): incompatible | | arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible | |
| dimensions" ); | | dimensions" ); | |
| arma_debug_check( (n_gaus == 0), "gmm_diag::assign(): model has no | | arma_debug_check( (N_gaus == 0), "gmm_diag::assign(): model has no | |
| means" ); | | means" ); | |
| | | | |
| const eT* X_mem = X.colptr(0); | | const eT* X_mem = X.colptr(0); | |
| | | | |
| if(dist_mode == eucl_dist) | | if(dist_mode == eucl_dist) | |
| { | | { | |
| eT best_dist = Datum<eT>::inf; | | eT best_dist = Datum<eT>::inf; | |
| uword best_g = 0; | | uword best_g = 0; | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
|
| const eT tmp_dist = distance<eT,1>::eval(n_dims, X_mem, means.colptr(
g), X_mem); | | const eT tmp_dist = distance<eT,1>::eval(N_dims, X_mem, means.colptr(
g), X_mem); | |
| | | | |
| if(tmp_dist <= best_dist) | | if(tmp_dist <= best_dist) | |
| { | | { | |
| best_dist = tmp_dist; | | best_dist = tmp_dist; | |
| best_g = g; | | best_g = g; | |
| } | | } | |
| } | | } | |
| | | | |
| return best_g; | | return best_g; | |
| } | | } | |
| else | | else | |
| if(dist_mode == prob_dist) | | if(dist_mode == prob_dist) | |
| { | | { | |
| const eT* log_hefts_mem = log_hefts.memptr(); | | const eT* log_hefts_mem = log_hefts.memptr(); | |
| | | | |
| eT best_p = -Datum<eT>::inf; | | eT best_p = -Datum<eT>::inf; | |
| uword best_g = 0; | | uword best_g = 0; | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
| const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g]; | | const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g]; | |
| | | | |
| if(tmp_p >= best_p) | | if(tmp_p >= best_p) | |
| { | | { | |
| best_p = tmp_p; | | best_p = tmp_p; | |
| best_g = g; | | best_g = g; | |
| } | | } | |
| } | | } | |
| | | | |
| | | | |
| skipping to change at line 994 | | skipping to change at line 994 | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| template<typename T1> | | template<typename T1> | |
| inline | | inline | |
| void | | void | |
| gmm_diag<eT>::internal_vec_assign(urowvec& out, const T1& X, const gmm_dist
_mode& dist_mode) const | | gmm_diag<eT>::internal_vec_assign(urowvec& out, const T1& X, const gmm_dist
_mode& dist_mode) const | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
|
| arma_debug_check( (X.n_rows != n_dims), "gmm_diag::assign(): incompatible
dimensions" ); | | arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible
dimensions" ); | |
| | | | |
|
| const uword X_n_cols = (n_gaus > 0) ? X.n_cols : 0; | | const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0; | |
| | | | |
| out.set_size(1,X_n_cols); | | out.set_size(1,X_n_cols); | |
| | | | |
| uword* out_mem = out.memptr(); | | uword* out_mem = out.memptr(); | |
| | | | |
| if(dist_mode == eucl_dist) | | if(dist_mode == eucl_dist) | |
| { | | { | |
| for(uword i=0; i<X_n_cols; ++i) | | for(uword i=0; i<X_n_cols; ++i) | |
| { | | { | |
| const eT* X_colptr = X.colptr(i); | | const eT* X_colptr = X.colptr(i); | |
| | | | |
| eT best_dist = Datum<eT>::inf; | | eT best_dist = Datum<eT>::inf; | |
| uword best_g = 0; | | uword best_g = 0; | |
| | | | |
|
| for(uword g=0; g<n_gaus; ++g) | | for(uword g=0; g<N_gaus; ++g) | |
| { | | { | |
|
| const eT tmp_dist = distance<eT,1>::eval(n_dims, X_colptr, means.co
lptr(g), X_colptr); | | const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.co
lptr(g), X_colptr); | |
| | | | |
| if(tmp_dist <= best_dist) | | if(tmp_dist <= best_dist) | |
| { | | { | |
| best_dist = tmp_dist; | | best_dist = tmp_dist; | |
| best_g = g; | | best_g = g; | |
| } | | } | |
| } | | } | |
| | | | |
| out_mem[i] = best_g; | | out_mem[i] = best_g; | |
| } | | } | |
| | | | |
| skipping to change at line 1040 | | skipping to change at line 1040 | |
| { | | { | |
| const eT* log_hefts_mem = log_hefts.memptr(); | | const eT* log_hefts_mem = log_hefts.memptr(); | |
| | | | |
| for(uword i=0; i<X_n_cols; ++i) | | for(uword i=0; i<X_n_cols; ++i) | |
| { | | { | |
| const eT* X_colptr = X.colptr(i); | | const eT* X_colptr = X.colptr(i); | |
| | | | |
| eT best_p = -Datum<eT>::inf; | | eT best_p = -Datum<eT>::inf; | |
| uword best_g = 0; | | uword best_g = 0; | |
| | | | |
|
| for(uword g=0; g<n_gaus; ++g) | | for(uword g=0; g<N_gaus; ++g) | |
| { | | { | |
| const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem
[g]; | | const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem
[g]; | |
| | | | |
| if(tmp_p >= best_p) | | if(tmp_p >= best_p) | |
| { | | { | |
| best_p = tmp_p; | | best_p = tmp_p; | |
| best_g = g; | | best_g = g; | |
| } | | } | |
| } | | } | |
| | | | |
| | | | |
| skipping to change at line 1067 | | skipping to change at line 1067 | |
| } | | } | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| inline | | inline | |
| void | | void | |
| gmm_diag<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_
dist_mode& dist_mode) const | | gmm_diag<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_
dist_mode& dist_mode) const | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
| const uword X_n_cols = X.n_cols; | | const uword X_n_cols = X.n_cols; | |
| | | | |
|
| hist.zeros(n_gaus); | | hist.zeros(N_gaus); | |
| | | | |
|
| if(n_gaus == 0) { return; } | | if(N_gaus == 0) { return; } | |
| | | | |
| uword* hist_mem = hist.memptr(); | | uword* hist_mem = hist.memptr(); | |
| | | | |
| if(dist_mode == eucl_dist) | | if(dist_mode == eucl_dist) | |
| { | | { | |
| for(uword i=0; i<X_n_cols; ++i) | | for(uword i=0; i<X_n_cols; ++i) | |
| { | | { | |
| const eT* X_colptr = X.colptr(i); | | const eT* X_colptr = X.colptr(i); | |
| | | | |
| eT best_dist = Datum<eT>::inf; | | eT best_dist = Datum<eT>::inf; | |
| uword best_g = 0; | | uword best_g = 0; | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
|
| const eT tmp_dist = distance<eT,1>::eval(n_dims, X_colptr, means.co
lptr(g), X_colptr); | | const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.co
lptr(g), X_colptr); | |
| | | | |
| if(tmp_dist <= best_dist) | | if(tmp_dist <= best_dist) | |
| { | | { | |
| best_dist = tmp_dist; | | best_dist = tmp_dist; | |
| best_g = g; | | best_g = g; | |
| } | | } | |
| } | | } | |
| | | | |
| hist_mem[best_g]++; | | hist_mem[best_g]++; | |
| } | | } | |
| | | | |
| skipping to change at line 1113 | | skipping to change at line 1113 | |
| { | | { | |
| const eT* log_hefts_mem = log_hefts.memptr(); | | const eT* log_hefts_mem = log_hefts.memptr(); | |
| | | | |
| for(uword i=0; i<X_n_cols; ++i) | | for(uword i=0; i<X_n_cols; ++i) | |
| { | | { | |
| const eT* X_colptr = X.colptr(i); | | const eT* X_colptr = X.colptr(i); | |
| | | | |
| eT best_p = -Datum<eT>::inf; | | eT best_p = -Datum<eT>::inf; | |
| uword best_g = 0; | | uword best_g = 0; | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
| const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem
[g]; | | const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem
[g]; | |
| | | | |
| if(tmp_p >= best_p) | | if(tmp_p >= best_p) | |
| { | | { | |
| best_p = tmp_p; | | best_p = tmp_p; | |
| best_g = g; | | best_g = g; | |
| } | | } | |
| } | | } | |
| | | | |
| | | | |
| skipping to change at line 1135 | | skipping to change at line 1135 | |
| } | | } | |
| } | | } | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| template<uword dist_id> | | template<uword dist_id> | |
| inline | | inline | |
| void | | void | |
| gmm_diag<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode&
seed_mode) | | gmm_diag<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode&
seed_mode) | |
| { | | { | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
| if( (seed_mode == static_subset) || (seed_mode == random_subset) ) | | if( (seed_mode == static_subset) || (seed_mode == random_subset) ) | |
| { | | { | |
| uvec initial_indices; | | uvec initial_indices; | |
| | | | |
|
| if(seed_mode == static_subset) { initial_indices = linspace<uvec> | | if(seed_mode == static_subset) { initial_indices = linspace<uvec> | |
| (0, X.n_cols-1, n_gaus); } | | (0, X.n_cols-1, N_gaus); } | |
| else if(seed_mode == random_subset) { initial_indices = sort_index(ran | | else if(seed_mode == random_subset) { initial_indices = sort_index(ran | |
| du<vec>(X.n_cols)).rows(0,n_gaus-1); } | | du<vec>(X.n_cols)).rows(0,N_gaus-1); } | |
| | | | |
| // not using randi() here as on some primitive systems it produces vect
ors with non-unique values | | // not using randi() here as on some primitive systems it produces vect
ors with non-unique values | |
| | | | |
| // initial_indices.print("initial_indices:"); | | // initial_indices.print("initial_indices:"); | |
| | | | |
| access::rw(means) = X.cols(initial_indices); | | access::rw(means) = X.cols(initial_indices); | |
| } | | } | |
| else | | else | |
| if( (seed_mode == static_spread) || (seed_mode == random_spread) ) | | if( (seed_mode == static_spread) || (seed_mode == random_spread) ) | |
| { | | { | |
| | | | |
| skipping to change at line 1165 | | skipping to change at line 1165 | |
| | | | |
| if(seed_mode == static_spread) { start_index = X.n_cols / 2;
} | | if(seed_mode == static_spread) { start_index = X.n_cols / 2;
} | |
| else if(seed_mode == random_spread) { start_index = as_scalar(randi<uv
ec>(1, distr_param(0,X.n_cols-1))); } | | else if(seed_mode == random_spread) { start_index = as_scalar(randi<uv
ec>(1, distr_param(0,X.n_cols-1))); } | |
| | | | |
| access::rw(means).col(0) = X.unsafe_col(start_index); | | access::rw(means).col(0) = X.unsafe_col(start_index); | |
| | | | |
| const eT* mah_aux_mem = mah_aux.memptr(); | | const eT* mah_aux_mem = mah_aux.memptr(); | |
| | | | |
| running_stat<double> rs; | | running_stat<double> rs; | |
| | | | |
|
| for(uword g=1; g < n_gaus; ++g) | | for(uword g=1; g < N_gaus; ++g) | |
| { | | { | |
| eT max_dist = eT(0); | | eT max_dist = eT(0); | |
| uword best_i = uword(0); | | uword best_i = uword(0); | |
| | | | |
| for(uword i=0; i < X.n_cols; ++i) | | for(uword i=0; i < X.n_cols; ++i) | |
| { | | { | |
| rs.reset(); | | rs.reset(); | |
| | | | |
| const eT* X_colptr = X.colptr(i); | | const eT* X_colptr = X.colptr(i); | |
| | | | |
| bool ignore_i = false; | | bool ignore_i = false; | |
| | | | |
| // find the average distance between sample i and the means so far | | // find the average distance between sample i and the means so far | |
| for(uword h = 0; h < g; ++h) | | for(uword h = 0; h < g; ++h) | |
| { | | { | |
|
| const eT dist = distance<eT,dist_id>::eval(n_dims, X_colptr, mean
s.colptr(h), mah_aux_mem); | | const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, mean
s.colptr(h), mah_aux_mem); | |
| | | | |
| // ignore sample already selected as a mean | | // ignore sample already selected as a mean | |
| if(dist == eT(0)) { ignore_i = true; break; } | | if(dist == eT(0)) { ignore_i = true; break; } | |
| else { rs(dist); } | | else { rs(dist); } | |
| } | | } | |
| | | | |
| if( (rs.mean() >= max_dist) && (ignore_i == false)) | | if( (rs.mean() >= max_dist) && (ignore_i == false)) | |
| { | | { | |
| max_dist = rs.mean(); best_i = i; | | max_dist = rs.mean(); best_i = i; | |
| } | | } | |
| | | | |
| skipping to change at line 1209 | | skipping to change at line 1209 | |
| // get_stream_err2() << "generate_initial_means():" << '\n'; | | // get_stream_err2() << "generate_initial_means():" << '\n'; | |
| // means.print(); | | // means.print(); | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| template<uword dist_id> | | template<uword dist_id> | |
| inline | | inline | |
| void | | void | |
| gmm_diag<eT>::generate_initial_dcovs_and_hefts(const Mat<eT>& X, const eT v
ar_floor) | | gmm_diag<eT>::generate_initial_dcovs_and_hefts(const Mat<eT>& X, const eT v
ar_floor) | |
| { | | { | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
|
| field< running_stat_vec< Col<eT> > > rs(n_gaus); | | field< running_stat_vec< Col<eT> > > rs(N_gaus); | |
| | | | |
| const eT* mah_aux_mem = mah_aux.memptr(); | | const eT* mah_aux_mem = mah_aux.memptr(); | |
| | | | |
| for(uword i=0; i<X.n_cols; ++i) | | for(uword i=0; i<X.n_cols; ++i) | |
| { | | { | |
| const eT* X_colptr = X.colptr(i); | | const eT* X_colptr = X.colptr(i); | |
| | | | |
| double min_dist = Datum<eT>::inf; | | double min_dist = Datum<eT>::inf; | |
| uword best_g = 0; | | uword best_g = 0; | |
| | | | |
|
| for(uword g=0; g<n_gaus; ++g) | | for(uword g=0; g<N_gaus; ++g) | |
| { | | { | |
|
| const double dist = distance<eT,dist_id>::eval(n_dims, X_colptr, mean
s.colptr(g), mah_aux_mem); | | const double dist = distance<eT,dist_id>::eval(N_dims, X_colptr, mean
s.colptr(g), mah_aux_mem); | |
| | | | |
| if(dist <= min_dist) { min_dist = dist; best_g = g; } | | if(dist <= min_dist) { min_dist = dist; best_g = g; } | |
| } | | } | |
| | | | |
| rs(best_g)(X.unsafe_col(i)); | | rs(best_g)(X.unsafe_col(i)); | |
| } | | } | |
| | | | |
|
| for(uword g=0; g<n_gaus; ++g) | | for(uword g=0; g<N_gaus; ++g) | |
| { | | { | |
| if( rs(g).count() >= eT(2) ) | | if( rs(g).count() >= eT(2) ) | |
| { | | { | |
| access::rw(dcovs).col(g) = rs(g).var(1); | | access::rw(dcovs).col(g) = rs(g).var(1); | |
| } | | } | |
| else | | else | |
| { | | { | |
| access::rw(dcovs).col(g).ones(); | | access::rw(dcovs).col(g).ones(); | |
| } | | } | |
| | | | |
| access::rw(hefts)(g) = (std::max)(eT(1), rs(g).count()) / eT(X.n_cols); | | access::rw(hefts)(g) = (std::max)(eT(1), rs(g).count()) / eT(X.n_cols); | |
| } | | } | |
| | | | |
| em_fix_params(var_floor); | | em_fix_params(var_floor); | |
| } | | } | |
| | | | |
|
| | | //! multi-threaded implementation of k-means, inspired by MapReduce | |
| template<typename eT> | | template<typename eT> | |
| template<uword dist_id> | | template<uword dist_id> | |
| inline | | inline | |
| bool | | bool | |
| gmm_diag<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool
verbose) | | gmm_diag<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool
verbose) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| // get_stream_err2() << "km_iterate()" << '\n'; | | | |
| // get_stream_err2() << "dist_id: " << dist_id << '\n'; | | | |
| | | | |
| if(verbose) | | if(verbose) | |
| { | | { | |
| get_stream_err2().unsetf(ios::showbase); | | get_stream_err2().unsetf(ios::showbase); | |
| get_stream_err2().unsetf(ios::uppercase); | | get_stream_err2().unsetf(ios::uppercase); | |
| get_stream_err2().unsetf(ios::showpos); | | get_stream_err2().unsetf(ios::showpos); | |
| get_stream_err2().unsetf(ios::scientific); | | get_stream_err2().unsetf(ios::scientific); | |
| | | | |
| get_stream_err2().setf(ios::right); | | get_stream_err2().setf(ios::right); | |
| get_stream_err2().setf(ios::fixed); | | get_stream_err2().setf(ios::fixed); | |
| } | | } | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
| Mat<eT> old_means = means; | | Mat<eT> old_means = means; | |
| Mat<eT> new_means = means; | | Mat<eT> new_means = means; | |
| | | | |
| running_mean_scalar<double> rs_delta; | | running_mean_scalar<double> rs_delta; | |
| | | | |
|
| field< running_mean_vec<eT> > running_means(n_gaus); | | field< running_mean_vec<eT> > running_means(N_gaus); | |
| | | | |
| const eT* mah_aux_mem = mah_aux.memptr(); | | const eT* mah_aux_mem = mah_aux.memptr(); | |
| | | | |
| #if defined(_OPENMP) | | #if defined(_OPENMP) | |
| const arma_omp_state save_omp_state; | | const arma_omp_state save_omp_state; | |
| | | | |
| omp_set_dynamic(0); | | omp_set_dynamic(0); | |
| | | | |
| //const uword n_cores = 0; | | //const uword n_cores = 0; | |
| const uword n_cores = uword(omp_get_num_procs()); | | const uword n_cores = uword(omp_get_num_procs()); | |
| const uword n_threads = (n_cores > 0) ? ( (n_cores <= X.n_cols) ? n_cor
es : 1 ) : 1; | | const uword n_threads = (n_cores > 0) ? ( (n_cores <= X.n_cols) ? n_cor
es : 1 ) : 1; | |
| | | | |
| field< field< running_mean_vec<eT> > > t_running_means(n_threads); | | field< field< running_mean_vec<eT> > > t_running_means(n_threads); | |
| | | | |
|
| for(uword t=0; t < n_threads; ++t) { t_running_means[t].set_size(n_gau
s); } | | for(uword t=0; t < n_threads; ++t) { t_running_means[t].set_size(N_gau
s); } | |
| | | | |
| field< uvec > t_boundary(n_threads); | | field< uvec > t_boundary(n_threads); | |
| | | | |
| const uword chunk_size = X.n_cols / n_threads; | | const uword chunk_size = X.n_cols / n_threads; | |
| | | | |
|
| uword count = 0; | | uword vec_count = 0; | |
| | | | |
| for(uword t=0; t<n_threads; t++) | | for(uword t=0; t<n_threads; t++) | |
| { | | { | |
| t_boundary[t].set_size(2); | | t_boundary[t].set_size(2); | |
| | | | |
|
| t_boundary[t][0] = count; | | t_boundary[t][0] = vec_count; | |
| | | | |
|
| count += chunk_size; | | vec_count += chunk_size; | |
| | | | |
|
| t_boundary[t][1] = count-1; | | t_boundary[t][1] = vec_count-1; | |
| } | | } | |
| | | | |
| t_boundary[n_threads-1][1] = X.n_cols - 1; | | t_boundary[n_threads-1][1] = X.n_cols - 1; | |
| | | | |
|
| vec tmp_mean(n_dims); | | vec tmp_mean(N_dims); | |
| | | | |
| if(verbose) | | if(verbose) | |
| { | | { | |
| get_stream_err2() << "gmm_diag::learn(): k-means: n_threads: " << n_
threads << '\n'; | | get_stream_err2() << "gmm_diag::learn(): k-means: n_threads: " << n_
threads << '\n'; | |
| get_stream_err2() << "gmm_diag::learn(): k-means: chunk_size: " << ch
unk_size << '\n'; | | get_stream_err2() << "gmm_diag::learn(): k-means: chunk_size: " << ch
unk_size << '\n'; | |
| } | | } | |
| #endif | | #endif | |
| | | | |
| for(uword iter=1; iter <= max_iter; ++iter) | | for(uword iter=1; iter <= max_iter; ++iter) | |
| { | | { | |
| #if defined(_OPENMP) | | #if defined(_OPENMP) | |
| { | | { | |
| for(uword t=0; t < n_threads; ++t) | | for(uword t=0; t < n_threads; ++t) | |
| { | | { | |
|
| for(uword g=0; g < n_gaus; ++g) { t_running_means[t][g].reset(); } | | for(uword g=0; g < N_gaus; ++g) { t_running_means[t][g].reset(); } | |
| } | | } | |
| | | | |
|
| | | // km_update_stats() is the "map" operation, which produces partial m | |
| | | eans | |
| | | | |
| #pragma omp parallel for | | #pragma omp parallel for | |
| for(uword t=0; t < n_threads; ++t) | | for(uword t=0; t < n_threads; ++t) | |
| { | | { | |
| const uvec& boundary = t_boundary[t]; | | const uvec& boundary = t_boundary[t]; | |
| | | | |
| field< running_mean_vec<eT> >& current_running_means = t_running_me
ans[t]; | | field< running_mean_vec<eT> >& current_running_means = t_running_me
ans[t]; | |
| | | | |
| km_update_stats<dist_id>(X, boundary[0], boundary[1], old_means, cu
rrent_running_means); | | km_update_stats<dist_id>(X, boundary[0], boundary[1], old_means, cu
rrent_running_means); | |
| } | | } | |
| | | | |
|
| // the "reduce" operation, which combines the results produced by sep
erate threads; | | // the "reduce" operation, which combines the partial means produced
by the separate threads; | |
| // takes into account the counts for each mean | | // takes into account the counts for each mean | |
|
| for(uword g=0; g < n_gaus; ++g) | | | |
| | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
| uword total_count = 0; | | uword total_count = 0; | |
| | | | |
| for(uword t=0; t < n_threads; ++t) { total_count += t_running_mean
s[t][g].count(); } | | for(uword t=0; t < n_threads; ++t) { total_count += t_running_mean
s[t][g].count(); } | |
| | | | |
| tmp_mean.zeros(); | | tmp_mean.zeros(); | |
| | | | |
| bool dead = true; | | bool dead = true; | |
| uword last_index = 0; | | uword last_index = 0; | |
| | | | |
| | | | |
| skipping to change at line 1380 | | skipping to change at line 1381 | |
| } | | } | |
| } | | } | |
| | | | |
| running_means[g].reset(); | | running_means[g].reset(); | |
| | | | |
| if(dead == false) { running_means[g](tmp_mean, last_index); } | | if(dead == false) { running_means[g](tmp_mean, last_index); } | |
| } | | } | |
| } | | } | |
| #else | | #else | |
| { | | { | |
|
| for(uword g=0; g < n_gaus; ++g) { running_means[g].reset(); } | | for(uword g=0; g < N_gaus; ++g) { running_means[g].reset(); } | |
| | | | |
| km_update_stats<dist_id>(X, 0, X.n_cols-1, old_means, running_means); | | km_update_stats<dist_id>(X, 0, X.n_cols-1, old_means, running_means); | |
| } | | } | |
| #endif | | #endif | |
| | | | |
| uword n_dead_means = 0; | | uword n_dead_means = 0; | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
| if(running_means[g].count() > 0) | | if(running_means[g].count() > 0) | |
| { | | { | |
| new_means.col(g) = running_means[g].mean(); | | new_means.col(g) = running_means[g].mean(); | |
| } | | } | |
| else | | else | |
| { | | { | |
| n_dead_means++; | | n_dead_means++; | |
| } | | } | |
| } | | } | |
| | | | |
| skipping to change at line 1410 | | skipping to change at line 1411 | |
| if(n_dead_means > 0) | | if(n_dead_means > 0) | |
| { | | { | |
| if(verbose) { get_stream_err2() << "gmm_diag::learn(): k-means: reco
vering from dead means\n"; } | | if(verbose) { get_stream_err2() << "gmm_diag::learn(): k-means: reco
vering from dead means\n"; } | |
| | | | |
| if(n_dead_means == 1) | | if(n_dead_means == 1) | |
| { | | { | |
| uword dead_g = 0; | | uword dead_g = 0; | |
| uword populous_g = 0; | | uword populous_g = 0; | |
| uword populous_count = running_means(0).count(); | | uword populous_count = running_means(0).count(); | |
| | | | |
|
| for(uword g=1; g < n_gaus; ++g) | | for(uword g=1; g < N_gaus; ++g) | |
| { | | { | |
| const uword count = running_means(g).count(); | | const uword count = running_means(g).count(); | |
| | | | |
| if(count == 0) { dead_g = g; } | | if(count == 0) { dead_g = g; } | |
| | | | |
| if(populous_count < count) | | if(populous_count < count) | |
| { | | { | |
| populous_count = count; | | populous_count = count; | |
| populous_g = g; | | populous_g = g; | |
| } | | } | |
| } | | } | |
| | | | |
| if( (populous_count <= 2) || (dead_g == populous_g) ) { return fal
se; } | | if( (populous_count <= 2) || (dead_g == populous_g) ) { return fal
se; } | |
| | | | |
| new_means.col(dead_g) = X.unsafe_col( running_means(populous_g).las
t_index() ); | | new_means.col(dead_g) = X.unsafe_col( running_means(populous_g).las
t_index() ); | |
| } | | } | |
| else | | else | |
| { | | { | |
| uword dead_g = 0; | | uword dead_g = 0; | |
| | | | |
|
| for(uword live_g = 0; live_g < n_gaus; ++live_g) | | for(uword live_g = 0; live_g < N_gaus; ++live_g) | |
| { | | { | |
| if(running_means(live_g).count() >= 2) | | if(running_means(live_g).count() >= 2) | |
| { | | { | |
|
| for(; dead_g < n_gaus; ++dead_g) | | for(; dead_g < N_gaus; ++dead_g) | |
| { | | { | |
| if(running_means(dead_g).count() == 0) { break; } | | if(running_means(dead_g).count() == 0) { break; } | |
| } | | } | |
| | | | |
| new_means.col(dead_g) = X.unsafe_col( running_means(live_g).las
t_index() ); | | new_means.col(dead_g) = X.unsafe_col( running_means(live_g).las
t_index() ); | |
| | | | |
| dead_g++; | | dead_g++; | |
| } | | } | |
| } | | } | |
| } | | } | |
| } | | } | |
| | | | |
| rs_delta.reset(); | | rs_delta.reset(); | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
|
| rs_delta( distance<eT,dist_id>::eval(n_dims, old_means.colptr(g), new
_means.colptr(g), mah_aux_mem) ); | | rs_delta( distance<eT,dist_id>::eval(N_dims, old_means.colptr(g), new
_means.colptr(g), mah_aux_mem) ); | |
| } | | } | |
| | | | |
| if(verbose) | | if(verbose) | |
| { | | { | |
| get_stream_err2() << "gmm_diag::learn(): k-means: iteration: "; | | get_stream_err2() << "gmm_diag::learn(): k-means: iteration: "; | |
| get_stream_err2().unsetf(ios::scientific); | | get_stream_err2().unsetf(ios::scientific); | |
| get_stream_err2().setf(ios::fixed); | | get_stream_err2().setf(ios::fixed); | |
| get_stream_err2().width(std::streamsize(4)); | | get_stream_err2().width(std::streamsize(4)); | |
| get_stream_err2() << iter; | | get_stream_err2() << iter; | |
| get_stream_err2() << " delta: "; | | get_stream_err2() << " delta: "; | |
| | | | |
| skipping to change at line 1489 | | skipping to change at line 1490 | |
| template<uword dist_id> | | template<uword dist_id> | |
| inline | | inline | |
| void | | void | |
| gmm_diag<eT>::km_update_stats(const Mat<eT>& X, const uword start_index, co
nst uword end_index, const Mat<eT>& old_means, field< running_mean_vec<eT>
>& running_means) const | | gmm_diag<eT>::km_update_stats(const Mat<eT>& X, const uword start_index, co
nst uword end_index, const Mat<eT>& old_means, field< running_mean_vec<eT>
>& running_means) const | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| // get_stream_err2() << "km_update_stats(): start_index: " << start_index
<< '\n'; | | // get_stream_err2() << "km_update_stats(): start_index: " << start_index
<< '\n'; | |
| // get_stream_err2() << "km_update_stats(): end_index: " << end_index
<< '\n'; | | // get_stream_err2() << "km_update_stats(): end_index: " << end_index
<< '\n'; | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
| const eT* mah_aux_mem = mah_aux.memptr(); | | const eT* mah_aux_mem = mah_aux.memptr(); | |
| | | | |
| for(uword i=start_index; i <= end_index; ++i) | | for(uword i=start_index; i <= end_index; ++i) | |
| { | | { | |
| const eT* X_colptr = X.colptr(i); | | const eT* X_colptr = X.colptr(i); | |
| | | | |
| double best_dist = Datum<eT>::inf; | | double best_dist = Datum<eT>::inf; | |
| uword best_g = 0; | | uword best_g = 0; | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
|
| const double dist = distance<eT,dist_id>::eval(n_dims, X_colptr, old_
means.colptr(g), mah_aux_mem); | | const double dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_
means.colptr(g), mah_aux_mem); | |
| | | | |
| // get_stream_err2() << "g: " << g << " dist: " << dist << '\n'; | | // get_stream_err2() << "g: " << g << " dist: " << dist << '\n'; | |
| // old_means.col(g).print("old_means.col(g):"); | | // old_means.col(g).print("old_means.col(g):"); | |
| // vec tmp(old_means.colptr(g), old_means.n_rows); | | // vec tmp(old_means.colptr(g), old_means.n_rows); | |
| // tmp.print("tmp:"); | | // tmp.print("tmp:"); | |
| | | | |
| if(dist <= best_dist) { best_dist = dist; best_g = g; } | | if(dist <= best_dist) { best_dist = dist; best_g = g; } | |
| } | | } | |
| | | | |
| // get_stream_err2() << "best_g: " << best_g << '\n'; | | // get_stream_err2() << "best_g: " << best_g << '\n'; | |
| | | | |
| running_means[best_g]( X.unsafe_col(i), i ); | | running_means[best_g]( X.unsafe_col(i), i ); | |
| } | | } | |
| } | | } | |
| | | | |
|
| | | //! multi-threaded implementation of Expectation-Maximisation, inspired by
MapReduce | |
| template<typename eT> | | template<typename eT> | |
| inline | | inline | |
| bool | | bool | |
| gmm_diag<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT v
ar_floor, const bool verbose) | | gmm_diag<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT v
ar_floor, const bool verbose) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
| if(verbose) | | if(verbose) | |
| { | | { | |
| get_stream_err2().unsetf(ios::showbase); | | get_stream_err2().unsetf(ios::showbase); | |
| get_stream_err2().unsetf(ios::uppercase); | | get_stream_err2().unsetf(ios::uppercase); | |
| get_stream_err2().unsetf(ios::showpos); | | get_stream_err2().unsetf(ios::showpos); | |
| get_stream_err2().unsetf(ios::scientific); | | get_stream_err2().unsetf(ios::scientific); | |
| | | | |
| get_stream_err2().setf(ios::right); | | get_stream_err2().setf(ios::right); | |
| get_stream_err2().setf(ios::fixed); | | get_stream_err2().setf(ios::fixed); | |
| | | | |
| skipping to change at line 1570 | | skipping to change at line 1572 | |
| | | | |
| field< Col<eT> > t_acc_norm_lhoods(n_threads); | | field< Col<eT> > t_acc_norm_lhoods(n_threads); | |
| field< Col<eT> > t_gaus_log_lhoods(n_threads); | | field< Col<eT> > t_gaus_log_lhoods(n_threads); | |
| | | | |
| Col<eT> t_progress_log_lhood(n_threads); | | Col<eT> t_progress_log_lhood(n_threads); | |
| | | | |
| for(uword t=0; t<n_threads; t++) | | for(uword t=0; t<n_threads; t++) | |
| { | | { | |
| t_boundary[t].set_size(2); | | t_boundary[t].set_size(2); | |
| | | | |
|
| t_acc_means[t].set_size(n_dims, n_gaus); | | t_acc_means[t].set_size(N_dims, N_gaus); | |
| t_acc_dcovs[t].set_size(n_dims, n_gaus); | | t_acc_dcovs[t].set_size(N_dims, N_gaus); | |
| | | | |
|
| t_acc_norm_lhoods[t].set_size(n_gaus); | | t_acc_norm_lhoods[t].set_size(N_gaus); | |
| t_gaus_log_lhoods[t].set_size(n_gaus); | | t_gaus_log_lhoods[t].set_size(N_gaus); | |
| } | | } | |
| | | | |
| const uword chunk_size = X.n_cols / n_threads; | | const uword chunk_size = X.n_cols / n_threads; | |
| | | | |
| uword count = 0; | | uword count = 0; | |
| | | | |
| for(uword t=0; t<n_threads; t++) | | for(uword t=0; t<n_threads; t++) | |
| { | | { | |
| t_boundary[t][0] = count; | | t_boundary[t][0] = count; | |
| | | | |
| | | | |
| skipping to change at line 1661 | | skipping to change at line 1663 | |
| field< Mat<eT> >& t_acc_dcovs, | | field< Mat<eT> >& t_acc_dcovs, | |
| field< Col<eT> >& t_acc_norm_lhoods, | | field< Col<eT> >& t_acc_norm_lhoods, | |
| field< Col<eT> >& t_gaus_log_lhoods, | | field< Col<eT> >& t_gaus_log_lhoods, | |
| Col<eT>& t_progress_log_lhood | | Col<eT>& t_progress_log_lhood | |
| ) | | ) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| const uword n_threads = t_boundary.n_elem; | | const uword n_threads = t_boundary.n_elem; | |
| | | | |
|
| | | // em_generate_acc() is the "map" operation, which produces partial accum | |
| | | ulators for means, diagonal covariances and hefts | |
| | | | |
| #if defined(_OPENMP) | | #if defined(_OPENMP) | |
| { | | { | |
| #pragma omp parallel for | | #pragma omp parallel for | |
| for(uword t=0; t<n_threads; t++) | | for(uword t=0; t<n_threads; t++) | |
| { | | { | |
| const uvec& boundary = t_boundary[t]; | | const uvec& boundary = t_boundary[t]; | |
| Mat<eT>& acc_means = t_acc_means[t]; | | Mat<eT>& acc_means = t_acc_means[t]; | |
| Mat<eT>& acc_dcovs = t_acc_dcovs[t]; | | Mat<eT>& acc_dcovs = t_acc_dcovs[t]; | |
| Col<eT>& acc_norm_lhoods = t_acc_norm_lhoods[t]; | | Col<eT>& acc_norm_lhoods = t_acc_norm_lhoods[t]; | |
| Col<eT>& gaus_log_lhoods = t_gaus_log_lhoods[t]; | | Col<eT>& gaus_log_lhoods = t_gaus_log_lhoods[t]; | |
| | | | |
| skipping to change at line 1682 | | skipping to change at line 1686 | |
| | | | |
| em_generate_acc(X, boundary, acc_means, acc_dcovs, acc_norm_lhoods, g
aus_log_lhoods, progress_log_lhood); | | em_generate_acc(X, boundary, acc_means, acc_dcovs, acc_norm_lhoods, g
aus_log_lhoods, progress_log_lhood); | |
| } | | } | |
| } | | } | |
| #else | | #else | |
| { | | { | |
| em_generate_acc(X, t_boundary[0], t_acc_means[0], t_acc_dcovs[0], t_acc
_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]); | | em_generate_acc(X, t_boundary[0], t_acc_means[0], t_acc_dcovs[0], t_acc
_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]); | |
| } | | } | |
| #endif | | #endif | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
| Mat<eT>& final_acc_means = t_acc_means[0]; | | Mat<eT>& final_acc_means = t_acc_means[0]; | |
| Mat<eT>& final_acc_dcovs = t_acc_dcovs[0]; | | Mat<eT>& final_acc_dcovs = t_acc_dcovs[0]; | |
| | | | |
| Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0]; | | Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0]; | |
| | | | |
|
| | | // the "reduce" operation, which combines the partial accumulators produc | |
| | | ed by the separate threads | |
| | | | |
| for(uword t=1; t<n_threads; t++) | | for(uword t=1; t<n_threads; t++) | |
| { | | { | |
| final_acc_means += t_acc_means[t]; | | final_acc_means += t_acc_means[t]; | |
| final_acc_dcovs += t_acc_dcovs[t]; | | final_acc_dcovs += t_acc_dcovs[t]; | |
| | | | |
| final_acc_norm_lhoods += t_acc_norm_lhoods[t]; | | final_acc_norm_lhoods += t_acc_norm_lhoods[t]; | |
| } | | } | |
| | | | |
| eT* hefts_mem = access::rw(hefts).memptr(); | | eT* hefts_mem = access::rw(hefts).memptr(); | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
| eT* mean_mem = access::rw(means).colptr(g); | | eT* mean_mem = access::rw(means).colptr(g); | |
| eT* dcov_mem = access::rw(dcovs).colptr(g); | | eT* dcov_mem = access::rw(dcovs).colptr(g); | |
| | | | |
| eT* acc_mean_mem = final_acc_means.colptr(g); | | eT* acc_mean_mem = final_acc_means.colptr(g); | |
| eT* acc_dcov_mem = final_acc_dcovs.colptr(g); | | eT* acc_dcov_mem = final_acc_dcovs.colptr(g); | |
| | | | |
| const eT acc_norm_lhood = final_acc_norm_lhoods[g]; | | const eT acc_norm_lhood = final_acc_norm_lhoods[g]; | |
| | | | |
| hefts_mem[g] = acc_norm_lhood / eT(X.n_cols); | | hefts_mem[g] = acc_norm_lhood / eT(X.n_cols); | |
| | | | |
|
| for(uword d=0; d < n_dims; ++d) | | for(uword d=0; d < N_dims; ++d) | |
| { | | { | |
| const eT tmp = acc_mean_mem[d] / acc_norm_lhood; | | const eT tmp = acc_mean_mem[d] / acc_norm_lhood; | |
| | | | |
| mean_mem[d] = tmp; | | mean_mem[d] = tmp; | |
| dcov_mem[d] = acc_dcov_mem[d] / acc_norm_lhood - tmp*tmp; | | dcov_mem[d] = acc_dcov_mem[d] / acc_norm_lhood - tmp*tmp; | |
| } | | } | |
| } | | } | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| | | | |
| skipping to change at line 1748 | | skipping to change at line 1754 | |
| | | | |
| progress_log_lhood = eT(0); | | progress_log_lhood = eT(0); | |
| | | | |
| acc_means.zeros(); | | acc_means.zeros(); | |
| acc_dcovs.zeros(); | | acc_dcovs.zeros(); | |
| | | | |
| acc_norm_lhoods.zeros(); | | acc_norm_lhoods.zeros(); | |
| gaus_log_lhoods.zeros(); | | gaus_log_lhoods.zeros(); | |
| | | | |
| const uword n_dim = means.n_rows; | | const uword n_dim = means.n_rows; | |
|
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
| const eT* log_hefts_mem = log_hefts.memptr(); | | const eT* log_hefts_mem = log_hefts.memptr(); | |
| eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr(); | | eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr(); | |
| | | | |
| const uword start_index = boundary[0]; | | const uword start_index = boundary[0]; | |
| const uword end_index = boundary[1]; | | const uword end_index = boundary[1]; | |
| | | | |
| for(uword i=start_index; i <= end_index; i++) | | for(uword i=start_index; i <= end_index; i++) | |
| { | | { | |
| const eT* x = X.colptr(i); | | const eT* x = X.colptr(i); | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
| gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[
g]; | | gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[
g]; | |
| } | | } | |
| | | | |
| eT log_lhood_sum = gaus_log_lhoods_mem[0]; | | eT log_lhood_sum = gaus_log_lhoods_mem[0]; | |
| | | | |
|
| for(uword g=1; g < n_gaus; ++g) | | for(uword g=1; g < N_gaus; ++g) | |
| { | | { | |
| log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); | | log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); | |
| } | | } | |
| | | | |
| progress_log_lhood += log_lhood_sum; | | progress_log_lhood += log_lhood_sum; | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
| const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum
); | | const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum
); | |
| | | | |
| acc_norm_lhoods[g] += norm_lhood; | | acc_norm_lhoods[g] += norm_lhood; | |
| | | | |
| eT* acc_mean_mem = acc_means.colptr(g); | | eT* acc_mean_mem = acc_means.colptr(g); | |
| eT* acc_dcov_mem = acc_dcovs.colptr(g); | | eT* acc_dcov_mem = acc_dcovs.colptr(g); | |
| | | | |
| for(uword d=0; d < n_dim; ++d) | | for(uword d=0; d < n_dim; ++d) | |
| { | | { | |
| | | | |
| skipping to change at line 1804 | | skipping to change at line 1810 | |
| progress_log_lhood /= eT((end_index - start_index) + 1); | | progress_log_lhood /= eT((end_index - start_index) + 1); | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| inline | | inline | |
| void | | void | |
| gmm_diag<eT>::em_fix_params(const eT var_floor) | | gmm_diag<eT>::em_fix_params(const eT var_floor) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| const uword n_dims = means.n_rows; | | const uword N_dims = means.n_rows; | |
| const uword n_gaus = means.n_cols; | | const uword N_gaus = means.n_cols; | |
| | | | |
|
| for(uword g=0; g < n_gaus; ++g) | | for(uword g=0; g < N_gaus; ++g) | |
| { | | { | |
| eT* dcov_mem = access::rw(dcovs).colptr(g); | | eT* dcov_mem = access::rw(dcovs).colptr(g); | |
| | | | |
|
| for(uword d=0; d < n_dims; ++d) | | for(uword d=0; d < N_dims; ++d) | |
| { | | { | |
| if(dcov_mem[d] < var_floor) { dcov_mem[d] = var_floor; } | | if(dcov_mem[d] < var_floor) { dcov_mem[d] = var_floor; } | |
| } | | } | |
| } | | } | |
| | | | |
| const eT heft_sum = accu(hefts); | | const eT heft_sum = accu(hefts); | |
| | | | |
| if(heft_sum != eT(1)) { access::rw(hefts) / heft_sum; } | | if(heft_sum != eT(1)) { access::rw(hefts) / heft_sum; } | |
| } | | } | |
| | | | |
| | | | |
End of changes. 92 change blocks. |
| 110 lines changed or deleted | | 119 lines changed or added | |
|