| blas_wrapper.hpp | | blas_wrapper.hpp | |
|
| // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) | | // Copyright (C) 2008-2012 NICTA (www.nicta.com.au) | |
| // Copyright (C) 2008-2011 Conrad Sanderson | | // Copyright (C) 2008-2012 Conrad Sanderson | |
| // | | // | |
| // This file is part of the Armadillo C++ library. | | // This file is part of the Armadillo C++ library. | |
| // It is provided without any warranty of fitness | | // It is provided without any warranty of fitness | |
| // for any purpose. You can redistribute this file | | // for any purpose. You can redistribute this file | |
| // and/or modify it under the terms of the GNU | | // and/or modify it under the terms of the GNU | |
| // Lesser General Public License (LGPL) as published | | // Lesser General Public License (LGPL) as published | |
| // by the Free Software Foundation, either version 3 | | // by the Free Software Foundation, either version 3 | |
| // of the License or (at your option) any later version. | | // of the License or (at your option) any later version. | |
| // (see http://www.opensource.org/licenses for more info) | | // (see http://www.opensource.org/licenses for more info) | |
| | | | |
| #ifdef ARMA_USE_BLAS | | #ifdef ARMA_USE_BLAS | |
| | | | |
| //! \namespace blas namespace for BLAS functions | | //! \namespace blas namespace for BLAS functions | |
| namespace blas | | namespace blas | |
| { | | { | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| inline | | inline | |
|
| eT | | | |
| dot(const uword n_elem, const eT* x, const eT* y) | | | |
| { | | | |
| arma_ignore(n_elem); | | | |
| arma_ignore(x); | | | |
| arma_ignore(y); | | | |
| | | | |
| return eT(0); | | | |
| } | | | |
| | | | |
| template<> | | | |
| inline | | | |
| float | | | |
| dot(const uword n_elem, const float* x, const float* y) | | | |
| { | | | |
| blas_int n = blas_int(n_elem); | | | |
| blas_int inc = blas_int(1); | | | |
| | | | |
| return arma_fortran(arma_sdot)(&n, x, &inc, y, &inc); | | | |
| } | | | |
| | | | |
| template<> | | | |
| inline | | | |
| double | | | |
| dot(const uword n_elem, const double* x, const double* y) | | | |
| { | | | |
| blas_int n = blas_int(n_elem); | | | |
| blas_int inc = blas_int(1); | | | |
| | | | |
| return arma_fortran(arma_ddot)(&n, x, &inc, y, &inc); | | | |
| } | | | |
| | | | |
| template<typename eT> | | | |
| inline | | | |
| void | | void | |
| gemv(const char* transA, const blas_int* m, const blas_int* n, const eT*
alpha, const eT* A, const blas_int* ldA, const eT* x, const blas_int* incx,
const eT* beta, eT* y, const blas_int* incy) | | gemv(const char* transA, const blas_int* m, const blas_int* n, const eT*
alpha, const eT* A, const blas_int* ldA, const eT* x, const blas_int* incx,
const eT* beta, eT* y, const blas_int* incy) | |
| { | | { | |
| arma_type_check((is_supported_blas_type<eT>::value == false)); | | arma_type_check((is_supported_blas_type<eT>::value == false)); | |
| | | | |
| if(is_float<eT>::value == true) | | if(is_float<eT>::value == true) | |
| { | | { | |
| typedef float T; | | typedef float T; | |
| arma_fortran(arma_sgemv)(transA, m, n, (const T*)alpha, (const T*)A,
ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); | | arma_fortran(arma_sgemv)(transA, m, n, (const T*)alpha, (const T*)A,
ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); | |
| } | | } | |
| | | | |
| skipping to change at line 119 | | skipping to change at line 85 | |
| } | | } | |
| else | | else | |
| if(is_supported_complex_double<eT>::value == true) | | if(is_supported_complex_double<eT>::value == true) | |
| { | | { | |
| typedef std::complex<double> T; | | typedef std::complex<double> T; | |
| arma_fortran(arma_zgemm)(transA, transB, m, n, k, (const T*)alpha, (c
onst T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); | | arma_fortran(arma_zgemm)(transA, transB, m, n, k, (const T*)alpha, (c
onst T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); | |
| } | | } | |
| | | | |
| } | | } | |
| | | | |
|
| | | template<typename eT> | |
| | | inline | |
| | | eT | |
| | | dot(const uword n_elem, const eT* x, const eT* y) | |
| | | { | |
| | | arma_type_check((is_supported_blas_type<eT>::value == false)); | |
| | | | |
| | | if(is_float<eT>::value == true) | |
| | | { | |
| | | #if defined(ARMA_BLAS_SDOT_BUG) | |
| | | { | |
| | | if(n_elem == 0) { return eT(0); } | |
| | | | |
| | | const char trans = 'T'; | |
| | | | |
| | | const blas_int m = blas_int(n_elem); | |
| | | const blas_int n = 1; | |
| | | //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1 | |
| | | ); | |
| | | const blas_int inc = 1; | |
| | | | |
| | | const eT alpha = eT(1); | |
| | | const eT beta = eT(0); | |
| | | | |
| | | eT result[2]; // paranoia: using two elements instead of one | |
| | | | |
| | | //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &resu | |
| | | lt[0], &inc); | |
| | | blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0 | |
| | | ], &inc); | |
| | | | |
| | | return result[0]; | |
| | | } | |
| | | #else | |
| | | { | |
| | | blas_int n = blas_int(n_elem); | |
| | | blas_int inc = 1; | |
| | | | |
| | | typedef float T; | |
| | | return arma_fortran(arma_sdot)(&n, (const T*)x, &inc, (const T*)y, | |
| | | &inc); | |
| | | } | |
| | | #endif | |
| | | } | |
| | | else | |
| | | if(is_double<eT>::value == true) | |
| | | { | |
| | | blas_int n = blas_int(n_elem); | |
| | | blas_int inc = 1; | |
| | | | |
| | | typedef double T; | |
| | | return arma_fortran(arma_ddot)(&n, (const T*)x, &inc, (const T*)y, &i | |
| | | nc); | |
| | | } | |
| | | else | |
| | | if( (is_supported_complex_float<eT>::value == true) || (is_supported_co | |
| | | mplex_double<eT>::value == true) ) | |
| | | { | |
| | | if(n_elem == 0) { return eT(0); } | |
| | | | |
| | | // using gemv() workaround due to compatibility issues with cdotu() a | |
| | | nd zdotu() | |
| | | | |
| | | const char trans = 'T'; | |
| | | | |
| | | const blas_int m = blas_int(n_elem); | |
| | | const blas_int n = 1; | |
| | | //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1); | |
| | | const blas_int inc = 1; | |
| | | | |
| | | const eT alpha = eT(1); | |
| | | const eT beta = eT(0); | |
| | | | |
| | | eT result[2]; // paranoia: using two elements instead of one | |
| | | | |
| | | //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &result | |
| | | [0], &inc); | |
| | | blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], | |
| | | &inc); | |
| | | | |
| | | return result[0]; | |
| | | } | |
| | | else | |
| | | { | |
| | | return eT(0); | |
| | | } | |
| | | } | |
| } | | } | |
| | | | |
| #endif | | #endif | |
| | | | |
End of changes. 3 change blocks. |
| 36 lines changed or deleted | | 89 lines changed or added | |
|
| config.hpp | | config.hpp | |
| | | | |
| skipping to change at line 28 | | skipping to change at line 28 | |
| #endif | | #endif | |
| | | | |
| #if !defined(ARMA_USE_BLAS) | | #if !defined(ARMA_USE_BLAS) | |
| #define ARMA_USE_BLAS | | #define ARMA_USE_BLAS | |
| //// Uncomment the above line if you have BLAS or a high-speed replacement
for BLAS, | | //// Uncomment the above line if you have BLAS or a high-speed replacement
for BLAS, | |
| //// such as GotoBLAS, Intel's MKL, AMD's ACML, or the Accelerate framework
. | | //// such as GotoBLAS, Intel's MKL, AMD's ACML, or the Accelerate framework
. | |
| //// BLAS is used for matrix multiplication. | | //// BLAS is used for matrix multiplication. | |
| //// Without BLAS, matrix multiplication will still work, but might be slow
er. | | //// Without BLAS, matrix multiplication will still work, but might be slow
er. | |
| #endif | | #endif | |
| | | | |
|
| | | #define ARMA_USE_WRAPPER | |
| | | //// Comment out the above line if you prefer to directly link with LAPACK | |
| | | and/or BLAS (eg. -llapack -lblas) | |
| | | //// instead of linking indirectly with LAPACK and/or BLAS via Armadillo's | |
| | | run-time wrapper library. | |
| | | | |
| // #define ARMA_BLAS_CAPITALS | | // #define ARMA_BLAS_CAPITALS | |
| //// Uncomment the above line if your BLAS and LAPACK libraries have capita
lised function names (eg. ACML on 64-bit Windows) | | //// Uncomment the above line if your BLAS and LAPACK libraries have capita
lised function names (eg. ACML on 64-bit Windows) | |
| | | | |
| #define ARMA_BLAS_UNDERSCORE | | #define ARMA_BLAS_UNDERSCORE | |
| //// Uncomment the above line if your BLAS and LAPACK libraries have functi
on names with a trailing underscore. | | //// Uncomment the above line if your BLAS and LAPACK libraries have functi
on names with a trailing underscore. | |
| //// Conversely, comment it out if the function names don't have a trailing
underscore. | | //// Conversely, comment it out if the function names don't have a trailing
underscore. | |
| | | | |
| // #define ARMA_BLAS_LONG | | // #define ARMA_BLAS_LONG | |
| //// Uncomment the above line if your BLAS and LAPACK libraries use "long"
instead of "int" | | //// Uncomment the above line if your BLAS and LAPACK libraries use "long"
instead of "int" | |
| | | | |
| | | | |
| skipping to change at line 57 | | skipping to change at line 61 | |
| #define ARMA_USE_ATLAS | | #define ARMA_USE_ATLAS | |
| #define ARMA_ATLAS_INCLUDE_DIR /usr/include/ | | #define ARMA_ATLAS_INCLUDE_DIR /usr/include/ | |
| //// If you're using ATLAS and the compiler can't find cblas.h and/or clapa
ck.h | | //// If you're using ATLAS and the compiler can't find cblas.h and/or clapa
ck.h | |
| //// uncomment the above define and specify the appropriate include directo
ry. | | //// uncomment the above define and specify the appropriate include directo
ry. | |
| //// Make sure the directory has a trailing / | | //// Make sure the directory has a trailing / | |
| | | | |
| // #define ARMA_64BIT_WORD | | // #define ARMA_64BIT_WORD | |
| //// Uncomment the above line if you require matrices/vectors capable of ho
lding more than 4 billion elements. | | //// Uncomment the above line if you require matrices/vectors capable of ho
lding more than 4 billion elements. | |
| //// Your machine and compiler must have support for 64 bit integers (eg. v
ia "long" or "long long") | | //// Your machine and compiler must have support for 64 bit integers (eg. v
ia "long" or "long long") | |
| | | | |
|
| | | #if !defined(ARMA_USE_CXX11) | |
| // #define ARMA_USE_CXX11 | | // #define ARMA_USE_CXX11 | |
| //// Uncomment the above line if you have a C++ compiler that supports the
C++11 standard | | //// Uncomment the above line if you have a C++ compiler that supports the
C++11 standard | |
| //// This will enable additional features, such as use of initialiser lists | | //// This will enable additional features, such as use of initialiser lists | |
|
| | | #endif | |
| | | | |
| | | #if !defined(ARMA_USE_HDF5) | |
| | | /* #undef ARMA_USE_HDF5 */ | |
| | | //// Uncomment the above line if you want the ability to save and load matr | |
| | | ices stored in the HDF5 format; | |
| | | //// the hdf5.h header file must be available on your system and you will n | |
| | | eed to link with the hdf5 library (eg. -lhdf5) | |
| | | #endif | |
| | | | |
| #if !defined(ARMA_MAT_PREALLOC) | | #if !defined(ARMA_MAT_PREALLOC) | |
| #define ARMA_MAT_PREALLOC 16 | | #define ARMA_MAT_PREALLOC 16 | |
| #endif | | #endif | |
| //// This is the number of preallocated elements used by matrices and vecto
rs; | | //// This is the number of preallocated elements used by matrices and vecto
rs; | |
| //// it must be an integer that is at least 1. | | //// it must be an integer that is at least 1. | |
| //// If you mainly use lots of very small vectors (eg. <= 4 elements), | | //// If you mainly use lots of very small vectors (eg. <= 4 elements), | |
| //// change the number to the size of your vectors. | | //// change the number to the size of your vectors. | |
| | | | |
| #if !defined(ARMA_SPMAT_CHUNKSIZE) | | #if !defined(ARMA_SPMAT_CHUNKSIZE) | |
| | | | |
| skipping to change at line 87 | | skipping to change at line 99 | |
| //// Uncomment the above line if you want to disable all run-time checks. | | //// Uncomment the above line if you want to disable all run-time checks. | |
| //// This will result in faster code, but you first need to make sure that
your code runs correctly! | | //// This will result in faster code, but you first need to make sure that
your code runs correctly! | |
| //// We strongly recommend to have the run-time checks enabled during devel
opment, | | //// We strongly recommend to have the run-time checks enabled during devel
opment, | |
| //// as this greatly aids in finding mistakes in your code, and hence speed
s up development. | | //// as this greatly aids in finding mistakes in your code, and hence speed
s up development. | |
| //// We recommend that run-time checks be disabled _only_ for the shipped v
ersion of your program. | | //// We recommend that run-time checks be disabled _only_ for the shipped v
ersion of your program. | |
| | | | |
| // #define ARMA_EXTRA_DEBUG | | // #define ARMA_EXTRA_DEBUG | |
| //// Uncomment the above line if you want to see the function traces of how
Armadillo evaluates expressions. | | //// Uncomment the above line if you want to see the function traces of how
Armadillo evaluates expressions. | |
| //// This is mainly useful for debugging of the library. | | //// This is mainly useful for debugging of the library. | |
| | | | |
|
| #define ARMA_USE_BOOST | | // #define ARMA_USE_BOOST | |
| #define ARMA_USE_BOOST_DATE | | // #define ARMA_USE_BOOST_DATE | |
| #define ARMA_USE_WRAPPER | | | |
| #define ARMA_USE_HDF5 | | | |
| | | | |
| #if !defined(ARMA_DEFAULT_OSTREAM) | | #if !defined(ARMA_DEFAULT_OSTREAM) | |
| #define ARMA_DEFAULT_OSTREAM std::cout | | #define ARMA_DEFAULT_OSTREAM std::cout | |
| #endif | | #endif | |
| | | | |
| #define ARMA_PRINT_LOGIC_ERRORS | | #define ARMA_PRINT_LOGIC_ERRORS | |
| #define ARMA_PRINT_RUNTIME_ERRORS | | #define ARMA_PRINT_RUNTIME_ERRORS | |
| //#define ARMA_PRINT_HDF5_ERRORS | | //#define ARMA_PRINT_HDF5_ERRORS | |
| | | | |
| #define ARMA_HAVE_STD_ISFINITE | | #define ARMA_HAVE_STD_ISFINITE | |
| | | | |
End of changes. 4 change blocks. |
| 4 lines changed or deleted | | 18 lines changed or added | |
|
| fn_dot.hpp | | fn_dot.hpp | |
|
| // Copyright (C) 2008-2010 NICTA (www.nicta.com.au) | | // Copyright (C) 2008-2012 NICTA (www.nicta.com.au) | |
| // Copyright (C) 2008-2010 Conrad Sanderson | | // Copyright (C) 2008-2012 Conrad Sanderson | |
| | | // Copyright (C) 2012 Ryan Curtin | |
| // | | // | |
| // This file is part of the Armadillo C++ library. | | // This file is part of the Armadillo C++ library. | |
| // It is provided without any warranty of fitness | | // It is provided without any warranty of fitness | |
| // for any purpose. You can redistribute this file | | // for any purpose. You can redistribute this file | |
| // and/or modify it under the terms of the GNU | | // and/or modify it under the terms of the GNU | |
| // Lesser General Public License (LGPL) as published | | // Lesser General Public License (LGPL) as published | |
| // by the Free Software Foundation, either version 3 | | // by the Free Software Foundation, either version 3 | |
| // of the License or (at your option) any later version. | | // of the License or (at your option) any later version. | |
| // (see http://www.opensource.org/licenses for more info) | | // (see http://www.opensource.org/licenses for more info) | |
| | | | |
| //! \addtogroup fn_dot | | //! \addtogroup fn_dot | |
| //! @{ | | //! @{ | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_inline | | arma_inline | |
| arma_warn_unused | | arma_warn_unused | |
|
| typename T1::elem_type | | typename | |
| | | enable_if2 | |
| | | < | |
| | | is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typena | |
| | | me T1::elem_type, typename T2::elem_type>::value, | |
| | | typename T1::elem_type | |
| | | >::result | |
| dot | | dot | |
| ( | | ( | |
|
| const Base<typename T1::elem_type,T1>& A, | | const T1& A, | |
| const Base<typename T1::elem_type,T2>& B | | const T2& B | |
| ) | | ) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| return op_dot::apply(A,B); | | return op_dot::apply(A,B); | |
| } | | } | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_inline | | arma_inline | |
| arma_warn_unused | | arma_warn_unused | |
|
| typename T1::elem_type | | typename | |
| | | enable_if2 | |
| | | < | |
| | | is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typena | |
| | | me T1::elem_type, typename T2::elem_type>::value, | |
| | | typename T1::elem_type | |
| | | >::result | |
| norm_dot | | norm_dot | |
| ( | | ( | |
|
| const Base<typename T1::elem_type,T1>& A, | | const T1& A, | |
| const Base<typename T1::elem_type,T2>& B, | | const T2& B | |
| const typename arma_blas_type_only<typename T1::elem_type>::result* junk | | | |
| = 0 | | | |
| ) | | ) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
|
| arma_ignore(junk); | | | |
| | | | |
| return op_norm_dot::apply(A,B); | | return op_norm_dot::apply(A,B); | |
| } | | } | |
| | | | |
| // | | // | |
| // cdot | | // cdot | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_inline | | arma_inline | |
| arma_warn_unused | | arma_warn_unused | |
|
| typename T1::elem_type | | typename | |
| | | enable_if2 | |
| | | < | |
| | | is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typena | |
| | | me T1::elem_type, typename T2::elem_type>::value && is_not_complex<typename | |
| | | T1::elem_type>::value, | |
| | | typename T1::elem_type | |
| | | >::result | |
| cdot | | cdot | |
| ( | | ( | |
|
| const Base<typename T1::elem_type,T1>& A, | | const T1& A, | |
| const Base<typename T1::elem_type,T2>& B, | | const T2& B | |
| const typename arma_cx_only<typename T1::elem_type>::result* junk = 0 | | | |
| ) | | ) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
|
| arma_ignore(junk); | | | |
| | | | |
|
| return op_cdot::apply(A,B); | | return op_dot::apply(A,B); | |
| } | | } | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_inline | | arma_inline | |
| arma_warn_unused | | arma_warn_unused | |
|
| typename T1::elem_type | | typename | |
| | | enable_if2 | |
| | | < | |
| | | is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typena | |
| | | me T1::elem_type, typename T2::elem_type>::value && is_complex<typename T1: | |
| | | :elem_type>::value, | |
| | | typename T1::elem_type | |
| | | >::result | |
| cdot | | cdot | |
| ( | | ( | |
|
| const Base<typename T1::elem_type,T1>& A, | | const T1& A, | |
| const Base<typename T1::elem_type,T2>& B, | | const T2& B | |
| const typename arma_not_cx<typename T1::elem_type>::result* junk = 0 | | | |
| ) | | ) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
|
| arma_ignore(junk); | | | |
| | | | |
|
| return op_dot::apply(A,B); | | return op_cdot::apply(A,B); | |
| } | | } | |
| | | | |
| // convert dot(htrans(x), y) to cdot(x,y) | | // convert dot(htrans(x), y) to cdot(x,y) | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_inline | | arma_inline | |
| arma_warn_unused | | arma_warn_unused | |
|
| typename T1::elem_type | | typename | |
| | | enable_if2 | |
| | | < | |
| | | is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename | |
| | | T2::elem_type>::value && is_complex<typename T1::elem_type>::value, | |
| | | typename T1::elem_type | |
| | | >::result | |
| dot | | dot | |
| ( | | ( | |
| const Op<T1, op_htrans>& A, | | const Op<T1, op_htrans>& A, | |
|
| const Base<typename T1::elem_type,T2>& B, | | const T2& B | |
| const typename arma_cx_only<typename T1::elem_type>::result* junk = 0 | | | |
| ) | | ) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
|
| arma_ignore(junk); | | | |
| | | | |
| return cdot(A.m, B); | | return cdot(A.m, B); | |
| } | | } | |
| | | | |
| //! dot product of two sparse objects | | //! dot product of two sparse objects | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| inline | | inline | |
| arma_warn_unused | | arma_warn_unused | |
| typename | | typename | |
| enable_if2 | | enable_if2 | |
| | | | |
| skipping to change at line 178 | | skipping to change at line 196 | |
| arma_inline | | arma_inline | |
| arma_warn_unused | | arma_warn_unused | |
| typename | | typename | |
| enable_if2 | | enable_if2 | |
| <(is_arma_sparse_type<T1>::value) && (is_arma_type<T2>::value) && (is_sam
e_type<typename T1::elem_type, typename T2::elem_type>::value), | | <(is_arma_sparse_type<T1>::value) && (is_arma_type<T2>::value) && (is_sam
e_type<typename T1::elem_type, typename T2::elem_type>::value), | |
| typename T1::elem_type | | typename T1::elem_type | |
| >::result | | >::result | |
| dot | | dot | |
| ( | | ( | |
| const SpBase<typename T1::elem_type, T1>& x, | | const SpBase<typename T1::elem_type, T1>& x, | |
|
| const Base<typename T2::elem_type, T2>& y | | const Base<typename T2::elem_type, T2>& y | |
| ) | | ) | |
| { | | { | |
| // this is commutative | | // this is commutative | |
| return dot(y, x); | | return dot(y, x); | |
| } | | } | |
| | | | |
| //! dot product of one dense and one sparse object | | //! dot product of one dense and one sparse object | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| inline | | inline | |
| arma_warn_unused | | arma_warn_unused | |
| typename | | typename | |
| enable_if2 | | enable_if2 | |
| <(is_arma_type<T1>::value) && (is_arma_sparse_type<T2>::value) && (is_sam
e_type<typename T1::elem_type, typename T2::elem_type>::value), | | <(is_arma_type<T1>::value) && (is_arma_sparse_type<T2>::value) && (is_sam
e_type<typename T1::elem_type, typename T2::elem_type>::value), | |
| typename T1::elem_type | | typename T1::elem_type | |
| >::result | | >::result | |
| dot | | dot | |
| ( | | ( | |
|
| const Base<typename T1::elem_type, T1>& x, | | const Base<typename T1::elem_type, T1>& x, | |
| const SpBase<typename T2::elem_type, T2>& y | | const SpBase<typename T2::elem_type, T2>& y | |
| ) | | ) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| const Proxy<T1> pa(x.get_ref()); | | const Proxy<T1> pa(x.get_ref()); | |
| const SpProxy<T2> pb(y.get_ref()); | | const SpProxy<T2> pb(y.get_ref()); | |
| | | | |
| arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_ro
ws(), pb.get_n_cols(), "dot()"); | | arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_ro
ws(), pb.get_n_cols(), "dot()"); | |
| | | | |
| | | | |
End of changes. 19 change blocks. |
| 29 lines changed or deleted | | 53 lines changed or added | |
|
| op_dot_bones.hpp | | op_dot_bones.hpp | |
| | | | |
| skipping to change at line 24 | | skipping to change at line 24 | |
| //! @{ | | //! @{ | |
| | | | |
| //! \brief | | //! \brief | |
| //! dot product operation | | //! dot product operation | |
| | | | |
| class op_dot | | class op_dot | |
| { | | { | |
| public: | | public: | |
| | | | |
| template<typename eT> | | template<typename eT> | |
|
| arma_hot arma_pure arma_inline static eT direct_dot_arma(const uword n_el | | arma_hot arma_pure arma_inline static | |
| em, const eT* const A, const eT* const B); | | typename arma_not_cx<eT>::result | |
| | | direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) | |
| | | ; | |
| | | | |
| | | template<typename eT> | |
| | | arma_hot arma_pure inline static | |
| | | typename arma_cx_only<eT>::result | |
| | | direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) | |
| | | ; | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| arma_hot arma_pure inline static typename arma_float_only<eT>::result | | arma_hot arma_pure inline static typename arma_float_only<eT>::result | |
| direct_dot(const uword n_elem, const eT* const A, const eT* const B); | | direct_dot(const uword n_elem, const eT* const A, const eT* const B); | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| arma_hot arma_pure inline static typename arma_cx_only<eT>::result | | arma_hot arma_pure inline static typename arma_cx_only<eT>::result | |
| direct_dot(const uword n_elem, const eT* const A, const eT* const B); | | direct_dot(const uword n_elem, const eT* const A, const eT* const B); | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| arma_hot arma_pure inline static typename arma_integral_only<eT>::result | | arma_hot arma_pure inline static typename arma_integral_only<eT>::result | |
| direct_dot(const uword n_elem, const eT* const A, const eT* const B); | | direct_dot(const uword n_elem, const eT* const A, const eT* const B); | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| arma_hot arma_pure inline static eT direct_dot(const uword n_elem, const
eT* const A, const eT* const B, const eT* C); | | arma_hot arma_pure inline static eT direct_dot(const uword n_elem, const
eT* const A, const eT* const B, const eT* C); | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
|
| arma_hot arma_inline static typename T1::elem_type apply(const Base<typen | | arma_hot inline static typename T1::elem_type apply(const T1& X, const T2 | |
| ame T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y); | | & Y); | |
| | | | |
| | | template<typename T1, typename T2> | |
| | | arma_hot inline static typename T1::elem_type apply_unwrap(const T1& X, c | |
| | | onst T2& Y); | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
|
| arma_hot inline static typename T1::elem_type apply_unwrap(const Base<typ
ename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y); | | arma_hot inline static typename arma_not_cx<typename T1::elem_type>::resu
lt apply_proxy(const T1& X, const T2& Y); | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
|
| arma_hot inline static typename T1::elem_type apply_proxy (const Base<typ
ename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y); | | arma_hot inline static typename arma_cx_only<typename T1::elem_type>::res
ult apply_proxy(const T1& X, const T2& Y); | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| arma_hot inline static eT dot_and_copy_row(eT* out, const Mat<eT>& A, con
st uword row, const eT* B_mem, const uword N); | | arma_hot inline static eT dot_and_copy_row(eT* out, const Mat<eT>& A, con
st uword row, const eT* B_mem, const uword N); | |
| }; | | }; | |
| | | | |
| //! \brief | | //! \brief | |
| //! normalised dot product operation | | //! normalised dot product operation | |
| | | | |
| class op_norm_dot | | class op_norm_dot | |
| { | | { | |
| public: | | public: | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
|
| arma_hot inline static typename T1::elem_type apply (const Base<typ
ename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y); | | arma_hot inline static typename T1::elem_type apply (const T1& X, c
onst T2& Y); | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
|
| arma_hot inline static typename T1::elem_type apply_unwrap(const Base<typ
ename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y); | | arma_hot inline static typename T1::elem_type apply_unwrap(const T1& X, c
onst T2& Y); | |
| }; | | }; | |
| | | | |
|
| | | //! \brief | |
| | | //! complex conjugate dot product operation | |
| | | | |
| class op_cdot | | class op_cdot | |
| { | | { | |
| public: | | public: | |
| | | | |
|
| | | template<typename eT> | |
| | | arma_hot inline static eT direct_cdot_arma(const uword n_elem, const eT* | |
| | | const A, const eT* const B); | |
| | | | |
| | | template<typename eT> | |
| | | arma_hot inline static eT direct_cdot(const uword n_elem, const eT* const | |
| | | A, const eT* const B); | |
| | | | |
| | | template<typename T1, typename T2> | |
| | | arma_hot inline static typename T1::elem_type apply (const T1& X, c | |
| | | onst T2& Y); | |
| | | | |
| | | template<typename T1, typename T2> | |
| | | arma_hot inline static typename T1::elem_type apply_unwrap(const T1& X, c | |
| | | onst T2& Y); | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
|
| arma_hot arma_inline static typename T1::elem_type apply(const Base<typen
ame T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y); | | arma_hot inline static typename T1::elem_type apply_proxy (const T1& X, c
onst T2& Y); | |
| }; | | }; | |
| | | | |
| //! @} | | //! @} | |
| | | | |
End of changes. 9 change blocks. |
| 9 lines changed or deleted | | 40 lines changed or added | |
|
| op_dot_meat.hpp | | op_dot_meat.hpp | |
| | | | |
| skipping to change at line 16 | | skipping to change at line 16 | |
| // for any purpose. You can redistribute this file | | // for any purpose. You can redistribute this file | |
| // and/or modify it under the terms of the GNU | | // and/or modify it under the terms of the GNU | |
| // Lesser General Public License (LGPL) as published | | // Lesser General Public License (LGPL) as published | |
| // by the Free Software Foundation, either version 3 | | // by the Free Software Foundation, either version 3 | |
| // of the License or (at your option) any later version. | | // of the License or (at your option) any later version. | |
| // (see http://www.opensource.org/licenses for more info) | | // (see http://www.opensource.org/licenses for more info) | |
| | | | |
| //! \addtogroup op_dot | | //! \addtogroup op_dot | |
| //! @{ | | //! @{ | |
| | | | |
|
| //! for two arrays, generic version | | //! for two arrays, generic version for non-complex values | |
| template<typename eT> | | template<typename eT> | |
| arma_hot | | arma_hot | |
| arma_pure | | arma_pure | |
| arma_inline | | arma_inline | |
|
| eT | | typename arma_not_cx<eT>::result | |
| op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* co
nst B) | | op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* co
nst B) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| eT val1 = eT(0); | | eT val1 = eT(0); | |
| eT val2 = eT(0); | | eT val2 = eT(0); | |
| | | | |
| uword i, j; | | uword i, j; | |
| | | | |
| for(i=0, j=1; j<n_elem; i+=2, j+=2) | | for(i=0, j=1; j<n_elem; i+=2, j+=2) | |
| | | | |
| skipping to change at line 45 | | skipping to change at line 45 | |
| } | | } | |
| | | | |
| if(i < n_elem) | | if(i < n_elem) | |
| { | | { | |
| val1 += A[i] * B[i]; | | val1 += A[i] * B[i]; | |
| } | | } | |
| | | | |
| return val1 + val2; | | return val1 + val2; | |
| } | | } | |
| | | | |
|
| | | //! for two arrays, generic version for complex values | |
| | | template<typename eT> | |
| | | arma_hot | |
| | | arma_pure | |
| | | inline | |
| | | typename arma_cx_only<eT>::result | |
| | | op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* co | |
| | | nst B) | |
| | | { | |
| | | arma_extra_debug_sigprint(); | |
| | | | |
| | | typedef typename get_pod_type<eT>::result T; | |
| | | | |
| | | T val_real = T(0); | |
| | | T val_imag = T(0); | |
| | | | |
| | | for(uword i=0; i<n_elem; ++i) | |
| | | { | |
| | | const std::complex<T>& X = A[i]; | |
| | | const std::complex<T>& Y = B[i]; | |
| | | | |
| | | const T a = X.real(); | |
| | | const T b = X.imag(); | |
| | | | |
| | | const T c = Y.real(); | |
| | | const T d = Y.imag(); | |
| | | | |
| | | val_real += (a*c) - (b*d); | |
| | | val_imag += (a*d) + (b*c); | |
| | | } | |
| | | | |
| | | return std::complex<T>(val_real, val_imag); | |
| | | } | |
| | | | |
| //! for two arrays, float and double version | | //! for two arrays, float and double version | |
| template<typename eT> | | template<typename eT> | |
| arma_hot | | arma_hot | |
| arma_pure | | arma_pure | |
| inline | | inline | |
| typename arma_float_only<eT>::result | | typename arma_float_only<eT>::result | |
| op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B
) | | op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B
) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| if( n_elem <= (128/sizeof(eT)) ) | | if( n_elem <= 32u ) | |
| { | | { | |
| return op_dot::direct_dot_arma(n_elem, A, B); | | return op_dot::direct_dot_arma(n_elem, A, B); | |
| } | | } | |
| else | | else | |
| { | | { | |
| #if defined(ARMA_USE_ATLAS) | | #if defined(ARMA_USE_ATLAS) | |
| { | | { | |
| arma_extra_debug_print("atlas::cblas_dot()"); | | arma_extra_debug_print("atlas::cblas_dot()"); | |
| | | | |
| return atlas::cblas_dot(n_elem, A, B); | | return atlas::cblas_dot(n_elem, A, B); | |
| | | | |
| skipping to change at line 89 | | skipping to change at line 122 | |
| } | | } | |
| | | | |
| //! for two arrays, complex version | | //! for two arrays, complex version | |
| template<typename eT> | | template<typename eT> | |
| inline | | inline | |
| arma_hot | | arma_hot | |
| arma_pure | | arma_pure | |
| typename arma_cx_only<eT>::result | | typename arma_cx_only<eT>::result | |
| op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B
) | | op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B
) | |
| { | | { | |
|
| #if defined(ARMA_USE_ATLAS) | | if( n_elem <= 16u ) | |
| { | | | |
| arma_extra_debug_print("atlas::cx_cblas_dot()"); | | | |
| | | | |
| return atlas::cx_cblas_dot(n_elem, A, B); | | | |
| } | | | |
| #elif defined(ARMA_USE_BLAS) | | | |
| { | | { | |
|
| // TODO: work out the mess with zdotu() and zdotu_sub() in BLAS | | | |
| return op_dot::direct_dot_arma(n_elem, A, B); | | return op_dot::direct_dot_arma(n_elem, A, B); | |
| } | | } | |
|
| #else | | else | |
| { | | { | |
|
| return op_dot::direct_dot_arma(n_elem, A, B); | | #if defined(ARMA_USE_ATLAS) | |
| | | { | |
| | | arma_extra_debug_print("atlas::cx_cblas_dot()"); | |
| | | | |
| | | return atlas::cx_cblas_dot(n_elem, A, B); | |
| | | } | |
| | | #elif defined(ARMA_USE_BLAS) | |
| | | { | |
| | | arma_extra_debug_print("blas::dot()"); | |
| | | | |
| | | return blas::dot(n_elem, A, B); | |
| | | } | |
| | | #else | |
| | | { | |
| | | return op_dot::direct_dot_arma(n_elem, A, B); | |
| | | } | |
| | | #endif | |
| } | | } | |
|
| #endif | | | |
| } | | } | |
| | | | |
| //! for two arrays, integral version | | //! for two arrays, integral version | |
| template<typename eT> | | template<typename eT> | |
| arma_hot | | arma_hot | |
| arma_pure | | arma_pure | |
| inline | | inline | |
| typename arma_integral_only<eT>::result | | typename arma_integral_only<eT>::result | |
| op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B
) | | op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B
) | |
| { | | { | |
| | | | |
| skipping to change at line 140 | | skipping to change at line 181 | |
| for(uword i=0; i<n_elem; ++i) | | for(uword i=0; i<n_elem; ++i) | |
| { | | { | |
| val += A[i] * B[i] * C[i]; | | val += A[i] * B[i] * C[i]; | |
| } | | } | |
| | | | |
| return val; | | return val; | |
| } | | } | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_hot | | arma_hot | |
|
| arma_inline | | inline | |
| typename T1::elem_type | | typename T1::elem_type | |
|
| op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename
T1::elem_type,T2>& Y) | | op_dot::apply(const T1& X, const T2& Y) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) ) | | if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) ) | |
| { | | { | |
| return op_dot::apply_unwrap(X,Y); | | return op_dot::apply_unwrap(X,Y); | |
| } | | } | |
| else | | else | |
| { | | { | |
| return op_dot::apply_proxy(X,Y); | | return op_dot::apply_proxy(X,Y); | |
| } | | } | |
| } | | } | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_hot | | arma_hot | |
|
| arma_inline | | inline | |
| typename T1::elem_type | | typename T1::elem_type | |
|
| op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<t
ypename T1::elem_type,T2>& Y) | | op_dot::apply_unwrap(const T1& X, const T2& Y) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| typedef typename T1::elem_type eT; | | typedef typename T1::elem_type eT; | |
| | | | |
|
| const unwrap<T1> tmp1(X.get_ref()); | | const unwrap<T1> tmp1(X); | |
| const unwrap<T2> tmp2(Y.get_ref()); | | const unwrap<T2> tmp2(Y); | |
| | | | |
| const Mat<eT>& A = tmp1.M; | | const Mat<eT>& A = tmp1.M; | |
| const Mat<eT>& B = tmp2.M; | | const Mat<eT>& B = tmp2.M; | |
| | | | |
| arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the s
ame number of elements" ); | | arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the s
ame number of elements" ); | |
| | | | |
| return op_dot::direct_dot(A.n_elem, A.mem, B.mem); | | return op_dot::direct_dot(A.n_elem, A.mem, B.mem); | |
| } | | } | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_hot | | arma_hot | |
| inline | | inline | |
|
| typename T1::elem_type | | typename arma_not_cx<typename T1::elem_type>::result | |
| op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<ty | | op_dot::apply_proxy(const T1& X, const T2& Y) | |
| pename T1::elem_type,T2>& Y) | | | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| typedef typename T1::elem_type eT; | | typedef typename T1::elem_type eT; | |
| typedef typename Proxy<T1>::ea_type ea_type1; | | typedef typename Proxy<T1>::ea_type ea_type1; | |
| typedef typename Proxy<T2>::ea_type ea_type2; | | typedef typename Proxy<T2>::ea_type ea_type2; | |
| | | | |
|
| const Proxy<T1> A(X.get_ref()); | | | |
| const Proxy<T2> B(Y.get_ref()); | | | |
| | | | |
| const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy
<T2>::prefer_at_accessor); | | const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy
<T2>::prefer_at_accessor); | |
| | | | |
| if(prefer_at_accessor == false) | | if(prefer_at_accessor == false) | |
| { | | { | |
|
| arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "dot(): objects m | | const Proxy<T1> PA(X); | |
| ust have the same number of elements" ); | | const Proxy<T2> PB(Y); | |
| | | | |
| | | const uword N = PA.get_n_elem(); | |
| | | | |
| | | arma_debug_check( (N != PB.get_n_elem()), "dot(): objects must have the | |
| | | same number of elements" ); | |
| | | | |
|
| const uword N = A.get_n_elem(); | | ea_type1 A = PA.get_ea(); | |
| ea_type1 PA = A.get_ea(); | | ea_type2 B = PB.get_ea(); | |
| ea_type2 PB = B.get_ea(); | | | |
| | | | |
| eT val1 = eT(0); | | eT val1 = eT(0); | |
| eT val2 = eT(0); | | eT val2 = eT(0); | |
| | | | |
| uword i,j; | | uword i,j; | |
| | | | |
| for(i=0, j=1; j<N; i+=2, j+=2) | | for(i=0, j=1; j<N; i+=2, j+=2) | |
| { | | { | |
|
| val1 += PA[i] * PB[i]; | | val1 += A[i] * B[i]; | |
| val2 += PA[j] * PB[j]; | | val2 += A[j] * B[j]; | |
| } | | } | |
| | | | |
| if(i < N) | | if(i < N) | |
| { | | { | |
|
| val1 += PA[i] * PB[i]; | | val1 += A[i] * B[i]; | |
| } | | } | |
| | | | |
| return val1 + val2; | | return val1 + val2; | |
| } | | } | |
| else | | else | |
| { | | { | |
|
| return op_dot::apply_unwrap(A.Q, B.Q); | | return op_dot::apply_unwrap(X,Y); | |
| | | } | |
| | | } | |
| | | | |
| | | template<typename T1, typename T2> | |
| | | arma_hot | |
| | | inline | |
| | | typename arma_cx_only<typename T1::elem_type>::result | |
| | | op_dot::apply_proxy(const T1& X, const T2& Y) | |
| | | { | |
| | | arma_extra_debug_sigprint(); | |
| | | | |
| | | typedef typename T1::elem_type eT; | |
| | | typedef typename get_pod_type<eT>::result T; | |
| | | | |
| | | typedef typename Proxy<T1>::ea_type ea_type1; | |
| | | typedef typename Proxy<T2>::ea_type ea_type2; | |
| | | | |
| | | const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy | |
| | | <T2>::prefer_at_accessor); | |
| | | | |
| | | if(prefer_at_accessor == false) | |
| | | { | |
| | | const Proxy<T1> PA(X); | |
| | | const Proxy<T2> PB(Y); | |
| | | | |
| | | const uword N = PA.get_n_elem(); | |
| | | | |
| | | arma_debug_check( (N != PB.get_n_elem()), "dot(): objects must have the | |
| | | same number of elements" ); | |
| | | | |
| | | ea_type1 A = PA.get_ea(); | |
| | | ea_type2 B = PB.get_ea(); | |
| | | | |
| | | T val_real = T(0); | |
| | | T val_imag = T(0); | |
| | | | |
| | | for(uword i=0; i<N; ++i) | |
| | | { | |
| | | const std::complex<T> X = A[i]; | |
| | | const std::complex<T> Y = B[i]; | |
| | | | |
| | | const T a = X.real(); | |
| | | const T b = X.imag(); | |
| | | | |
| | | const T c = Y.real(); | |
| | | const T d = Y.imag(); | |
| | | | |
| | | val_real += (a*c) - (b*d); | |
| | | val_imag += (a*d) + (b*c); | |
| | | } | |
| | | | |
| | | return std::complex<T>(val_real, val_imag); | |
| | | } | |
| | | else | |
| | | { | |
| | | return op_dot::apply_unwrap(X,Y); | |
| } | | } | |
| } | | } | |
| | | | |
| template<typename eT> | | template<typename eT> | |
| arma_hot | | arma_hot | |
| inline | | inline | |
| eT | | eT | |
| op_dot::dot_and_copy_row(eT* out, const Mat<eT>& A, const uword row, const
eT* B_mem, const uword N) | | op_dot::dot_and_copy_row(eT* out, const Mat<eT>& A, const uword row, const
eT* B_mem, const uword N) | |
| { | | { | |
| eT acc1 = eT(0); | | eT acc1 = eT(0); | |
| | | | |
| skipping to change at line 267 | | skipping to change at line 363 | |
| return acc1 + acc2; | | return acc1 + acc2; | |
| } | | } | |
| | | | |
| // | | // | |
| // op_norm_dot | | // op_norm_dot | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_hot | | arma_hot | |
| inline | | inline | |
| typename T1::elem_type | | typename T1::elem_type | |
|
| op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typ
ename T1::elem_type,T2>& Y) | | op_norm_dot::apply(const T1& X, const T2& Y) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| typedef typename T1::elem_type eT; | | typedef typename T1::elem_type eT; | |
| typedef typename Proxy<T1>::ea_type ea_type1; | | typedef typename Proxy<T1>::ea_type ea_type1; | |
| typedef typename Proxy<T2>::ea_type ea_type2; | | typedef typename Proxy<T2>::ea_type ea_type2; | |
| | | | |
| const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy
<T2>::prefer_at_accessor); | | const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy
<T2>::prefer_at_accessor); | |
| | | | |
| if(prefer_at_accessor == false) | | if(prefer_at_accessor == false) | |
| { | | { | |
|
| const Proxy<T1> A(X.get_ref()); | | const Proxy<T1> PA(X); | |
| const Proxy<T2> B(Y.get_ref()); | | const Proxy<T2> PB(Y); | |
| | | | |
| | | const uword N = PA.get_n_elem(); | |
| | | | |
|
| arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "norm_dot(): obje
cts must have the same number of elements" ); | | arma_debug_check( (N != PB.get_n_elem()), "norm_dot(): objects must hav
e the same number of elements" ); | |
| | | | |
|
| const uword N = A.get_n_elem(); | | ea_type1 A = PA.get_ea(); | |
| ea_type1 PA = A.get_ea(); | | ea_type2 B = PB.get_ea(); | |
| ea_type2 PB = B.get_ea(); | | | |
| | | | |
| eT acc1 = eT(0); | | eT acc1 = eT(0); | |
| eT acc2 = eT(0); | | eT acc2 = eT(0); | |
| eT acc3 = eT(0); | | eT acc3 = eT(0); | |
| | | | |
| for(uword i=0; i<N; ++i) | | for(uword i=0; i<N; ++i) | |
| { | | { | |
|
| const eT tmpA = PA[i]; | | const eT tmpA = A[i]; | |
| const eT tmpB = PB[i]; | | const eT tmpB = B[i]; | |
| | | | |
| acc1 += tmpA * tmpA; | | acc1 += tmpA * tmpA; | |
| acc2 += tmpB * tmpB; | | acc2 += tmpB * tmpB; | |
| acc3 += tmpA * tmpB; | | acc3 += tmpA * tmpB; | |
| } | | } | |
| | | | |
| return acc3 / ( std::sqrt(acc1 * acc2) ); | | return acc3 / ( std::sqrt(acc1 * acc2) ); | |
| } | | } | |
| else | | else | |
| { | | { | |
| return op_norm_dot::apply_unwrap(X, Y); | | return op_norm_dot::apply_unwrap(X, Y); | |
| } | | } | |
| } | | } | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_hot | | arma_hot | |
| inline | | inline | |
| typename T1::elem_type | | typename T1::elem_type | |
|
| op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const B
ase<typename T1::elem_type,T2>& Y) | | op_norm_dot::apply_unwrap(const T1& X, const T2& Y) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
| typedef typename T1::elem_type eT; | | typedef typename T1::elem_type eT; | |
| | | | |
|
| const unwrap<T1> tmp1(X.get_ref()); | | const unwrap<T1> tmp1(X); | |
| const unwrap<T2> tmp2(Y.get_ref()); | | const unwrap<T2> tmp2(Y); | |
| | | | |
| const Mat<eT>& A = tmp1.M; | | const Mat<eT>& A = tmp1.M; | |
| const Mat<eT>& B = tmp2.M; | | const Mat<eT>& B = tmp2.M; | |
| | | | |
| arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have
the same number of elements" ); | | arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have
the same number of elements" ); | |
| | | | |
| const uword N = A.n_elem; | | const uword N = A.n_elem; | |
| | | | |
| const eT* A_mem = A.memptr(); | | const eT* A_mem = A.memptr(); | |
| const eT* B_mem = B.memptr(); | | const eT* B_mem = B.memptr(); | |
| | | | |
| skipping to change at line 353 | | skipping to change at line 450 | |
| acc2 += tmpB * tmpB; | | acc2 += tmpB * tmpB; | |
| acc3 += tmpA * tmpB; | | acc3 += tmpA * tmpB; | |
| } | | } | |
| | | | |
| return acc3 / ( std::sqrt(acc1 * acc2) ); | | return acc3 / ( std::sqrt(acc1 * acc2) ); | |
| } | | } | |
| | | | |
| // | | // | |
| // op_cdot | | // op_cdot | |
| | | | |
|
| | | template<typename eT> | |
| | | arma_hot | |
| | | arma_pure | |
| | | inline | |
| | | eT | |
| | | op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* | |
| | | const B) | |
| | | { | |
| | | arma_extra_debug_sigprint(); | |
| | | | |
| | | typedef typename get_pod_type<eT>::result T; | |
| | | | |
| | | T val_real = T(0); | |
| | | T val_imag = T(0); | |
| | | | |
| | | for(uword i=0; i<n_elem; ++i) | |
| | | { | |
| | | const std::complex<T>& X = A[i]; | |
| | | const std::complex<T>& Y = B[i]; | |
| | | | |
| | | const T a = X.real(); | |
| | | const T b = X.imag(); | |
| | | | |
| | | const T c = Y.real(); | |
| | | const T d = Y.imag(); | |
| | | | |
| | | val_real += (a*c) + (b*d); | |
| | | val_imag += (a*d) - (b*c); | |
| | | } | |
| | | | |
| | | return std::complex<T>(val_real, val_imag); | |
| | | } | |
| | | | |
| | | template<typename eT> | |
| | | arma_hot | |
| | | arma_pure | |
| | | inline | |
| | | eT | |
| | | op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const | |
| | | B) | |
| | | { | |
| | | arma_extra_debug_sigprint(); | |
| | | | |
| | | if( n_elem <= 32u ) | |
| | | { | |
| | | return op_cdot::direct_cdot_arma(n_elem, A, B); | |
| | | } | |
| | | else | |
| | | { | |
| | | #if defined(ARMA_USE_BLAS) | |
| | | { | |
| | | arma_extra_debug_print("blas::gemv()"); | |
| | | | |
| | | // using gemv() workaround due to compatibility issues with cdotc() a | |
| | | nd zdotc() | |
| | | | |
| | | const char trans = 'C'; | |
| | | | |
| | | const blas_int m = blas_int(n_elem); | |
| | | const blas_int n = 1; | |
| | | //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1); | |
| | | const blas_int inc = 1; | |
| | | | |
| | | const eT alpha = eT(1); | |
| | | const eT beta = eT(0); | |
| | | | |
| | | eT result[2]; // paranoia: using two elements instead of one | |
| | | | |
| | | //blas::gemv(&trans, &m, &n, &alpha, A, &lda, B, &inc, &beta, &result | |
| | | [0], &inc); | |
| | | blas::gemv(&trans, &m, &n, &alpha, A, &m, B, &inc, &beta, &result[0], | |
| | | &inc); | |
| | | | |
| | | return result[0]; | |
| | | } | |
| | | #elif defined(ARMA_USE_ATLAS) | |
| | | { | |
| | | // TODO: use dedicated atlas functions cblas_cdotc_sub() and cblas_zd | |
| | | otc_sub() and retune threshold | |
| | | | |
| | | return op_cdot::direct_cdot_arma(n_elem, A, B); | |
| | | } | |
| | | #else | |
| | | { | |
| | | return op_cdot::direct_cdot_arma(n_elem, A, B); | |
| | | } | |
| | | #endif | |
| | | } | |
| | | } | |
| | | | |
| template<typename T1, typename T2> | | template<typename T1, typename T2> | |
| arma_hot | | arma_hot | |
|
| arma_inline | | inline | |
| typename T1::elem_type | | typename T1::elem_type | |
|
| op_cdot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typenam
e T1::elem_type,T2>& Y) | | op_cdot::apply(const T1& X, const T2& Y) | |
| { | | { | |
| arma_extra_debug_sigprint(); | | arma_extra_debug_sigprint(); | |
| | | | |
|
| typedef typename T1::elem_type eT; | | if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) ) | |
| | | { | |
| | | return op_cdot::apply_unwrap(X,Y); | |
| | | } | |
| | | else | |
| | | { | |
| | | return op_cdot::apply_proxy(X,Y); | |
| | | } | |
| | | } | |
| | | | |
| | | template<typename T1, typename T2> | |
| | | arma_hot | |
| | | inline | |
| | | typename T1::elem_type | |
| | | op_cdot::apply_unwrap(const T1& X, const T2& Y) | |
| | | { | |
| | | arma_extra_debug_sigprint(); | |
| | | | |
| | | typedef typename T1::elem_type eT; | |
| | | | |
| | | const unwrap<T1> tmp1(X); | |
| | | const unwrap<T2> tmp2(Y); | |
| | | | |
| | | const Mat<eT>& A = tmp1.M; | |
| | | const Mat<eT>& B = tmp2.M; | |
| | | | |
| | | arma_debug_check( (A.n_elem != B.n_elem), "cdot(): objects must have the | |
| | | same number of elements" ); | |
| | | | |
| | | return op_cdot::direct_cdot( A.n_elem, A.mem, B.mem ); | |
| | | } | |
| | | | |
| | | template<typename T1, typename T2> | |
| | | arma_hot | |
| | | inline | |
| | | typename T1::elem_type | |
| | | op_cdot::apply_proxy(const T1& X, const T2& Y) | |
| | | { | |
| | | arma_extra_debug_sigprint(); | |
| | | | |
| | | typedef typename T1::elem_type eT; | |
| | | typedef typename get_pod_type<eT>::result T; | |
| | | | |
| typedef typename Proxy<T1>::ea_type ea_type1; | | typedef typename Proxy<T1>::ea_type ea_type1; | |
| typedef typename Proxy<T2>::ea_type ea_type2; | | typedef typename Proxy<T2>::ea_type ea_type2; | |
| | | | |
|
| const Proxy<T1> A(X.get_ref()); | | | |
| const Proxy<T2> B(Y.get_ref()); | | | |
| | | | |
| const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) || (Proxy
<T2>::prefer_at_accessor); | | const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) || (Proxy
<T2>::prefer_at_accessor); | |
| | | | |
| if(prefer_at_accessor == false) | | if(prefer_at_accessor == false) | |
| { | | { | |
|
| arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "cdot(): objects | | const Proxy<T1> PA(X); | |
| must have the same number of elements" ); | | const Proxy<T2> PB(Y); | |
| | | | |
|
| const uword N = A.get_n_elem(); | | const uword N = PA.get_n_elem(); | |
| ea_type1 PA = A.get_ea(); | | | |
| ea_type2 PB = B.get_ea(); | | | |
| | | | |
|
| eT val1 = eT(0); | | arma_debug_check( (N != PB.get_n_elem()), "cdot(): objects must have th | |
| eT val2 = eT(0); | | e same number of elements" ); | |
| | | | |
|
| uword i,j; | | ea_type1 A = PA.get_ea(); | |
| for(i=0, j=1; j<N; i+=2, j+=2) | | ea_type2 B = PB.get_ea(); | |
| { | | | |
| val1 += std::conj(PA[i]) * PB[i]; | | | |
| val2 += std::conj(PA[j]) * PB[j]; | | | |
| } | | | |
| | | | |
|
| if(i < N) | | T val_real = T(0); | |
| | | T val_imag = T(0); | |
| | | | |
| | | for(uword i=0; i<N; ++i) | |
| { | | { | |
|
| val1 += std::conj(PA[i]) * PB[i]; | | const std::complex<T> AA = A[i]; | |
| | | const std::complex<T> BB = B[i]; | |
| | | | |
| | | const T a = AA.real(); | |
| | | const T b = AA.imag(); | |
| | | | |
| | | const T c = BB.real(); | |
| | | const T d = BB.imag(); | |
| | | | |
| | | val_real += (a*c) + (b*d); | |
| | | val_imag += (a*d) - (b*c); | |
| } | | } | |
| | | | |
|
| return val1 + val2; | | return std::complex<T>(val_real, val_imag); | |
| } | | } | |
| else | | else | |
| { | | { | |
|
| const unwrap< typename Proxy<T1>::stored_type > tmp_A(A.Q); | | return op_cdot::apply_unwrap( X, Y ); | |
| const unwrap< typename Proxy<T2>::stored_type > tmp_B(B.Q); | | | |
| | | | |
| return op_cdot::apply( tmp_A.M, tmp_B.M ); | | | |
| } | | } | |
| } | | } | |
| | | | |
| //! @} | | //! @} | |
| | | | |
End of changes. 41 change blocks. |
| 73 lines changed or deleted | | 305 lines changed or added | |
|