arma_version.hpp | arma_version.hpp | |||
---|---|---|---|---|
skipping to change at line 13 | skipping to change at line 13 | |||
// | // | |||
// This Source Code Form is subject to the terms of the Mozilla Public | // This Source Code Form is subject to the terms of the Mozilla Public | |||
// License, v. 2.0. If a copy of the MPL was not distributed with this | // License, v. 2.0. If a copy of the MPL was not distributed with this | |||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | // file, You can obtain one at http://mozilla.org/MPL/2.0/. | |||
//! \addtogroup arma_version | //! \addtogroup arma_version | |||
//! @{ | //! @{ | |||
#define ARMA_VERSION_MAJOR 4 | #define ARMA_VERSION_MAJOR 4 | |||
#define ARMA_VERSION_MINOR 550 | #define ARMA_VERSION_MINOR 550 | |||
#define ARMA_VERSION_PATCH 0 | #define ARMA_VERSION_PATCH 1 | |||
#define ARMA_VERSION_NAME "Singapore Sling Deluxe" | #define ARMA_VERSION_NAME "Singapore Sling Deluxe" | |||
struct arma_version | struct arma_version | |||
{ | { | |||
static const unsigned int major = ARMA_VERSION_MAJOR; | static const unsigned int major = ARMA_VERSION_MAJOR; | |||
static const unsigned int minor = ARMA_VERSION_MINOR; | static const unsigned int minor = ARMA_VERSION_MINOR; | |||
static const unsigned int patch = ARMA_VERSION_PATCH; | static const unsigned int patch = ARMA_VERSION_PATCH; | |||
static | static | |||
inline | inline | |||
End of changes. 1 change blocks. | ||||
1 lines changed or deleted | 1 lines changed or added | |||
gmm_diag_meat.hpp | gmm_diag_meat.hpp | |||
---|---|---|---|---|
skipping to change at line 602 | skipping to change at line 602 | |||
|| (seed_mode == random_subset) | || (seed_mode == random_subset) | |||
|| (seed_mode == random_spread); | || (seed_mode == random_spread); | |||
arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode must be eucl_dist or maha_dist" ); | arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode must be eucl_dist or maha_dist" ); | |||
arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown se ed_mode" ); | arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown se ed_mode" ); | |||
arma_debug_check( (var_floor < eT(0) ), "gmm_diag::learn(): variance f loor is negative" ); | arma_debug_check( (var_floor < eT(0) ), "gmm_diag::learn(): variance f loor is negative" ); | |||
const unwrap<T1> tmp_X(data.get_ref()); | const unwrap<T1> tmp_X(data.get_ref()); | |||
const Mat<eT>& X = tmp_X.M; | const Mat<eT>& X = tmp_X.M; | |||
if(X.is_empty() ) { arma_warn(true, "gmm_diag::learn(): given m | if(X.is_empty() ) { arma_warn(true, "gmm_diag::learn(): given m | |||
atrix is empty" ); reset(); return false; } | atrix is empty" ); return false; } | |||
if(X.is_finite() == false) { arma_warn(true, "gmm_diag::learn(): given m | if(X.is_finite() == false) { arma_warn(true, "gmm_diag::learn(): given m | |||
atrix has non-finite values"); reset(); return false; } | atrix has non-finite values"); return false; } | |||
if(N_gaus == 0) { reset(); return true; } | if(N_gaus == 0) { reset(); return true; } | |||
if(dist_mode == maha_dist) | if(dist_mode == maha_dist) | |||
{ | { | |||
mah_aux = var(X,1,1); | mah_aux = var(X,1,1); | |||
const uword mah_aux_n_elem = mah_aux.n_elem; | const uword mah_aux_n_elem = mah_aux.n_elem; | |||
eT* mah_aux_mem = mah_aux.memptr(); | eT* mah_aux_mem = mah_aux.memptr(); | |||
for(uword i=0; i < mah_aux_n_elem; ++i) | for(uword i=0; i < mah_aux_n_elem; ++i) | |||
{ | { | |||
const eT val = mah_aux_mem[i]; | const eT val = mah_aux_mem[i]; | |||
mah_aux_mem[i] = (val != eT(0)) ? eT(1) / val : eT(1); | mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1); | |||
} | } | |||
} | } | |||
// mah_aux.print("mah_aux:"); | // copy current model, in case of failure by k-means and/or EM | |||
const gmm_diag<eT> orig = (*this); | ||||
// initial means | // initial means | |||
if(seed_mode == keep_existing) | if(seed_mode == keep_existing) | |||
{ | { | |||
if(means.is_empty() ) { arma_warn(true, "gmm_diag::learn(): no | if(means.is_empty() ) { arma_warn(true, "gmm_diag::learn(): no | |||
existing means" ); reset(); return false; } | existing means" ); return false; } | |||
if(X.n_rows != means.n_rows) { arma_warn(true, "gmm_diag::learn(): dim | if(X.n_rows != means.n_rows) { arma_warn(true, "gmm_diag::learn(): dim | |||
ensionality mismatch"); reset(); return false; } | ensionality mismatch"); return false; } | |||
// TODO: also check for number of vectors? | // TODO: also check for number of vectors? | |||
} | } | |||
else | else | |||
{ | { | |||
if(X.n_cols < N_gaus) { arma_warn(true, "gmm_diag::learn(): number of vectors is less than number of gaussians"); reset(); return false; } | if(X.n_cols < N_gaus) { arma_warn(true, "gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; } | |||
reset(X.n_rows, N_gaus); | reset(X.n_rows, N_gaus); | |||
if(print_mode) { get_stream_err2() << "gmm_diag::learn(): generating i nitial means\n"; } | if(print_mode) { get_stream_err2() << "gmm_diag::learn(): generating i nitial means\n"; } | |||
if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mo de); } | if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mo de); } | |||
else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mo de); } | else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mo de); } | |||
} | } | |||
// k-means | // k-means | |||
skipping to change at line 658 | skipping to change at line 660 | |||
{ | { | |||
const arma_ostream_state stream_state(get_stream_err2()); | const arma_ostream_state stream_state(get_stream_err2()); | |||
bool status = false; | bool status = false; | |||
if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, p rint_mode); } | if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, p rint_mode); } | |||
else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, p rint_mode); } | else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, p rint_mode); } | |||
stream_state.restore(get_stream_err2()); | stream_state.restore(get_stream_err2()); | |||
if(status == false) { arma_warn(true, "gmm_diag::learn(): k-means algo rithm failed"); reset(); return false; } | if(status == false) { arma_warn(true, "gmm_diag::learn(): k-means algo rithm failed"); init(orig); return false; } | |||
} | } | |||
// initial dcovs | // initial dcovs | |||
const eT vfloor = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_ limits<eT>::min(); | const eT vfloor = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_ limits<eT>::min(); | |||
if(seed_mode != keep_existing) | if(seed_mode != keep_existing) | |||
{ | { | |||
if(print_mode) { get_stream_err2() << "gmm_diag::learn(): generating i nitial covariances\n"; } | if(print_mode) { get_stream_err2() << "gmm_diag::learn(): generating i nitial covariances\n"; } | |||
skipping to change at line 683 | skipping to change at line 685 | |||
// EM algorithm | // EM algorithm | |||
if(em_iter > 0) | if(em_iter > 0) | |||
{ | { | |||
const arma_ostream_state stream_state(get_stream_err2()); | const arma_ostream_state stream_state(get_stream_err2()); | |||
const bool status = em_iterate(X, em_iter, vfloor, print_mode); | const bool status = em_iterate(X, em_iter, vfloor, print_mode); | |||
stream_state.restore(get_stream_err2()); | stream_state.restore(get_stream_err2()); | |||
if(status == false) { arma_warn(true, "gmm_diag::learn(): EM algorithm failed"); reset(); return false; } | if(status == false) { arma_warn(true, "gmm_diag::learn(): EM algorithm failed"); init(orig); return false; } | |||
} | } | |||
mah_aux.reset(); | mah_aux.reset(); | |||
init_constants(); | init_constants(); | |||
return true; | return true; | |||
} | } | |||
// | // | |||
End of changes. 7 change blocks. | ||||
13 lines changed or deleted | 15 lines changed or added | |||