summaryrefslogtreecommitdiff
path: root/caffe2/core
diff options
context:
space:
mode:
authorSebastian Messmer <messmer@fb.com>2019-02-01 12:44:55 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-01 13:52:01 -0800
commitaaa8ace48642a1a5774332084161e0f93f171e1f (patch)
tree4f468a87d962d1c83710d618e563fc9515286175 /caffe2/core
parenta40e8ce7c553d61024fdf4f8f2b7b13ff606e77b (diff)
downloadpytorch-aaa8ace48642a1a5774332084161e0f93f171e1f.tar.gz
pytorch-aaa8ace48642a1a5774332084161e0f93f171e1f.tar.bz2
pytorch-aaa8ace48642a1a5774332084161e0f93f171e1f.zip
Implement new c10 dispatcher (#16625)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16625 This is a squash of multiple PRs that refactored the old c10 dispatcher into a new one that follows the c10 dispatcher design doc. It is now unboxed and follows the Stack semantics from JIT. It also uses the runtime JIT schema instead of its own compile time schema definitions. Reviewed By: ezyang Differential Revision: D13907069 fbshipit-source-id: edcc4806ccd21474fdfb5a98516219b1956db13d
Diffstat (limited to 'caffe2/core')
-rw-r--r--caffe2/core/operator.h10
-rw-r--r--caffe2/core/operator_c10wrapper.h146
2 files changed, 85 insertions, 71 deletions
diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h
index 253428645d..a12767c456 100644
--- a/caffe2/core/operator.h
+++ b/caffe2/core/operator.h
@@ -203,6 +203,16 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
return XBlobGetMutableTensor(outputs_.at(idx), dims, options);
}
+ void SetOutputTensor(int idx, Tensor tensor) {
+ // also update the tensor in the hack
+ if (!isLegacyOperator()) {
+ output_tensors_[idx] = tensor.UnsafeSharedInstance();
+ }
+
+ // update the tensor in the workspace
+ BlobSetTensor(outputs_.at(idx), std::move(tensor));
+ }
+
inline Tensor*
OutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
if (isLegacyOperator()) {
diff --git a/caffe2/core/operator_c10wrapper.h b/caffe2/core/operator_c10wrapper.h
index fb0c458ce9..ab0230d2af 100644
--- a/caffe2/core/operator_c10wrapper.h
+++ b/caffe2/core/operator_c10wrapper.h
@@ -23,19 +23,15 @@ using extract_type_t = typename ParameterDef::type;
*
* REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(C10Add, C2MyAddOpName)
*
- * Note: This wrapper currently only supports C10 ops that have exactly one
- * output and take that in the last parameter as "Tensor* output".
- * TODO: Figure out a better way to handle output parameters
*/
template <
- class OpSchemaDef,
+ const c10::OperatorHandle& (*OperatorHandle)(),
class Context,
bool use_array_input,
+ size_t num_output_parameters,
class ParameterDefTuple>
class C10OperatorWrapper final : public Operator<Context> {
- using Schema = c10::OpSchema<OpSchemaDef>;
-
public:
static_assert(
c10::guts::is_instantiation_of<std::tuple, ParameterDefTuple>::value,
@@ -49,28 +45,39 @@ class C10OperatorWrapper final : public Operator<Context> {
C10OperatorWrapper(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
+ op_(OperatorHandle()),
kernel_(at::nullopt),
parameters_(parse_parameters_(
operator_def,
- c10::guts::make_index_sequence<num_parameters()>())) {}
+ c10::guts::make_index_sequence<num_parameters()>())) {
- static constexpr size_t num_inputs() {
- return Schema::signature::num_args - num_outputs() - num_parameters();
+ AT_ASSERT(operator_def.output_size() == op_.schema().returns().size());
+ AT_ASSERT(operator_def.input_size() == num_inputs());
}
- static constexpr size_t num_parameters() {
- return std::tuple_size<ParameterDefTuple>::value;
+ size_t num_inputs() {
+ return op_.schema().arguments().size() - num_output_parameters - num_parameters();
}
- static constexpr size_t num_outputs() {
- return Schema::signature::num_outputs;
+ static constexpr size_t num_parameters() {
+ return std::tuple_size<ParameterDefTuple>::value;
}
bool RunOnDevice() override {
- RunOnDevice_(
- c10::guts::make_index_sequence<num_inputs()>(),
- c10::guts::make_index_sequence<num_outputs()>(),
- c10::guts::make_index_sequence<num_parameters()>());
+ // due to caching the stack_, concurrent calling is not allowed.
+ // TODO thread_local might fix this
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ AT_ASSERT(stack_.size() == 0);
+
+ pushInputs_();
+ pushParameters_(guts::make_index_sequence<num_parameters()>());
+ pushOutputParameters_();
+
+ callKernel_();
+
+ popOutputs_();
+
return true;
}
@@ -91,56 +98,44 @@ class C10OperatorWrapper final : public Operator<Context> {
return Parameter::parse(ArgumentHelper(operator_def));
}
- template <
- size_t... InputIndex,
- size_t... OutputIndex,
- size_t... ParameterIndex>
- c10::guts::enable_if_t<
- details::true_t<InputIndex...>::value &&
- !use_array_input,
- void>
- RunOnDevice_(
- c10::guts::index_sequence<InputIndex...>,
- c10::guts::index_sequence<OutputIndex...>,
- c10::guts::index_sequence<ParameterIndex...>) {
- Stack stack;
- torch::jit::push(stack,
- IValue(at::Tensor(C10Tensor(Input(InputIndex))))...,
- IValue(std::get<ParameterIndex>(parameters_))...,
- IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))...
- );
- call_(&stack);
- // TODO Do we have to Write outputs from stack back into the workspace?
+ void pushInputs_() {
+ if (use_array_input) {
+ stack_.emplace_back(ivalue::TensorList::create(array_inputs_()));
+ } else {
+ for (size_t i = 0; i < num_inputs(); ++i) {
+ stack_.emplace_back(at::Tensor(C10Tensor(Input(i))));
+ }
+ }
}
- template <
- size_t... InputIndex,
- size_t... OutputIndex,
- size_t... ParameterIndex>
- c10::guts::enable_if_t<
- details::true_t<InputIndex...>::value &&
- use_array_input,
- void>
- RunOnDevice_(
- c10::guts::index_sequence<InputIndex...>,
- c10::guts::index_sequence<OutputIndex...>,
- c10::guts::index_sequence<ParameterIndex...>) {
- Stack stack;
- torch::jit::push(stack,
- IValue(ivalue::TensorList::create(array_inputs_())),
- IValue(std::get<ParameterIndex>(parameters_))...,
- IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))...
- );
- call_(&stack);
- // TODO Do we have to Write outputs from stack back into the workspace?
+ template<size_t... ParameterIndex>
+ void pushParameters_(guts::index_sequence<ParameterIndex...>) {
+ (void)std::initializer_list<int>{(
+ stack_.emplace_back(std::get<ParameterIndex>(parameters_))
+ , 0)...};
+ }
+
+ void pushOutputParameters_() {
+ for (size_t i = 0; i < num_output_parameters; ++i) {
+ stack_.emplace_back(at::Tensor(C10Tensor(*Output(i))));
+ }
}
- void call_(Stack* stack) {
+ void callKernel_() {
+ AT_ASSERT(stack_.size() == op_.schema().arguments().size());
if (!kernel_.has_value()) {
// TODO if kernel is already set, try re-dispatch to assert it goes to the same kernel
- kernel_ = c10::Dispatcher<OpSchemaDef>::lookup(stack);
+ kernel_ = c10::Dispatcher::singleton().lookup(op_, &stack_);
}
- kernel_->call(stack);
+ kernel_->call(&stack_);
+ }
+
+ void popOutputs_() {
+ AT_ASSERT(stack_.size() == op_.schema().returns().size());
+ for (size_t i = 0; i < op_.schema().returns().size(); ++i) {
+ OperatorBase::SetOutputTensor(i, Tensor(C10Tensor(std::move(stack_[i]).toTensor())));
+ }
+ stack_.clear();
}
std::vector<at::Tensor> array_inputs_() {
@@ -152,8 +147,15 @@ class C10OperatorWrapper final : public Operator<Context> {
return result;
}
+ c10::OperatorHandle op_;
c10::optional<OpKernel> kernel_;
+ // this is stored as a member here to avoid having to re-allocate a stack
+ // for each call. Between kernel calls, stack_.size() == 0, but capacity
+ // should not need to be grown anymore after the first call.
+ std::vector<IValue> stack_;
+ std::mutex mutex_;
+
ParameterTuple parameters_;
};
@@ -174,39 +176,41 @@ C10_DECLARE_REGISTRY(
// TODO Currently we only register the CPU variant. This is going to be fixed
// once the tensor detemplatization lands.
-#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OpSchemaDef, Name) \
- C10_REGISTER_CLASS( \
- C10OperatorRegistry, \
- Name, \
- C10OperatorWrapper<OpSchemaDef, CPUContext, false, std::tuple<>>)
+#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OperatorHandle, Name, NumOutputParameters) \
+ C10_REGISTER_CLASS( \
+ C10OperatorRegistry, \
+ Name, \
+ C10OperatorWrapper<OperatorHandle, CPUContext, false, NumOutputParameters, std::tuple<>>)
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( \
- OpSchemaDef, Name, ...) \
+ OperatorHandle, Name, NumOutputParameters, ...) \
C10_REGISTER_CLASS( \
C10OperatorRegistry, \
Name, \
C10OperatorWrapper< \
- OpSchemaDef, \
+ OperatorHandle, \
CPUContext, \
false, \
+ NumOutputParameters, \
std::tuple<__VA_ARGS__>>)
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT( \
- OpSchemaDef, Name) \
+ OperatorHandle, Name, NumOutputParameters) \
C10_REGISTER_CLASS( \
C10OperatorRegistry, \
Name, \
- C10OperatorWrapper<OpSchemaDef, CPUContext, true, std::tuple<>>)
+ C10OperatorWrapper<OperatorHandle, CPUContext, true, NumOutputParameters, std::tuple<>>)
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( \
- OpSchemaDef, Name, ...) \
+ OperatorHandle, Name, NumOutputParameters, ...) \
C10_REGISTER_CLASS( \
C10OperatorRegistry, \
Name, \
C10OperatorWrapper< \
- OpSchemaDef, \
+ OperatorHandle, \
CPUContext, \
true, \
+ NumOutputParameters, \
std::tuple<__VA_ARGS__>>)
} // namespace caffe2