summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2018-02-15 22:53:19 +0100
committerGitHub <noreply@github.com>2018-02-15 22:53:19 +0100
commitcb2fd39fdddcc42865e881fe566dae6548ac47b9 (patch)
treeb3cc54c1e07ea5114798c1eec8921bba4af58bd2 /torch
parenta27f0e4daa9f7dc1187d237fd38645c387407e9e (diff)
downloadpytorch-cb2fd39fdddcc42865e881fe566dae6548ac47b9.tar.gz
pytorch-cb2fd39fdddcc42865e881fe566dae6548ac47b9.tar.bz2
pytorch-cb2fd39fdddcc42865e881fe566dae6548ac47b9.zip
Add Python frontend to the JIT (#5190)
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/init.cpp2
-rw-r--r--torch/csrc/jit/script/lexer.cpp20
-rw-r--r--torch/csrc/jit/script/lexer.h1
-rw-r--r--torch/csrc/jit/script/python_tree_views.cpp155
-rw-r--r--torch/csrc/jit/script/python_tree_views.h8
-rw-r--r--torch/csrc/jit/script/tree.h1
-rw-r--r--torch/csrc/jit/script/tree_views.h28
-rw-r--r--torch/jit/frontend.py401
8 files changed, 615 insertions, 1 deletions
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 6c2fa17854..5f227f8162 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -15,6 +15,7 @@
#include "torch/csrc/jit/passes/onnx/peephole.h"
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/jit/script/init.h"
+#include "torch/csrc/jit/script/python_tree_views.h"
namespace torch { namespace jit {
@@ -134,6 +135,7 @@ void initJITBindings(PyObject *module) {
initPythonTracerBindings(module);
python::initCompilerMixin(module);
script::initJitScriptBindings(module);
+ script::initTreeViewBindings(module);
}
}}
diff --git a/torch/csrc/jit/script/lexer.cpp b/torch/csrc/jit/script/lexer.cpp
index 2d167b4623..2c7d42bc38 100644
--- a/torch/csrc/jit/script/lexer.cpp
+++ b/torch/csrc/jit/script/lexer.cpp
@@ -1,10 +1,30 @@
#include "torch/csrc/jit/script/lexer.h"
#include <string>
+#include <unordered_map>
+#include <mutex>
namespace torch {
namespace jit {
namespace script {
+int stringToKind(std::string str) {
+ static std::once_flag init_flag;
+ static std::unordered_map<std::string, int> str_to_kind;
+ std::call_once(init_flag, []() {
+ for (char tok : std::string(valid_single_char_tokens))
+ str_to_kind[std::string(1, tok)] = tok;
+#define DEFINE_CASE(tok, _, str) \
+ if (std::string(str) != "") str_to_kind[str] = tok;
+ TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
+#undef DEFINE_CASE
+ });
+ try {
+ return str_to_kind.at(str);
+ } catch (std::out_of_range& err) {
+ throw std::out_of_range("unknown token in stringToKind");
+ }
+}
+
std::string kindToString(int kind) {
if (kind < 256)
return std::string(1, kind);
diff --git a/torch/csrc/jit/script/lexer.h b/torch/csrc/jit/script/lexer.h
index 6bf0c37baa..035105dac5 100644
--- a/torch/csrc/jit/script/lexer.h
+++ b/torch/csrc/jit/script/lexer.h
@@ -89,6 +89,7 @@ enum TokenKind {
};
std::string kindToString(int kind);
+int stringToKind(std::string str);
// nested hash tables that indicate char-by-char what is a valid token.
struct TokenTrie;
diff --git a/torch/csrc/jit/script/python_tree_views.cpp b/torch/csrc/jit/script/python_tree_views.cpp
new file mode 100644
index 0000000000..d25aec3fc9
--- /dev/null
+++ b/torch/csrc/jit/script/python_tree_views.cpp
@@ -0,0 +1,155 @@
+#include "torch/csrc/jit/script/python_tree_views.h"
+
+#include "torch/csrc/jit/script/tree_views.h"
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+#include <sstream>
+
+namespace py = pybind11;
+
+namespace torch { namespace jit { namespace script {
+
+struct SourceRangeFactory {
+ SourceRangeFactory(std::string source)
+ : source_(std::make_shared<std::string>(std::move(source))) {
+ std::size_t pos = 0;
+ do {
+ line_len_prefix_sum_.push_back(pos);
+ pos++;
+ } while ((pos = source_->find('\n', pos)) != std::string::npos);
+ }
+ SourceRange create(int line, int start_col, int end_col) {
+ // Python has a weird convention where col_offset points to the column *before*
+ // the token starts.
+ start_col++;
+ end_col++;
+ // Also, lines are counted from 1.
+ line--;
+ auto line_start = line_len_prefix_sum_.at(line);
+ return SourceRange(source_, line_start + start_col, line_start + end_col);
+ }
+
+ std::shared_ptr<std::string> source_;
+ std::vector<std::size_t> line_len_prefix_sum_;
+};
+
+template<typename T>
+List<T> wrap_list(const SourceRange& fallback_pos, std::vector<T>&& vec) {
+ if (vec.empty())
+ return List<T>::create(fallback_pos, std::move(vec));
+ return List<T>::create(vec.front().range(), std::move(vec));
+}
+
+void initTreeViewBindings(PyObject *module) {
+ auto _C = py::handle(module).cast<py::module>();
+ auto m = _C.def_submodule("_jit_tree_views");
+
+ py::class_<SourceRange>(m, "SourceRange")
+ .def("highlight", [](const SourceRange& self) {
+ std::ostringstream stream;
+ self.highlight(stream);
+ return stream.str();
+ })
+ .def_property_readonly("start", &SourceRange::start)
+ .def_property_readonly("end", &SourceRange::end);
+ py::class_<SourceRangeFactory>(m, "SourceRangeFactory")
+ .def(py::init<std::string&&>())
+ .def("make_range", &SourceRangeFactory::create)
+ .def("make_raw_range", [](const SourceRangeFactory& self, size_t start, size_t end) {
+ return SourceRange(self.source_, start, end);
+ })
+ .def_property_readonly("source", [](const SourceRangeFactory& self) {
+ return *self.source_;
+ });
+
+ py::class_<TreeView>(m, "TreeView")
+ .def("range", &TreeView::range)
+ .def("__str__", [](const TreeView& tree) {
+ std::ostringstream stream;
+ stream << tree.get();
+ return stream.str();
+ });
+
+ py::class_<Ident, TreeView>(m, "Ident")
+ .def(py::init(&Ident::create));
+
+ py::class_<Param, TreeView>(m, "Param")
+ .def(py::init([](const Type& type, const Ident& name) {
+ return Param::create(name.range(), name, type);
+ }));
+ py::class_<Attribute, TreeView>(m, "Attribute")
+ .def(py::init([](const Ident& name, const Expr& value) {
+ return Attribute::create(name.range(), name, value);
+ }));
+
+
+ py::class_<Type, TreeView>(m, "Type");
+ py::class_<TensorType, Type>(m, "TensorType")
+ .def(py::init(&TensorType::create));
+
+ py::class_<Stmt, TreeView>(m, "Stmt");
+ py::class_<Expr, TreeView>(m, "Expr");
+ py::class_<Def, TreeView>(m, "Def")
+ .def(py::init([](const Ident& name,
+ std::vector<Param> params,
+ std::vector<Param> returns,
+ std::vector<Stmt> body) {
+ auto r = name.range();
+ return Def::create(r,
+ name,
+ wrap_list(r, std::move(params)),
+ wrap_list(r, std::move(returns)),
+ wrap_list(r, std::move(body)));
+ }));
+
+
+ py::class_<Assign, Stmt>(m, "Assign")
+ .def(py::init([](std::vector<Ident> lhs, std::string kind_str, const Expr& rhs) {
+ auto r = lhs.at(0).range();
+ auto kind = AssignKind(Compound::create(stringToKind(kind_str), r, {}));
+ return Assign::create(r, List<Ident>::create(r, std::move(lhs)), kind, rhs);
+ }));
+ py::class_<If, Stmt>(m, "If")
+ .def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> true_branch, std::vector<Stmt> false_branch) {
+ return If::create(range, cond,
+ wrap_list(range, std::move(true_branch)),
+ wrap_list(range, std::move(false_branch)));
+ }));
+ py::class_<While, Stmt>(m, "While")
+ .def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> body) {
+ return While::create(range, cond, wrap_list(range, std::move(body)));
+ }));
+ py::class_<ExprStmt, Stmt>(m, "ExprStmt")
+ .def(py::init([](const Expr& expr) {
+ return ExprStmt::create(expr.range(), expr);
+ }));
+
+ py::class_<Var, Expr>(m, "Var")
+ .def(py::init([](const Ident& name) {
+ return Var::create(name.range(), name);
+ }))
+ .def("name", [](const Var& var) { return var.name(); });
+ py::class_<BinOp, Expr>(m, "BinOp")
+ .def(py::init([](std::string kind, const Expr& lhs, const Expr& rhs) {
+ return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
+ }));
+ // NB: we take range here, because unary ops precede their exprs, so we need to include them
+ py::class_<UnaryOp, Expr>(m, "UnaryOp")
+ .def(py::init([](const SourceRange& range, std::string kind, const Expr& expr) {
+ return UnaryOp::create(range, stringToKind(kind), expr);
+ }));
+ py::class_<Apply, Expr>(m, "Apply")
+ .def(py::init([](const Ident& name, std::vector<Expr> args, std::vector<Attribute> kwargs) {
+ auto r = name.range();
+ return Apply::create(name.range(), name,
+ wrap_list(r, std::move(args)), wrap_list(r, std::move(kwargs)));
+ }));
+ py::class_<TernaryIf, Expr>(m, "TernaryIf")
+ .def(py::init([](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
+ return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
+ }));
+}
+
+}}} // namespace torch::jit::script
diff --git a/torch/csrc/jit/script/python_tree_views.h b/torch/csrc/jit/script/python_tree_views.h
new file mode 100644
index 0000000000..06b06aba66
--- /dev/null
+++ b/torch/csrc/jit/script/python_tree_views.h
@@ -0,0 +1,8 @@
+#include <Python.h>
+
+namespace torch { namespace jit { namespace script {
+
+void initTreeViewBindings(PyObject *module);
+
+}}} // namespace torch::jit::script
+
diff --git a/torch/csrc/jit/script/tree.h b/torch/csrc/jit/script/tree.h
index 1cc5fcf74e..b7461d502f 100644
--- a/torch/csrc/jit/script/tree.h
+++ b/torch/csrc/jit/script/tree.h
@@ -2,6 +2,7 @@
#include <memory>
#include <vector>
+#include <functional>
#include "torch/csrc/jit/script/lexer.h"
diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h
index a28452663d..8745bc6dd9 100644
--- a/torch/csrc/jit/script/tree_views.h
+++ b/torch/csrc/jit/script/tree_views.h
@@ -2,6 +2,8 @@
#include "error_report.h"
#include "tree.h"
+#include <functional>
+
namespace torch {
namespace jit {
namespace script {
@@ -493,6 +495,9 @@ struct UnaryOp : public Expr {
throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid UnaryOp";
}
}
+ static UnaryOp create(const SourceRange& range, int kind, const Expr& expr) {
+ return UnaryOp(Compound::create(kind, range, {expr}));
+ }
};
struct Cast : public Expr {
@@ -601,7 +606,7 @@ struct Var : public Expr {
explicit Var(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_VAR);
};
- Ident name() {
+ Ident name() const {
return Ident(subtree(0));
}
static Var create(const SourceRange& range, const Ident& name) {
@@ -609,6 +614,27 @@ struct Var : public Expr {
}
};
+struct TernaryIf : public Expr {
+ explicit TernaryIf(const TreeRef& tree) : Expr(tree) {
+ tree_->matchNumSubtrees(TK_IF_EXPR, 3);
+ };
+ Expr cond() const {
+ return Expr(subtree(0));
+ }
+ Expr true_expr() const {
+ return Expr(subtree(1));
+ }
+ Expr false_expr() const {
+ return Expr(subtree(2));
+ }
+ static TernaryIf create(const SourceRange& range,
+ const Expr& cond,
+ const Expr& true_expr,
+ const Expr& false_expr) {
+ return TernaryIf(Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
+ };
+};
+
} // namespace script
} // namespace jit
} // namespace torch
diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py
new file mode 100644
index 0000000000..af37b390dc
--- /dev/null
+++ b/torch/jit/frontend.py
@@ -0,0 +1,401 @@
+import torch
+import sys
+import ast
+import inspect
+import string
+from textwrap import dedent
+from functools import partial
+from collections import namedtuple
+from torch._C._jit_tree_views import *
+
+PY2 = sys.version_info[0] == 2
+_reserved_prefix = '__jit'
+_identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)
+
+# TODO: populate those
+pretty_node_names = {
+ ast.For: "for loops",
+ ast.Delete: "del statements",
+ ast.ClassDef: "class definitions",
+ ast.With: "with statements",
+ ast.Raise: "raise statements",
+ ast.Assert: "assertions",
+ ast.Import: "import statements",
+ ast.ImportFrom: "import statements",
+ ast.Global: "global variables",
+ ast.Break: "break statements",
+ ast.Continue: "continue statements",
+}
+
+node_start_tokens = {
+ ast.For: "for",
+ ast.Delete: "del",
+ ast.ClassDef: "class",
+ ast.With: "with",
+ ast.Raise: "raise",
+ ast.Assert: "assert",
+ ast.Import: "import",
+ ast.ImportFrom: "from",
+ ast.Global: "global",
+ ast.Break: "break",
+ ast.Continue: "continue",
+}
+
+if PY2:
+ pretty_node_names.update({
+ ast.Print: "print statements",
+ ast.TryExcept: "try blocks",
+ ast.TryFinally: "try blocks",
+ ast.Exec: "exec statements",
+ })
+
+ node_start_tokens.update({
+ ast.Print: "print",
+ ast.TryExcept: "try",
+ ast.TryFinally: "try",
+ ast.Exec: "exec",
+ })
+else:
+ pretty_node_names.update({
+ ast.AsyncFor: "async for loops",
+ ast.AsyncWith: "async with statements",
+ ast.Try: "try blocks",
+ ast.Nonlocal: "nonlocal variables",
+ })
+
+ node_start_tokens.update({
+ ast.AsyncFor: "async for",
+ ast.AsyncWith: "async with",
+ ast.Try: "try",
+ ast.Nonlocal: "nonlocal",
+ })
+
+if sys.version_info >= (3, 6):
+ pretty_node_names.update({
+ ast.AnnAssign: "annotated assignments",
+ })
+ # NB: no specific token for AnnAssign
+
+
+class FrontendError(Exception):
+ def __init__(self, source_range, msg):
+ self.source_range = source_range
+ self.msg = msg
+
+ def __str__(self):
+ result = self.msg
+ if self.source_range is not None:
+ result += '\n' + self.source_range.highlight()
+ return result
+
+
+class NotSupportedError(FrontendError):
+ pass
+
+
+class UnsupportedNodeError(NotSupportedError):
+ def __init__(self, ctx, offending_node):
+ # If we don't have a specific token, we default to length of 1
+ range_len = len(node_start_tokens.get(type(offending_node), ' '))
+ source_range = ctx.make_range(offending_node.lineno,
+ offending_node.col_offset,
+ offending_node.col_offset + range_len)
+ feature_name = pretty_node_names.get(node_type, node_type.__name__)
+ msg = "{} aren't supported".format(feature_name)
+ super(NotSupportedError, self).__init__(source_range, msg)
+
+
+class FrontendTypeError(FrontendError):
+ pass
+
+
+def get_jit_ast(fn):
+ source = dedent(inspect.getsource(fn))
+ py_ast = ast.parse(source)
+ if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
+ raise RuntimeError("expected a single top-level function")
+ return build_def(SourceRangeFactory(source), py_ast.body[0])
+
+
+class Builder(object):
+ def __call__(self, ctx, node):
+ method = getattr(self, 'build_' + node.__class__.__name__, None)
+ if method is None:
+ raise UnsupportedNodeError(ctx, node)
+ return method(ctx, node)
+
+
+class CountReturns(ast.NodeVisitor):
+ def __init__(self):
+ self.num_returns = 0
+
+ def visit_Return(self, ret):
+ self.num_returns += 1
+
+ @staticmethod
+ def get_count(py_def):
+ counter = CountReturns()
+ counter.visit(py_def)
+ return counter.num_returns
+
+
+_ret_err_msg = ("JIT-ed functions can only have a single return, "
+ "and it has to be the last statement in the body")
+
+
+def build_def(ctx, py_def):
+ returns = []
+ ret_body = []
+ body = py_def.body
+ num_returns = CountReturns.get_count(py_def)
+ # TODO: change TorchScript AST to have a Return statement
+ if num_returns == 1:
+ ret_stmt, body = body[-1], body[:-1]
+ if not isinstance(ret_stmt, ast.Return):
+ raise ValueError(_ret_err_msg)
+ ret_expr = ret_stmt.value
+ ret_vals = ret_expr.elts if isinstance(ret_expr, ast.Tuple) else [ret_expr]
+ for i, val in enumerate(ret_vals):
+ val_expr = build_expr(ctx, val)
+ val_name = _reserved_prefix + '_' + str(i)
+ r = val_expr.range()
+ returns.append(Param(TensorType(r), Ident(r, val_name)))
+ ret_body.append(Assign([Ident(r, val_name)], '=', val_expr))
+ elif num_returns > 1:
+ raise ValueError(_ret_err_msg)
+ r = ctx.make_range(py_def.lineno, py_def.col_offset,
+ py_def.col_offset + len("def"))
+ return Def(Ident(r, py_def.name),
+ build_param_list(ctx, py_def.args),
+ returns,
+ [build_stmt(ctx, stmt) for stmt in body] + ret_body)
+
+
+_vararg_kwarg_err = ("Compiled functions can't take variable number of arguments, "
+ "have default values for arguments, nor keyword-only arguments")
+
+
+def build_param_list(ctx, py_args):
+ if py_args.vararg is not None or py_args.kwarg is not None or py_args.defaults:
+ raise ValueError(_vararg_kwarg_err)
+ if not PY2 and (py_args.kw_defaults or py_args.kwonlyargs):
+ raise ValueError(_vararg_kwarg_err)
+ return [build_param(ctx, arg) for arg in py_args.args]
+
+
+def build_param(ctx, py_arg):
+ # NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
+ # In Python2 py_arg is a Name (Expr subclass)
+ if getattr(py_arg, 'annotation', None) is not None:
+ raise ValueError("Compiled functions don't support annotations")
+ name = py_arg.id if PY2 else py_arg.arg
+ r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
+ return Param(TensorType(r), Ident(r, name))
+
+
+class StmtBuilder(Builder):
+ augassign_map = {
+ ast.Add: '+',
+ ast.Sub: '-',
+ ast.Mult: '*',
+ ast.Div: '/',
+ }
+
+ @staticmethod
+ def build_Expr(ctx, stmt):
+ return ExprStmt(build_expr(ctx, stmt.value))
+
+ @staticmethod
+ def get_assign_ident(ctx, expr):
+ var = build_expr(ctx, expr)
+ if not isinstance(var, Var):
+ raise NotSupportedError("the only expressions allowed on the left hand side of "
+ "assignments are variable names", var.range())
+ return var.name()
+
+ @staticmethod
+ def build_Assign(ctx, stmt):
+ return Assign([StmtBuilder.get_assign_ident(ctx, e) for e in stmt.targets],
+ '=',
+ build_expr(ctx, stmt.value))
+
+ @staticmethod
+ def build_AugAssign(ctx, stmt):
+ lhs = [StmtBuilder.get_assign_ident(ctx, stmt.target)]
+ rhs = build_expr(ctx, stmt.value)
+ op = type(stmt.op)
+ if op in StmtBuilder.augassign_map:
+ op_token = StmtBuilder.augassign_map[op]
+ else:
+ raise NotSupportedError(
+ find_before(ctx, rhs.range().start, '=', offsets=(-1, 0)),
+ "unsupported kind of augumented assignment: " + op.__name__)
+ return Assign(lhs, op_token, rhs)
+
+ @staticmethod
+ def build_While(ctx, stmt):
+ if stmt.orelse:
+ # TODO: try to recover the location of else:? Python doesn't give us useful
+ # annotations in this case
+ raise NotSupportedError(None, "else branches of while loops aren't supported")
+ r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("while"))
+ return While(r, build_expr(ctx, stmt.test), [build_stmt(ctx, s) for s in stmt.body])
+
+ @staticmethod
+ def build_If(ctx, stmt):
+ r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("if"))
+ return If(r, build_expr(ctx, stmt.test),
+ [build_stmt(ctx, s) for s in stmt.body],
+ [build_stmt(ctx, s) for s in stmt.orelse])
+
+
+class ExprBuilder(Builder):
+ _MethodRef = namedtuple('MethodRef', ['self', 'name'])
+ binop_map = {
+ ast.Add: '+',
+ ast.Sub: '-',
+ ast.Mult: '*',
+ ast.Div: '/',
+ }
+
+ unop_map = {
+ ast.Not: 'not',
+ ast.USub: '-',
+ }
+
+ boolop_map = {
+ ast.And: 'and',
+ ast.Or: 'or',
+ }
+
+ cmpop_map = {
+ ast.Eq: '==',
+ ast.NotEq: '!=',
+ ast.LtE: '<=',
+ ast.Lt: '<',
+ ast.GtE: '>=',
+ ast.Gt: '>',
+ }
+
+ @staticmethod
+ def build_Attribute(ctx, expr):
+ # NB: the only attributes we support are for getting methods
+ value = build_expr(ctx, expr.value)
+ # <sigh> name is just a string, so it's not annotated in any way.
+ source = ctx.source
+ pos = find_after(ctx, value.range().end, '.').end # Start with the dot
+ while source[pos] in string.whitespace: # Skip whitespace
+ pos += 1
+ start_pos = pos
+ while source[pos] in _identifier_chars: # Find the identifier itself
+ pos += 1
+ name_range = ctx.make_raw_range(start_pos, pos)
+ return ExprBuilder._MethodRef(value, Ident(name_range, expr.attr))
+
+ @staticmethod
+ def build_Call(ctx, expr):
+ ref = build_expr(ctx, expr.func, allow_methods=True)
+ if type(ref) is not ExprBuilder._MethodRef:
+ ref_range = ref.range()
+ parenthesis_range = find_after(ctx, ref_range.end, '(')
+ raise FrontendTypeError(
+ ctx.make_raw_range(ref_range.start, parenthesis_range.end),
+ "trying to call a non-function object")
+ args = [build_expr(ctx, py_arg) for py_arg in expr.args]
+ kwargs = [Attribute(Ident(name), build_expr(ctx, value)) for name, value in expr.keywords]
+ return Apply(ref.name, [ref.self] + args, kwargs)
+
+ @staticmethod
+ def build_Name(ctx, expr):
+ r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
+ if expr.id.startswith(_reserved_prefix):
+ raise NotSupportedError(r, "names of variables used in JIT-ed functions "
+ "can't start with " + _reserved_prefix)
+ return Var(Ident(r, expr.id))
+
+ @staticmethod
+ def build_BinOp(ctx, expr):
+ lhs = build_expr(ctx, expr.left)
+ rhs = build_expr(ctx, expr.right)
+ op = type(expr.op)
+ op_token = ExprBuilder.binop_map.get(op)
+ if op_token is None:
+ err_range = ctx.make_range(lhs.range().end, rhs.range().start)
+ raise NotSupportedError(err_range, "unsupported binary operator: " + op.__name__)
+ return BinOp(op_token, lhs, rhs)
+
+ @staticmethod
+ def build_UnaryOp(ctx, expr):
+ sub_expr = build_expr(ctx, expr.operand)
+ op = type(expr.op)
+ op_token = ExprBuilder.unop_map.get(op)
+ r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(op_token))
+ if op_token is None:
+ err_range = ctx.make_raw_range(r.start, sub_expr.range().end)
+ raise NotSupportedError(err_range, "unsupported unary operator: " + op.__name__)
+ return UnaryOp(r, op_token, sub_expr)
+
+ @staticmethod
+ def build_BoolOp(ctx, expr):
+ if len(expr.values) < 2:
+ raise AssertionError("expected at least 2 values in BoolOp, but got " + str(len(expr.values)))
+ sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values]
+ op = type(expr.op)
+ op_token = ExprBuilder.boolop_map.get(op)
+ if op_token is None:
+ err_range = ctx.make_raw_range(sub_exprs[0].range().end, sub_exprs[1].range().start)
+ raise NotSupportedError(err_range, "unsupported boolean operator: " + op.__name__)
+ lhs = sub_exprs[0]
+ for rhs in sub_exprs[1:]:
+ lhs = BinOp(op_token, lhs, rhs)
+ return lhs
+
+ @staticmethod
+ def build_IfExp(ctx, expr):
+ return TernaryIf(build_expr(ctx, expr.test),
+ build_expr(ctx, expr.body),
+ build_expr(ctx, expr.orelse))
+
+ @staticmethod
+ def build_Compare(ctx, expr):
+ operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)]
+ result = None
+ for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]):
+ op = type(op_)
+ op_token = ExprBuilder.cmpop_map.get(op)
+ if op_token is None:
+ err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
+ raise NotSupportedError(err_range, "unsupported comparison operator: " + op.__name__)
+ cmp_expr = BinOp(op_token, lhs, rhs)
+ if result is None:
+ result = cmp_expr
+ else:
+ result = BinOp('and', result, cmp_expr)
+ return result
+
+ @staticmethod
+ def build_Num(ctx, expr):
+ # TODO: fix this once we have a nice Number node in our AST
+ err_range = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
+ raise NotSupportedError(err_range, "scalar constants aren't supported")
+
+ def __call__(self, ctx, expr, allow_methods=False):
+ result = super(ExprBuilder, self).__call__(ctx, expr)
+ if type(result) is ExprBuilder._MethodRef and not allow_methods:
+ err_range = ctx.make_raw_range(result.self.range().start, result.name.range().end)
+ raise FrontendTypeError(err_range, "taking attributes/function values isn't supported")
+ return result
+
+
+build_expr = ExprBuilder()
+build_stmt = StmtBuilder()
+
+
+def find_after(ctx, pos, substr, offsets=(0, 0)):
+ new_pos = pos + ctx.source[pos:].index(substr)
+ return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
+
+
+def find_before(ctx, pos, substr, offsets=(0, 0)):
+ new_pos = ctx.source[:pos].rindex(substr)
+ return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])