[
Date Prev][
Date Next][
Thread Prev][
Thread Next][
Date Index][
Thread Index]
[patch] SAL dispatch for matrix and vector products
- To: VSIPL++ Developers List <vsipl++@xxxxxxxxxxxxxxxx>
- Subject: [patch] SAL dispatch for matrix and vector products
- From: Don McCoy <don@xxxxxxxxxxxxxxxx>
- Date: Wed, 26 Oct 2005 13:48:17 -0600
I am working on BLAS dispatch as well. Patch to follow. This one just
includes SAL. Hopefully I've separated them well.
Two issues worth pointing out:
1) The exec() function checks for row-major ordering because it
calls the newer SAL functions (mat_mul) that allow the column-stride to
be specified. I believe the rows must be unit stride. This is a
little less general than the older ones (mulx) which allow non-unit
strides. Recently, we heard back from Mercury that the older ones
perform better (at this time). Given that the older ones handle
non-unit strides and are faster, should we revert to using those? If
Mercury changes in the future, then we can follow.
2) Split-complex products (other than vector-vector) are not handled
at this time. Just a reminder that we were going to discuss how to
address this issue sometime.
Regards,
--
Don McCoy
CodeSourcery, LLC
2005-10-26 Don McCoy <don@xxxxxxxxxxxxxxxx>
* src/vsip/impl/eval-sal.hpp: new file. overloaded translation
functions for matrix/vector products in SAL. dispatch routines
for same.
* src/vsip/impl/general_dispatch.hpp: new OpTags for matrix-vector
and vector-matrix products. added ImplTag for Mercury SAL.
* src/vsip/impl/matvec-prod.hpp: added generic evaluators for m-v
and v-m products. changed product functions to go through dispatch.
* src/vsip/impl/matvec.hpp: include eval-sal.hpp.
Index: src/vsip/impl/eval-sal.hpp
===================================================================
RCS file: src/vsip/impl/eval-sal.hpp
diff -N src/vsip/impl/eval-sal.hpp
*** /dev/null 1 Jan 1970 00:00:00 -0000
--- src/vsip/impl/eval-sal.hpp 26 Oct 2005 19:30:18 -0000
***************
*** 0 ****
--- 1,428 ----
+ /* Copyright (c) 2005 by CodeSourcery, LLC. All rights reserved. */
+
+ /** @file vsip/impl/eval-sal.hpp
+ @author Don McCoy
+ @date 2005-10-17
+ @brief VSIPL++ Library: SAL evaluators (for use in general dispatch).
+
+ */
+
+ #ifndef VSIP_IMPL_EVAL_SAL_HPP
+ #define VSIP_IMPL_EVAL_SAL_HPP
+
+ #ifdef VSIP_IMPL_HAVE_SAL
+
+ /***********************************************************************
+ Included Files
+ ***********************************************************************/
+
+ #include <vsip/impl/general_dispatch.hpp>
+ #include <vsip/impl/sal.hpp>
+ #include <sal.h>
+
+
+
+ /***********************************************************************
+ Declarations
+ ***********************************************************************/
+
+ namespace vsip
+ {
+
+ namespace impl
+ {
+
+ namespace sal
+ {
+
+
+ #define VSIP_IMPL_SAL_DOT( T, SAL_T, VPPFCN, SALFCN, STRIDE ) \
+ inline T \
+ VPPFCN( length_type len, \
+ T * A, stride_type A_stride, \
+ T * B, stride_type B_stride ) \
+ { \
+ T c; \
+ SALFCN( (SAL_T *) A, STRIDE * A_stride, \
+ (SAL_T *) B, STRIDE * B_stride, \
+ (SAL_T *) &c, len, 0 ); \
+ return c; \
+ }
+
+ #define VSIP_IMPL_SAL_DOT_SPLIT( T, SAL_T, VPPFCN, SALFCN, STRIDE ) \
+ inline std::complex<T> \
+ VPPFCN( length_type len, \
+ std::pair<T *, T *> const A, stride_type A_stride, \
+ std::pair<T *, T *> const B, stride_type B_stride )\
+ { \
+ T real_val, imag_val; \
+ SAL_T c = { &real_val, &imag_val }; \
+ SALFCN( (SAL_T *) &A, STRIDE * A_stride, \
+ (SAL_T *) &B, STRIDE * B_stride, \
+ (SAL_T *) &c, len, 0 ); \
+ \
+ std::complex<T> r(*c.realp, *c.imagp); \
+ return r; \
+ }
+
+ VSIP_IMPL_SAL_DOT( float, float, dot, dotprx, 1 );
+ VSIP_IMPL_SAL_DOT( double, double, dot, dotprdx, 1 );
+ VSIP_IMPL_SAL_DOT( std::complex<float>, COMPLEX, dot, cdotprx, 2 );
+ VSIP_IMPL_SAL_DOT( std::complex<double>, DOUBLE_COMPLEX, dot, cdotprdx, 2 );
+ VSIP_IMPL_SAL_DOT_SPLIT( float, COMPLEX_SPLIT, dot, zdotprx, 1 );
+ VSIP_IMPL_SAL_DOT_SPLIT( double, DOUBLE_COMPLEX_SPLIT, dot, zdotprdx, 1 );
+
+ VSIP_IMPL_SAL_DOT( std::complex<float>, COMPLEX, dotc, cidotprx, 2 );
+ VSIP_IMPL_SAL_DOT( std::complex<double>, DOUBLE_COMPLEX, dotc, cidotprdx, 2 );
+ VSIP_IMPL_SAL_DOT_SPLIT( float, COMPLEX_SPLIT, dotc, zidotprx, 1 );
+ VSIP_IMPL_SAL_DOT_SPLIT( double, DOUBLE_COMPLEX_SPLIT, dotc, zidotprdx, 1 );
+
+ #undef VSIP_IMPL_SAL_DOT
+ #undef VSIP_IMPL_SAL_DOT_SPLIT
+
+
+
+
+ #define VSIP_IMPL_SAL_PROD( T, SAL_T, VPPFCN, SALFCN, STRIDE ) \
+ inline void \
+ VPPFCN( int m, int n, int p, \
+ T *a, int as, \
+ T *b, int bs, \
+ T *c, int cs ) \
+ { \
+ SALFCN( (SAL_T *) a, as, \
+ (SAL_T *) b, bs, \
+ (SAL_T *) c, cs, \
+ m, n, p, 0, 0 ); \
+ }
+
+ VSIP_IMPL_SAL_PROD( float, float, mmul, mat_mulx, 1 );
+ VSIP_IMPL_SAL_PROD( double, double, mmul, mat_muldx, 1 );
+ VSIP_IMPL_SAL_PROD( complex<float>, COMPLEX, mmul, cmat_mulx, 1 );
+ VSIP_IMPL_SAL_PROD( complex<double>, DOUBLE_COMPLEX, mmul, cmat_muldx, 1 );
+
+ #undef VSIP_IMPL_SAL_PROD
+
+
+
+
+ template <typename T>
+ struct Sal_traits
+ {
+ static bool const valid = false;
+ };
+
+ template <>
+ struct Sal_traits<float>
+ {
+ static bool const valid = true;
+ static char const trans = 't';
+ };
+
+ template <>
+ struct Sal_traits<double>
+ {
+ static bool const valid = true;
+ static char const trans = 't';
+ };
+
+ template <>
+ struct Sal_traits<std::complex<float> >
+ {
+ static bool const valid = true;
+ static char const trans = 'c';
+ };
+
+ template <>
+ struct Sal_traits<std::complex<double> >
+ {
+ static bool const valid = true;
+ static char const trans = 'c';
+ };
+
+
+ } // namespace vsip::impl::sal
+
+
+ // SAL evaluator for vector-vector dot-product (non-conjugated).
+
+ template <typename T,
+ typename Block1,
+ typename Block2>
+ struct Evaluator<Op_prod_vv, Return_scalar<T>, Op_list_2<Block1, Block2>,
+ Mercury_sal_tag>
+ {
+ static bool const ct_valid =
+ impl::sal::Sal_traits<T>::valid &&
+ Type_equal<T, typename Block1::value_type>::value &&
+ Type_equal<T, typename Block2::value_type>::value &&
+ // check that direct access is supported
+ Ext_data_cost<Block1>::value == 0 &&
+ Ext_data_cost<Block2>::value == 0;
+
+ static bool rt_valid(Block1 const&, Block2 const&) { return true; }
+
+ static T exec(Block1 const& a, Block2 const& b)
+ {
+ assert(a.size(1, 0) == b.size(1, 0));
+
+ Ext_data<Block1> ext_a(const_cast<Block1&>(a));
+ Ext_data<Block2> ext_b(const_cast<Block2&>(b));
+
+ T r = sal::dot( a.size(1, 0),
+ ext_a.data(), ext_a.stride(0),
+ ext_b.data(), ext_b.stride(0) );
+
+ return r;
+ }
+ };
+
+
+
+ // SAL evaluator for vector-vector dot-product (conjugated).
+
+ template <typename T,
+ typename Block1,
+ typename Block2>
+ struct Evaluator<Op_prod_vv, Return_scalar<complex<T> >,
+ Op_list_2<Block1,
+ Unary_expr_block<1, impl::conj_functor,
+ Block2, complex<T> > const>,
+ Mercury_sal_tag>
+ {
+ static bool const ct_valid =
+ impl::sal::Sal_traits<complex<T> >::valid &&
+ Type_equal<complex<T>, typename Block1::value_type>::value &&
+ Type_equal<complex<T>, typename Block2::value_type>::value &&
+ // check that direct access is supported
+ Ext_data_cost<Block1>::value == 0 &&
+ Ext_data_cost<Block2>::value == 0;
+
+ static bool rt_valid(
+ Block1 const&,
+ Unary_expr_block<1, impl::conj_functor, Block2, complex<T> > const&)
+ { return true; }
+
+ static complex<T> exec(
+ Block1 const& a,
+ Unary_expr_block<1, impl::conj_functor, Block2, complex<T> > const& b)
+ {
+ assert(a.size(1, 0) == b.size(1, 0));
+
+ Ext_data<Block1> ext_a(const_cast<Block1&>(a));
+ Ext_data<Block2> ext_b(const_cast<Block2&>(b.op()));
+
+ return sal::dotc( a.size(1, 0),
+ ext_b.data(), ext_b.stride(0),
+ ext_a.data(), ext_a.stride(0) );
+ // Note:
+ // SAL cidotprx(x, y) => conj(x) * y, while
+ // VSIPL++ cvjdot(x, y) => x * conj(y)
+ }
+ };
+
+
+ // SAL evaluator for matrix-vector product
+ template <typename Block0,
+ typename Block1,
+ typename Block2>
+ struct Evaluator<Op_prod_mv, Block0, Op_list_2<Block1, Block2>,
+ Mercury_sal_tag>
+ {
+ typedef typename Block0::value_type T;
+
+ static bool const ct_valid =
+ impl::sal::Sal_traits<T>::valid &&
+ Type_equal<T, typename Block1::value_type>::value &&
+ Type_equal<T, typename Block2::value_type>::value &&
+ // check that direct access is supported
+ Ext_data_cost<Block1>::value == 0 &&
+ Ext_data_cost<Block2>::value == 0;
+
+ static bool rt_valid(Block0& r, Block1 const& a, Block2 const& b)
+ {
+ typedef typename Block_layout<Block1>::order_type order1_type;
+ typedef typename Block_layout<Block2>::order_type order2_type;
+
+ Ext_data<Block1> ext_a(const_cast<Block1&>(a));
+ Ext_data<Block2> ext_b(const_cast<Block2&>(b));
+
+ bool is_a_row = Type_equal<order1_type, row2_type>::value;
+
+ bool unit_stride = false;
+ if ( is_a_row )
+ unit_stride = (ext_a.stride(1) == 1) && (ext_b.stride(0) == 1);
+
+ return unit_stride;
+ }
+
+ static void exec(Block0& r, Block1 const& a, Block2 const& b)
+ {
+ assert(a.size(2, 1) == b.size(1, 0));
+
+ typedef typename Block0::value_type RT;
+
+ typedef typename Block_layout<Block0>::order_type order0_type;
+ typedef typename Block_layout<Block1>::order_type order1_type;
+ typedef typename Block_layout<Block2>::order_type order2_type;
+
+ Ext_data<Block0> ext_r(const_cast<Block0&>(r));
+ Ext_data<Block1> ext_a(const_cast<Block1&>(a));
+ Ext_data<Block2> ext_b(const_cast<Block2&>(b));
+
+ if (Type_equal<order1_type, row2_type>::value)
+ {
+ // SAL supports column-stride parameter only (rows must be unit-stride)
+ sal::mmul( a.size(2, 0), // M
+ 1, // N
+ a.size(2, 1), // P
+ ext_a.data(), ext_a.stride(0),
+ ext_b.data(), ext_b.stride(0),
+ ext_r.data(), ext_r.stride(0) );
+ }
+ else
+ assert(0);
+ }
+ };
+
+
+ // SAL evaluator for vector-matrix product
+ template <typename Block0,
+ typename Block1,
+ typename Block2>
+ struct Evaluator<Op_prod_vm, Block0, Op_list_2<Block1, Block2>,
+ Mercury_sal_tag>
+ {
+ typedef typename Block0::value_type T;
+
+ static bool const ct_valid =
+ impl::sal::Sal_traits<T>::valid &&
+ Type_equal<T, typename Block1::value_type>::value &&
+ Type_equal<T, typename Block2::value_type>::value &&
+ // check that direct access is supported
+ Ext_data_cost<Block1>::value == 0 &&
+ Ext_data_cost<Block2>::value == 0;
+
+ static bool rt_valid(Block0& r, Block1 const& a, Block2 const& b)
+ {
+ typedef typename Block_layout<Block1>::order_type order1_type;
+ typedef typename Block_layout<Block2>::order_type order2_type;
+
+ Ext_data<Block1> ext_a(const_cast<Block1&>(a));
+ Ext_data<Block2> ext_b(const_cast<Block2&>(b));
+
+ bool is_b_row = Type_equal<order2_type, row2_type>::value;
+
+ bool unit_stride = false;
+ if ( is_b_row )
+ unit_stride = (ext_a.stride(0) == 1) && (ext_b.stride(1) == 1);
+
+ return unit_stride;
+ }
+
+ static void exec(Block0& r, Block1 const& a, Block2 const& b)
+ {
+ assert(a.size(1, 0) == b.size(2, 0));
+
+ typedef typename Block0::value_type RT;
+
+ typedef typename Block_layout<Block0>::order_type order0_type;
+ typedef typename Block_layout<Block1>::order_type order1_type;
+ typedef typename Block_layout<Block2>::order_type order2_type;
+
+ Ext_data<Block0> ext_r(const_cast<Block0&>(r));
+ Ext_data<Block1> ext_a(const_cast<Block1&>(a));
+ Ext_data<Block2> ext_b(const_cast<Block2&>(b));
+
+ if (Type_equal<order1_type, row2_type>::value)
+ {
+ // SAL supports column-stride parameter only (rows must be unit-stride)
+ sal::mmul( 1, // M
+ b.size(2, 1), // N
+ a.size(1, 0), // P
+ ext_a.data(), ext_a.stride(0),
+ ext_b.data(), ext_b.stride(0),
+ ext_r.data(), ext_r.stride(0) );
+ }
+ else
+ assert(0);
+ }
+ };
+
+
+
+ // SAL evaluator for matrix-matrix products.
+
+ template <typename Block0,
+ typename Block1,
+ typename Block2>
+ struct Evaluator<Op_prod_mm, Block0, Op_list_2<Block1, Block2>,
+ Mercury_sal_tag>
+ {
+ typedef typename Block0::value_type T;
+
+ static bool const ct_valid =
+ impl::sal::Sal_traits<T>::valid &&
+ Type_equal<T, typename Block1::value_type>::value &&
+ Type_equal<T, typename Block2::value_type>::value &&
+ // check that direct access is supported
+ Ext_data_cost<Block0>::value == 0 &&
+ Ext_data_cost<Block1>::value == 0 &&
+ Ext_data_cost<Block2>::value == 0;
+
+ static bool rt_valid(Block0& r, Block1 const& a, Block2 const& b)
+ {
+ typedef typename Block_layout<Block0>::order_type order0_type;
+ typedef typename Block_layout<Block1>::order_type order1_type;
+ typedef typename Block_layout<Block2>::order_type order2_type;
+
+ Ext_data<Block0> ext_r(const_cast<Block0&>(r));
+ Ext_data<Block1> ext_a(const_cast<Block1&>(a));
+ Ext_data<Block2> ext_b(const_cast<Block2&>(b));
+
+ bool is_r_row = Type_equal<order0_type, row2_type>::value;
+ bool is_a_row = Type_equal<order1_type, row2_type>::value;
+ bool is_b_row = Type_equal<order2_type, row2_type>::value;
+
+ bool unit_stride = false;
+ if ( is_r_row && is_a_row && is_b_row )
+ unit_stride = (ext_a.stride(1) == 1) && (ext_b.stride(1) == 1);
+
+ return unit_stride;
+ }
+
+ static void exec(Block0& r, Block1 const& a, Block2 const& b)
+ {
+ typedef typename Block0::value_type RT;
+
+ typedef typename Block_layout<Block0>::order_type order0_type;
+ typedef typename Block_layout<Block1>::order_type order1_type;
+ typedef typename Block_layout<Block2>::order_type order2_type;
+
+ Ext_data<Block0> ext_r(const_cast<Block0&>(r));
+ Ext_data<Block1> ext_a(const_cast<Block1&>(a));
+ Ext_data<Block2> ext_b(const_cast<Block2&>(b));
+
+ if (Type_equal<order0_type, row2_type>::value)
+ {
+ // SAL supports column-stride parameter only (rows must be unit-stride)
+ sal::mmul( a.size(2, 0), // M
+ b.size(2, 1), // N
+ a.size(2, 1), // P
+ ext_a.data(), ext_a.stride(0),
+ ext_b.data(), ext_b.stride(0),
+ ext_r.data(), ext_r.stride(0) );
+ }
+ else assert(0);
+ }
+ };
+
+
+
+ } // namespace vsip::impl
+
+ } // namespace vsip
+
+ #endif // VSIP_IMPL_HAVE_SAL
+
+ #endif // VSIP_IMPL_EVAL_SAL_HPP
Index: src/vsip/impl/general_dispatch.hpp
===================================================================
RCS file: /home/cvs/Repository/vpp/src/vsip/impl/general_dispatch.hpp,v
retrieving revision 1.1
diff -c -p -r1.1 general_dispatch.hpp
*** src/vsip/impl/general_dispatch.hpp 12 Oct 2005 12:45:05 -0000 1.1
--- src/vsip/impl/general_dispatch.hpp 26 Oct 2005 19:30:18 -0000
*************** namespace impl
*** 35,40 ****
--- 35,42 ----
struct Op_prod_vv; // vector-vector dot-product
struct Op_prod_mm; // matrix-matrix product
+ struct Op_prod_mv; // matrix-vector product
+ struct Op_prod_vm; // vector-matrix product
*************** struct Op_prod_mm; // matrix-matrix prod
*** 46,51 ****
--- 48,54 ----
struct Blas_tag; // BLAS implementation (ATLAS, MKL, etc)
struct Intel_ipp_tag; // Intel IPP library.
struct Generic_tag; // Generic implementation.
+ struct Mercury_sal_tag; // Mercury SAL library.
*************** struct Evaluator
*** 78,84 ****
template <typename OpTag>
struct Dispatch_order
{
! typedef typename Make_type_list<Blas_tag, Generic_tag>::type type;
};
--- 81,89 ----
template <typename OpTag>
struct Dispatch_order
{
! typedef typename Make_type_list<
! Blas_tag, Mercury_sal_tag, Generic_tag
! >::type type;
};
Index: src/vsip/impl/matvec-prod.hpp
===================================================================
RCS file: /home/cvs/Repository/vpp/src/vsip/impl/matvec-prod.hpp,v
retrieving revision 1.3
diff -c -p -r1.3 matvec-prod.hpp
*** src/vsip/impl/matvec-prod.hpp 12 Oct 2005 12:45:05 -0000 1.3
--- src/vsip/impl/matvec-prod.hpp 26 Oct 2005 19:30:18 -0000
***************
*** 18,23 ****
--- 18,24 ----
#include <vsip/matrix.hpp>
#include <vsip/impl/matvec.hpp>
#include <vsip/impl/eval-blas.hpp>
+ #include <vsip/impl/eval-sal.hpp>
*************** generic_prod(
*** 90,95 ****
--- 91,125 ----
+ // Generic evaluator for matrix-vector products.
+
+ template <typename Block0,
+ typename Block1,
+ typename Block2>
+ struct Evaluator<Op_prod_mv, Block0, Op_list_2<Block1, Block2>,
+ Generic_tag>
+ {
+ static bool const ct_valid = true;
+ static bool rt_valid(Block0&, Block1 const&, Block2 const&)
+ { return true; }
+
+ static void exec(Block0& r, Block1 const& a, Block2 const& b)
+ {
+ typedef typename Block0::value_type RT;
+
+ for (index_type i=0; i<r.size(1, 0); ++i)
+ {
+ RT sum = RT();
+ for (index_type k=0; k<a.size(2, 1); ++k)
+ {
+ sum += a.get(i, k) * b.get(k);
+ }
+ r.put(i, sum);
+ }
+ }
+ };
+
+
/// Matrix-vector product.
template <typename T0,
*************** generic_prod(
*** 107,112 ****
--- 137,167 ----
assert(r.size() == a.size(0));
assert(a.size(1) == b.size());
+ impl::General_dispatch<
+ impl::Op_prod_mv,
+ Block2,
+ impl::Op_list_2<Block0, Block1>,
+ typename impl::Dispatch_order<impl::Op_prod_mv>::type >
+ ::exec(r.block(), a.block(), b.block());
+ }
+
+
+ /*
+ template <typename T0,
+ typename T1,
+ typename T2,
+ typename Block0,
+ typename Block1,
+ typename Block2>
+ void
+ generic_prod(
+ const_Matrix<T0, Block0> a,
+ const_Vector<T1, Block1> b,
+ Vector<T2, Block2> r)
+ {
+ assert(r.size() == a.size(0));
+ assert(a.size(1) == b.size());
+
for (index_type i=0; i<r.size(0); ++i)
{
T2 sum = T2();
*************** generic_prod(
*** 117,123 ****
--- 172,207 ----
r.put(i, sum);
}
}
+ */
+
+ // Generic evaluator for vector-matrix products.
+
+ template <typename Block0,
+ typename Block1,
+ typename Block2>
+ struct Evaluator<Op_prod_vm, Block0, Op_list_2<Block1, Block2>,
+ Generic_tag>
+ {
+ static bool const ct_valid = true;
+ static bool rt_valid(Block0&, Block1 const&, Block2 const&)
+ { return true; }
+
+ static void exec(Block0& r, Block1 const& a, Block2 const& b)
+ {
+ typedef typename Block0::value_type RT;
+
+ for (index_type i=0; i<r.size(); ++i)
+ {
+ RT sum = RT();
+ for (index_type k=0; k<b.size(2, 0); ++k)
+ {
+ sum += a.get(k) * b.get(k, i);
+ }
+ r.put(i, sum);
+ }
+ }
+ };
/// Vector-matrix product.
*************** generic_prod(
*** 137,151 ****
assert(r.size() == b.size(1));
assert(a.size() == b.size(0));
! for (index_type i=0; i<r.size(0); ++i)
! {
! T2 sum = T2();
! for (index_type k=0; k<b.size(0); ++k)
! {
! sum += a.get(k) * b.get(k, i);
! }
! r.put(i, sum);
! }
}
} // namespace vsip::impl
--- 221,232 ----
assert(r.size() == b.size(1));
assert(a.size() == b.size(0));
! impl::General_dispatch<
! impl::Op_prod_vm,
! Block2,
! impl::Op_list_2<Block0, Block1>,
! typename impl::Dispatch_order<impl::Op_prod_vm>::type >
! ::exec(r.block(), a.block(), b.block());
}
} // namespace vsip::impl
Index: src/vsip/impl/matvec.hpp
===================================================================
RCS file: /home/cvs/Repository/vpp/src/vsip/impl/matvec.hpp,v
retrieving revision 1.5
diff -c -p -r1.5 matvec.hpp
*** src/vsip/impl/matvec.hpp 12 Oct 2005 12:45:05 -0000 1.5
--- src/vsip/impl/matvec.hpp 26 Oct 2005 19:30:19 -0000
***************
*** 21,26 ****
--- 21,27 ----
#include <vsip/impl/fns_elementwise.hpp>
#include <vsip/impl/general_dispatch.hpp>
#include <vsip/impl/eval-blas.hpp>
+ #include <vsip/impl/eval-sal.hpp>