summaryrefslogtreecommitdiff
path: root/boost/compute/lambda/context.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'boost/compute/lambda/context.hpp')
-rw-r--r--boost/compute/lambda/context.hpp329
1 files changed, 329 insertions, 0 deletions
diff --git a/boost/compute/lambda/context.hpp b/boost/compute/lambda/context.hpp
new file mode 100644
index 0000000000..ed25b79475
--- /dev/null
+++ b/boost/compute/lambda/context.hpp
@@ -0,0 +1,329 @@
+//---------------------------------------------------------------------------//
+// Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
+//
+// Distributed under the Boost Software License, Version 1.0
+// See accompanying file LICENSE_1_0.txt or copy at
+// http://www.boost.org/LICENSE_1_0.txt
+//
+// See http://boostorg.github.com/compute for more information.
+//---------------------------------------------------------------------------//
+
+#ifndef BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
+#define BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
+
+#include <boost/proto/core.hpp>
+#include <boost/proto/context.hpp>
+#include <boost/type_traits.hpp>
+#include <boost/preprocessor/repetition.hpp>
+
+#include <boost/compute/config.hpp>
+#include <boost/compute/function.hpp>
+#include <boost/compute/lambda/result_of.hpp>
+#include <boost/compute/lambda/functional.hpp>
+#include <boost/compute/type_traits/result_of.hpp>
+#include <boost/compute/type_traits/type_name.hpp>
+#include <boost/compute/detail/meta_kernel.hpp>
+
+namespace boost {
+namespace compute {
+namespace lambda {
+
+namespace mpl = boost::mpl;
+namespace proto = boost::proto;
+
+#define BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(tag, op) \
+ template<class LHS, class RHS> \
+ void operator()(tag, const LHS &lhs, const RHS &rhs) \
+ { \
+ if(proto::arity_of<LHS>::value > 0){ \
+ stream << '('; \
+ proto::eval(lhs, *this); \
+ stream << ')'; \
+ } \
+ else { \
+ proto::eval(lhs, *this); \
+ } \
+ \
+ stream << op; \
+ \
+ if(proto::arity_of<RHS>::value > 0){ \
+ stream << '('; \
+ proto::eval(rhs, *this); \
+ stream << ')'; \
+ } \
+ else { \
+ proto::eval(rhs, *this); \
+ } \
+ }
+
+// lambda expression context
+template<class Args>
+struct context : proto::callable_context<context<Args> >
+{
+ typedef void result_type;
+ typedef Args args_tuple;
+
+ // create a lambda context for kernel with args
+ context(boost::compute::detail::meta_kernel &kernel, const Args &args_)
+ : stream(kernel),
+ args(args_)
+ {
+ }
+
+ // handle terminals
+ template<class T>
+ void operator()(proto::tag::terminal, const T &x)
+ {
+ // terminal values in lambda expressions are always literals
+ stream << stream.lit(x);
+ }
+
+ // handle placeholders
+ template<int I>
+ void operator()(proto::tag::terminal, placeholder<I>)
+ {
+ stream << boost::get<I>(args);
+ }
+
+ // handle functions
+ #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG(z, n, unused) \
+ BOOST_PP_COMMA_IF(n) BOOST_PP_CAT(const Arg, n) BOOST_PP_CAT(&arg, n)
+
+ #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION(z, n, unused) \
+ template<class F, BOOST_PP_ENUM_PARAMS(n, class Arg)> \
+ void operator()( \
+ proto::tag::function, \
+ const F &function, \
+ BOOST_PP_REPEAT(n, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG, ~) \
+ ) \
+ { \
+ proto::value(function).apply(*this, BOOST_PP_ENUM_PARAMS(n, arg)); \
+ }
+
+ BOOST_PP_REPEAT_FROM_TO(1, BOOST_COMPUTE_MAX_ARITY, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION, ~)
+
+ #undef BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION
+
+ // operators
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::plus, '+')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::minus, '-')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::multiplies, '*')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::divides, '/')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::modulus, '%')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less, '<')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater, '>')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less_equal, "<=")
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater_equal, ">=")
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::equal_to, "==")
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::not_equal_to, "!=")
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_and, "&&")
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_or, "||")
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_and, '&')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_or, '|')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_xor, '^')
+ BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::assign, '=')
+
+ // subscript operator
+ template<class LHS, class RHS>
+ void operator()(proto::tag::subscript, const LHS &lhs, const RHS &rhs)
+ {
+ proto::eval(lhs, *this);
+ stream << '[';
+ proto::eval(rhs, *this);
+ stream << ']';
+ }
+
+ // ternary conditional operator
+ template<class Pred, class Arg1, class Arg2>
+ void operator()(proto::tag::if_else_, const Pred &p, const Arg1 &x, const Arg2 &y)
+ {
+ proto::eval(p, *this);
+ stream << '?';
+ proto::eval(x, *this);
+ stream << ':';
+ proto::eval(y, *this);
+ }
+
+ boost::compute::detail::meta_kernel &stream;
+ Args args;
+};
+
+namespace detail {
+
+template<class Expr, class Arg>
+struct invoked_unary_expression
+{
+ typedef typename ::boost::compute::result_of<Expr(Arg)>::type result_type;
+
+ invoked_unary_expression(const Expr &expr, const Arg &arg)
+ : m_expr(expr),
+ m_arg(arg)
+ {
+ }
+
+ Expr m_expr;
+ Arg m_arg;
+};
+
+template<class Expr, class Arg>
+boost::compute::detail::meta_kernel&
+operator<<(boost::compute::detail::meta_kernel &kernel,
+ const invoked_unary_expression<Expr, Arg> &expr)
+{
+ context<boost::tuple<Arg> > ctx(kernel, boost::make_tuple(expr.m_arg));
+ proto::eval(expr.m_expr, ctx);
+
+ return kernel;
+}
+
+template<class Expr, class Arg1, class Arg2>
+struct invoked_binary_expression
+{
+ typedef typename ::boost::compute::result_of<Expr(Arg1, Arg2)>::type result_type;
+
+ invoked_binary_expression(const Expr &expr,
+ const Arg1 &arg1,
+ const Arg2 &arg2)
+ : m_expr(expr),
+ m_arg1(arg1),
+ m_arg2(arg2)
+ {
+ }
+
+ Expr m_expr;
+ Arg1 m_arg1;
+ Arg2 m_arg2;
+};
+
+template<class Expr, class Arg1, class Arg2>
+boost::compute::detail::meta_kernel&
+operator<<(boost::compute::detail::meta_kernel &kernel,
+ const invoked_binary_expression<Expr, Arg1, Arg2> &expr)
+{
+ context<boost::tuple<Arg1, Arg2> > ctx(
+ kernel,
+ boost::make_tuple(expr.m_arg1, expr.m_arg2)
+ );
+ proto::eval(expr.m_expr, ctx);
+
+ return kernel;
+}
+
+} // end detail namespace
+
+// forward declare domain
+struct domain;
+
+// lambda expression wrapper
+template<class Expr>
+struct expression : proto::extends<Expr, expression<Expr>, domain>
+{
+ typedef proto::extends<Expr, expression<Expr>, domain> base_type;
+
+ BOOST_PROTO_EXTENDS_USING_ASSIGN(expression)
+
+ expression(const Expr &expr = Expr())
+ : base_type(expr)
+ {
+ }
+
+ // result_of protocol
+ template<class Signature>
+ struct result
+ {
+ };
+
+ template<class This>
+ struct result<This()>
+ {
+ typedef
+ typename ::boost::compute::lambda::result_of<Expr>::type type;
+ };
+
+ template<class This, class Arg>
+ struct result<This(Arg)>
+ {
+ typedef
+ typename ::boost::compute::lambda::result_of<
+ Expr,
+ typename boost::tuple<Arg>
+ >::type type;
+ };
+
+ template<class This, class Arg1, class Arg2>
+ struct result<This(Arg1, Arg2)>
+ {
+ typedef typename
+ ::boost::compute::lambda::result_of<
+ Expr,
+ typename boost::tuple<Arg1, Arg2>
+ >::type type;
+ };
+
+ template<class Arg>
+ detail::invoked_unary_expression<expression<Expr>, Arg>
+ operator()(const Arg &x) const
+ {
+ return detail::invoked_unary_expression<expression<Expr>, Arg>(*this, x);
+ }
+
+ template<class Arg1, class Arg2>
+ detail::invoked_binary_expression<expression<Expr>, Arg1, Arg2>
+ operator()(const Arg1 &x, const Arg2 &y) const
+ {
+ return detail::invoked_binary_expression<
+ expression<Expr>,
+ Arg1,
+ Arg2
+ >(*this, x, y);
+ }
+
+ // function<> conversion operator
+ template<class R, class A1>
+ operator function<R(A1)>() const
+ {
+ using ::boost::compute::detail::meta_kernel;
+
+ std::stringstream source;
+
+ ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
+
+ source << "inline " << type_name<R>() << " lambda"
+ << ::boost::compute::detail::generate_argument_list<R(A1)>('x')
+ << "{\n"
+ << " return " << meta_kernel::expr_to_string((*this)(arg1)) << ";\n"
+ << "}\n";
+
+ return make_function_from_source<R(A1)>("lambda", source.str());
+ }
+
+ template<class R, class A1, class A2>
+ operator function<R(A1, A2)>() const
+ {
+ using ::boost::compute::detail::meta_kernel;
+
+ std::stringstream source;
+
+ ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
+ ::boost::compute::detail::meta_kernel_variable<A1> arg2("y");
+
+ source << "inline " << type_name<R>() << " lambda"
+ << ::boost::compute::detail::generate_argument_list<R(A1, A2)>('x')
+ << "{\n"
+ << " return " << meta_kernel::expr_to_string((*this)(arg1, arg2)) << ";\n"
+ << "}\n";
+
+ return make_function_from_source<R(A1, A2)>("lambda", source.str());
+ }
+};
+
+// lambda expression domain
+struct domain : proto::domain<proto::generator<expression> >
+{
+};
+
+} // end lambda namespace
+} // end compute namespace
+} // end boost namespace
+
+#endif // BOOST_COMPUTE_LAMBDA_CONTEXT_HPP