diff options
author | Hao Lu <hlu@fb.com> | 2017-09-11 16:25:01 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2017-09-11 16:32:40 -0700 |
commit | ebf7784840b8b23627f41d0906813cbffe0f7082 (patch) | |
tree | 9c391cffdf7e5113cf3c89f76ff2fe65b62e585d /caffe2/share | |
parent | 944115c915be8c7882a28a128d1bfd688b380093 (diff) | |
download | pytorch-ebf7784840b8b23627f41d0906813cbffe0f7082.tar.gz pytorch-ebf7784840b8b23627f41d0906813cbffe0f7082.tar.bz2 pytorch-ebf7784840b8b23627f41d0906813cbffe0f7082.zip |
Make android segmentation net run with MPSCNN
Summary: The android segmentation net was failing with MPSCNN because the some fused MPSCNNConvRelu ops become in-place after fusion.
Reviewed By: fricc33
Differential Revision: D5803245
fbshipit-source-id: 6808e9c3504389c113c7a16504d6554e83bdcc3e
Diffstat (limited to 'caffe2/share')
-rw-r--r-- | caffe2/share/caffe2/mobile/contrib/ios/mpscnn/mpscnn_graph.mm | 331 | ||||
-rw-r--r-- | caffe2/share/contrib/ios/mpscnn/mpscnn_graph.mm | 331 |
2 files changed, 662 insertions, 0 deletions
diff --git a/caffe2/share/caffe2/mobile/contrib/ios/mpscnn/mpscnn_graph.mm b/caffe2/share/caffe2/mobile/contrib/ios/mpscnn/mpscnn_graph.mm new file mode 100644 index 0000000000..43a6201d11 --- /dev/null +++ b/caffe2/share/caffe2/mobile/contrib/ios/mpscnn/mpscnn_graph.mm @@ -0,0 +1,331 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#include "caffe2/core/operator.h" +#include "mpscnn.h" +#include "mpscnn_context.h" + +#import <Metal/Metal.h> +#import <MetalPerformanceShaders/MetalPerformanceShaders.h> +#import <UIKit/UIDevice.h> + +namespace caffe2 { +struct Analysis { + struct SSA { + using BlobVersions = std::unordered_map<std::string, size_t>; + BlobVersions inVersions; + BlobVersions outVersions; + }; + std::vector<SSA> ssa; + std::unordered_map<std::string, std::unordered_map<size_t, std::vector<size_t>>> inUsages; +}; + +Analysis analyzeNet(const NetDef& net) { + Analysis::SSA::BlobVersions frontier; + Analysis analysis; + + auto play = [&](size_t i, const OperatorDef& op) { + Analysis::SSA::BlobVersions inVersions; + for (const auto& s : op.input()) { + inVersions[s] = frontier[s]; + analysis.inUsages[s][frontier[s]].push_back(i); + } + Analysis::SSA::BlobVersions outVersions; + for (const auto& s : op.output()) { + if (frontier.find(s) != frontier.end()) { + frontier[s] += 1; + } + outVersions[s] = frontier[s]; + } + analysis.ssa.push_back(Analysis::SSA{inVersions, outVersions}); + }; + + for (auto i = 0; i < net.op_size(); ++i) { + play(i, net.op(i)); + } + return analysis; +} + +NetDef insertInputOutputCopyOps(const NetDef& def) { + // Do some validation of the outputs. For this version, we require: + // - a single input (first element of external_input()) is consumed by the NetDef + // - a single output (first element of external_output()) is produced by the NetDef. + // - the input is consumed by def.op(0), and this is the only consumer. + // - the output is produced by def.op(-1). + CAFFE_ENFORCE_GE(def.external_input_size(), 1); + CAFFE_ENFORCE_GE(def.external_output_size(), 1); + auto analysis = analyzeNet(def); + // enforce a single use of the input blob. + CAFFE_ENFORCE_GE(def.op_size(), 1); + const auto& inputBlob = def.external_input(0); + // Enforce that the input blob has a single usage - in the first operator. + CAFFE_ENFORCE(analysis.inUsages[inputBlob][0] == (std::vector<size_t>{0})); + // Enforce that the external_output(0) blob is produced by the last operator in this sequence. + const auto& outputBlob = def.external_output(0); + CAFFE_ENFORCE(analysis.ssa.back().outVersions.find(outputBlob) != + analysis.ssa.back().outVersions.end()); + const auto& outputBlobVersion = analysis.ssa.back().outVersions[outputBlob]; + // This should hold true by definition of the SSA analysis. + CAFFE_ENFORCE(analysis.inUsages[outputBlob].find(outputBlobVersion) == + analysis.inUsages[outputBlob].end()); + NetDef mdef; + mdef.CopyFrom(def); + mdef.clear_op(); + + { + auto& op = *(mdef.add_op()); + op.set_type("CopyToMPSCNN"); + op.add_input(def.external_input(0)); + op.add_output("__METAL_INPUT_COPY__"); + } + + for (auto i = 0; i < def.op_size(); ++i) { + const auto& ogOp = def.op(i); + auto op = mdef.add_op(); + op->CopyFrom(ogOp); + if (i == 0) { + CAFFE_ENFORCE_EQ(op->input(0), def.external_input(0)); + op->set_input(0, "__METAL_INPUT_COPY__"); + } + if (i == def.op_size() - 1) { + CAFFE_ENFORCE_EQ(op->output(0), def.external_output(0)); + op->set_output(0, "__METAL_OUTPUT_COPY__"); + } + } + { + auto& op = *(mdef.add_op()); + op.set_type("CopyFromMPSCNN"); + op.add_input("__METAL_OUTPUT_COPY__"); + op.add_output(def.external_output(0)); + } + return mdef; +} + +bool nextIsOnlyUserOfCurrent(const Analysis& analysis, + size_t currentIdx, + const OperatorDef& currentOp, + const OperatorDef& nextOp) { + CAFFE_ENFORCE_EQ(currentOp.output_size(), 1); + CAFFE_ENFORCE_GE(nextOp.input_size(), 1); + CAFFE_ENFORCE_EQ(currentOp.output(0), nextOp.input(0)); + const auto outputName = currentOp.output(0); + // Find the version of the output name we are currently looking at. + // This is guaranteed to exist by SSA analysis. + const auto currentOutputVersion = analysis.ssa.at(currentIdx).outVersions.at(outputName); + VLOG(2) << "Blob: " << outputName << ", idx: " << currentOutputVersion; + // Find the usages of this in the SSA analysis. + + // Has this blob every been used? + if (analysis.inUsages.find(outputName) == analysis.inUsages.end()) { + return false; + } + + // Has this version of the blob ever been used? + if (analysis.inUsages.at(outputName).find(currentOutputVersion) == + analysis.inUsages.at(outputName).end()) { + return false; + } + const auto currentOutputUsages = analysis.inUsages.at(outputName).at(currentOutputVersion); + VLOG(2) << "Blob: " << outputName << ", idx: " << currentOutputVersion + << ", usages[0]: " << currentOutputUsages[0]; + + return currentOutputUsages == std::vector<size_t>{currentIdx + 1}; +} +bool tryFuseAdjacentOps(const Analysis& analysis, + size_t currentIdx, + const OperatorDef& currentOp, + const OperatorDef& nextOp, + OperatorDef* fusedOp) { + // Check for possible invalid opportunities. + // Must be identical outputs, with either in-place usage for nextOp, *or* the only use of the + // output of currentOp is the consumption by nextOp. + if (currentOp.output_size() != 1 || !nextOp.input_size() || nextOp.output_size() != 1) { + return false; + } + + if (currentOp.output(0) != nextOp.input(0)) { + return false; + } + + if (!nextIsOnlyUserOfCurrent(analysis, currentIdx, currentOp, nextOp)) { + return false; + } + + // Can we autogenerate this at registration time instead? + static const std::map<std::pair<std::string, std::string>, std::string> fusionOpportunities = {{ + {{"MPSCNNConv", "MPSCNNRelu"}, "MPSCNNConvRelu"}, + {{"MPSCNNConv", "MPSCNNSigmoid"}, "MPSCNNConvSigmoid"}, + {{"MPSCNNFC", "MPSCNNRelu"}, "MPSCNNFCRelu"}, + {{"MPSCNNInstanceNorm", "MPSCNNPRelu"}, "MPSCNNInstanceNormPRelu"}, + }}; + auto it = fusionOpportunities.find({currentOp.type(), nextOp.type()}); + if (it == fusionOpportunities.end()) { + return false; + } + // MPSCNNConvRelu and MPSCNNConvSigmoid cannot be in-place + if (currentOp.type() == "MPSCNNConv" && currentOp.input(0) == nextOp.output(0)) { + return false; + } + LOG(INFO) << "Found a fusion between adjacent ops: (" << currentOp.type() << ", " << nextOp.type() + << ") -> " << it->second; + fusedOp->CopyFrom(currentOp); + fusedOp->set_type(it->second); + for (auto i = 1; i < nextOp.input_size(); ++i) { + fusedOp->add_input(nextOp.input(i)); + } + fusedOp->set_output(0, nextOp.output(0)); + return true; +} + +NetDef runMPSCNNFusion(const NetDef& def) { + CAFFE_ENFORCE_GE(def.op_size(), 1); + NetDef mdef; + mdef.CopyFrom(def); + mdef.clear_op(); + auto i = 0; + auto analysis = analyzeNet(def); + + while (i < def.op_size()) { + if (i == def.op_size() - 1) { + VLOG(2) << "Last operator, skipping"; + auto* op = mdef.add_op(); + op->CopyFrom(def.op(i)); + i += 1; + continue; + } + + const auto& currentOp = def.op(i); + const auto& nextOp = def.op(i + 1); + OperatorDef fusedOp; + if (tryFuseAdjacentOps(analysis, i, currentOp, nextOp, &fusedOp)) { + VLOG(2) << "Found an adjacent fusion at: " << i; + // We can fuse. + auto* op = mdef.add_op(); + op->CopyFrom(fusedOp); + i += 2; + continue; + } + VLOG(2) << "No fusion available"; + // Just emit the current type. + auto* op = mdef.add_op(); + op->CopyFrom(currentOp); + i += 1; + } + return mdef; +} + +NetDef rewriteForMetal(const NetDef& def) { + NetDef mdef; + mdef.CopyFrom(def); + + const auto& opKeyList = CPUOperatorRegistry()->Keys(); + const auto& opKeySet = std::set<std::string>(opKeyList.begin(), opKeyList.end()); + for (auto i = 0; i < mdef.op_size(); ++i) { + auto* op = mdef.mutable_op(i); + const auto mpscnnOp = std::string("MPSCNN") + op->type(); + CAFFE_ENFORCE(opKeySet.find(mpscnnOp) != opKeySet.end()); + op->set_type(mpscnnOp); + } + + mdef = runMPSCNNFusion(mdef); + static std::set<std::string> mpscnnInputOps = { + "CopyToMPSCNN", "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess"}; + static std::set<std::string> mpscnnOutputOps = { + "CopyFromMPSCNN", "MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess"}; + + if (mpscnnInputOps.find(mdef.op(0).type()) == mpscnnInputOps.end() && + mpscnnOutputOps.find(mdef.op(mdef.op_size() - 1).type()) == mpscnnOutputOps.end()) { + mdef = insertInputOutputCopyOps(mdef); + } + CAFFE_ENFORCE_GE(mdef.op_size(), 2); + CAFFE_ENFORCE(mpscnnInputOps.find(mdef.op(0).type()) != mpscnnInputOps.end()); + CAFFE_ENFORCE(mpscnnOutputOps.find(mdef.op(mdef.op_size() - 1).type()) != mpscnnOutputOps.end()); + return mdef; +} + +void dumpDef(const NetDef& d) { + for (const auto& op : d.op()) { + LOG(INFO) << op.input(0) << " -> " << op.type() << " -> " << op.output(0); + } +} + +NetDef annotateDefWithReadCounts(const NetDef& net) { + // Now we have usage versions, we want to compute, for each blob version, the number of usages of + // each blob version. + // ReadCount + auto analysis = analyzeNet(net); + using ReadCount = std::unordered_map<std::string, size_t>; + std::vector<ReadCount> readCounts; + + auto computeReadCount = [&](size_t i, const OperatorDef& op) { + ReadCount rcs; + for (const auto bv : analysis.ssa[i].outVersions) { + const auto versionUsages = analysis.inUsages[bv.first][bv.second]; + rcs[bv.first] = versionUsages.size(); + } + readCounts.push_back(rcs); + }; + for (auto i = 0; i < net.op_size(); ++i) { + computeReadCount(i, net.op(i)); + } + + NetDef annotatedNet; + annotatedNet.CopyFrom(net); + for (auto i = 0; i < annotatedNet.op_size(); ++i) { + auto* op = annotatedNet.mutable_op(i); + // TODO - relax this? CAFFE_ENFORCE_EQ(op->output_size(), 1); + const auto& blob = op->output(0); + const size_t readCount = readCounts[i][blob]; + if (readCount > 1) { + auto* arg = op->add_arg(); + arg->set_name(kMPSCNNReadCountArg); + arg->set_i(readCount); + LOG(INFO) << "Op: " << i << ", ty: " << op->type() << ", blob: " << blob + << ", read count: " << readCount; + } + } + return annotatedNet; +} + +bool tryConvertToMPSCNN(const NetDef& initNet, const NetDef& predictNet, NetDef* metalPredictNet) { +// iOS 10.0 and above. +#define SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(v) \ + ([[[UIDevice currentDevice] systemVersion] compare:v options:NSNumericSearch] != \ + NSOrderedAscending) +#define SYSTEM_VERSION_EQUAL_TO(v) \ + ([[[UIDevice currentDevice] systemVersion] compare:v options:NSNumericSearch] == NSOrderedSame) + + if (!SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(@"10.2")) { + LOG(ERROR) << "MPSCNN is unstable for ios version under 10.2."; + return false; + } + +#undef SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO +#undef SYSTEM_VERSION_EQUAL_TO + + // The iOS GPU Family 3 v2 feature set. Introduced with the Apple A9 GPU and iOS 10.0. + // Don't instantiate the MPSCNNContext, as that compiles the kernel source. + if (![MTLCreateSystemDefaultDevice() supportsFeatureSet:MTLFeatureSet_iOS_GPUFamily3_v2]) { + LOG(ERROR) << "The iOS GPU is less than an A9, so MPSCNN is not available"; + return false; + } + + try { + // Instantiating the net and catching failures allows us to + Workspace ws; + ws.RunNetOnce(initNet); + // Throws if unsupported operators are found. + *metalPredictNet = rewriteForMetal(predictNet); + *metalPredictNet = annotateDefWithReadCounts(*metalPredictNet); + // Throws if unsupported parameters are found. + ws.CreateNet(*metalPredictNet); + LOG(INFO) << "MPSCNN is successfully enabled"; + return true; + } catch (const std::exception& e) { + LOG(ERROR) << "Caught exception trying to convert NetDef to MPSCNN: " << e.what(); + return false; + } +} + +void mpscnnRecordExecutionFinish() { [getMPSCNNContext().commandQueue insertDebugCaptureBoundary]; } + +} diff --git a/caffe2/share/contrib/ios/mpscnn/mpscnn_graph.mm b/caffe2/share/contrib/ios/mpscnn/mpscnn_graph.mm new file mode 100644 index 0000000000..43a6201d11 --- /dev/null +++ b/caffe2/share/contrib/ios/mpscnn/mpscnn_graph.mm @@ -0,0 +1,331 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#include "caffe2/core/operator.h" +#include "mpscnn.h" +#include "mpscnn_context.h" + +#import <Metal/Metal.h> +#import <MetalPerformanceShaders/MetalPerformanceShaders.h> +#import <UIKit/UIDevice.h> + +namespace caffe2 { +struct Analysis { + struct SSA { + using BlobVersions = std::unordered_map<std::string, size_t>; + BlobVersions inVersions; + BlobVersions outVersions; + }; + std::vector<SSA> ssa; + std::unordered_map<std::string, std::unordered_map<size_t, std::vector<size_t>>> inUsages; +}; + +Analysis analyzeNet(const NetDef& net) { + Analysis::SSA::BlobVersions frontier; + Analysis analysis; + + auto play = [&](size_t i, const OperatorDef& op) { + Analysis::SSA::BlobVersions inVersions; + for (const auto& s : op.input()) { + inVersions[s] = frontier[s]; + analysis.inUsages[s][frontier[s]].push_back(i); + } + Analysis::SSA::BlobVersions outVersions; + for (const auto& s : op.output()) { + if (frontier.find(s) != frontier.end()) { + frontier[s] += 1; + } + outVersions[s] = frontier[s]; + } + analysis.ssa.push_back(Analysis::SSA{inVersions, outVersions}); + }; + + for (auto i = 0; i < net.op_size(); ++i) { + play(i, net.op(i)); + } + return analysis; +} + +NetDef insertInputOutputCopyOps(const NetDef& def) { + // Do some validation of the outputs. For this version, we require: + // - a single input (first element of external_input()) is consumed by the NetDef + // - a single output (first element of external_output()) is produced by the NetDef. + // - the input is consumed by def.op(0), and this is the only consumer. + // - the output is produced by def.op(-1). + CAFFE_ENFORCE_GE(def.external_input_size(), 1); + CAFFE_ENFORCE_GE(def.external_output_size(), 1); + auto analysis = analyzeNet(def); + // enforce a single use of the input blob. + CAFFE_ENFORCE_GE(def.op_size(), 1); + const auto& inputBlob = def.external_input(0); + // Enforce that the input blob has a single usage - in the first operator. + CAFFE_ENFORCE(analysis.inUsages[inputBlob][0] == (std::vector<size_t>{0})); + // Enforce that the external_output(0) blob is produced by the last operator in this sequence. + const auto& outputBlob = def.external_output(0); + CAFFE_ENFORCE(analysis.ssa.back().outVersions.find(outputBlob) != + analysis.ssa.back().outVersions.end()); + const auto& outputBlobVersion = analysis.ssa.back().outVersions[outputBlob]; + // This should hold true by definition of the SSA analysis. + CAFFE_ENFORCE(analysis.inUsages[outputBlob].find(outputBlobVersion) == + analysis.inUsages[outputBlob].end()); + NetDef mdef; + mdef.CopyFrom(def); + mdef.clear_op(); + + { + auto& op = *(mdef.add_op()); + op.set_type("CopyToMPSCNN"); + op.add_input(def.external_input(0)); + op.add_output("__METAL_INPUT_COPY__"); + } + + for (auto i = 0; i < def.op_size(); ++i) { + const auto& ogOp = def.op(i); + auto op = mdef.add_op(); + op->CopyFrom(ogOp); + if (i == 0) { + CAFFE_ENFORCE_EQ(op->input(0), def.external_input(0)); + op->set_input(0, "__METAL_INPUT_COPY__"); + } + if (i == def.op_size() - 1) { + CAFFE_ENFORCE_EQ(op->output(0), def.external_output(0)); + op->set_output(0, "__METAL_OUTPUT_COPY__"); + } + } + { + auto& op = *(mdef.add_op()); + op.set_type("CopyFromMPSCNN"); + op.add_input("__METAL_OUTPUT_COPY__"); + op.add_output(def.external_output(0)); + } + return mdef; +} + +bool nextIsOnlyUserOfCurrent(const Analysis& analysis, + size_t currentIdx, + const OperatorDef& currentOp, + const OperatorDef& nextOp) { + CAFFE_ENFORCE_EQ(currentOp.output_size(), 1); + CAFFE_ENFORCE_GE(nextOp.input_size(), 1); + CAFFE_ENFORCE_EQ(currentOp.output(0), nextOp.input(0)); + const auto outputName = currentOp.output(0); + // Find the version of the output name we are currently looking at. + // This is guaranteed to exist by SSA analysis. + const auto currentOutputVersion = analysis.ssa.at(currentIdx).outVersions.at(outputName); + VLOG(2) << "Blob: " << outputName << ", idx: " << currentOutputVersion; + // Find the usages of this in the SSA analysis. + + // Has this blob every been used? + if (analysis.inUsages.find(outputName) == analysis.inUsages.end()) { + return false; + } + + // Has this version of the blob ever been used? + if (analysis.inUsages.at(outputName).find(currentOutputVersion) == + analysis.inUsages.at(outputName).end()) { + return false; + } + const auto currentOutputUsages = analysis.inUsages.at(outputName).at(currentOutputVersion); + VLOG(2) << "Blob: " << outputName << ", idx: " << currentOutputVersion + << ", usages[0]: " << currentOutputUsages[0]; + + return currentOutputUsages == std::vector<size_t>{currentIdx + 1}; +} +bool tryFuseAdjacentOps(const Analysis& analysis, + size_t currentIdx, + const OperatorDef& currentOp, + const OperatorDef& nextOp, + OperatorDef* fusedOp) { + // Check for possible invalid opportunities. + // Must be identical outputs, with either in-place usage for nextOp, *or* the only use of the + // output of currentOp is the consumption by nextOp. + if (currentOp.output_size() != 1 || !nextOp.input_size() || nextOp.output_size() != 1) { + return false; + } + + if (currentOp.output(0) != nextOp.input(0)) { + return false; + } + + if (!nextIsOnlyUserOfCurrent(analysis, currentIdx, currentOp, nextOp)) { + return false; + } + + // Can we autogenerate this at registration time instead? + static const std::map<std::pair<std::string, std::string>, std::string> fusionOpportunities = {{ + {{"MPSCNNConv", "MPSCNNRelu"}, "MPSCNNConvRelu"}, + {{"MPSCNNConv", "MPSCNNSigmoid"}, "MPSCNNConvSigmoid"}, + {{"MPSCNNFC", "MPSCNNRelu"}, "MPSCNNFCRelu"}, + {{"MPSCNNInstanceNorm", "MPSCNNPRelu"}, "MPSCNNInstanceNormPRelu"}, + }}; + auto it = fusionOpportunities.find({currentOp.type(), nextOp.type()}); + if (it == fusionOpportunities.end()) { + return false; + } + // MPSCNNConvRelu and MPSCNNConvSigmoid cannot be in-place + if (currentOp.type() == "MPSCNNConv" && currentOp.input(0) == nextOp.output(0)) { + return false; + } + LOG(INFO) << "Found a fusion between adjacent ops: (" << currentOp.type() << ", " << nextOp.type() + << ") -> " << it->second; + fusedOp->CopyFrom(currentOp); + fusedOp->set_type(it->second); + for (auto i = 1; i < nextOp.input_size(); ++i) { + fusedOp->add_input(nextOp.input(i)); + } + fusedOp->set_output(0, nextOp.output(0)); + return true; +} + +NetDef runMPSCNNFusion(const NetDef& def) { + CAFFE_ENFORCE_GE(def.op_size(), 1); + NetDef mdef; + mdef.CopyFrom(def); + mdef.clear_op(); + auto i = 0; + auto analysis = analyzeNet(def); + + while (i < def.op_size()) { + if (i == def.op_size() - 1) { + VLOG(2) << "Last operator, skipping"; + auto* op = mdef.add_op(); + op->CopyFrom(def.op(i)); + i += 1; + continue; + } + + const auto& currentOp = def.op(i); + const auto& nextOp = def.op(i + 1); + OperatorDef fusedOp; + if (tryFuseAdjacentOps(analysis, i, currentOp, nextOp, &fusedOp)) { + VLOG(2) << "Found an adjacent fusion at: " << i; + // We can fuse. + auto* op = mdef.add_op(); + op->CopyFrom(fusedOp); + i += 2; + continue; + } + VLOG(2) << "No fusion available"; + // Just emit the current type. + auto* op = mdef.add_op(); + op->CopyFrom(currentOp); + i += 1; + } + return mdef; +} + +NetDef rewriteForMetal(const NetDef& def) { + NetDef mdef; + mdef.CopyFrom(def); + + const auto& opKeyList = CPUOperatorRegistry()->Keys(); + const auto& opKeySet = std::set<std::string>(opKeyList.begin(), opKeyList.end()); + for (auto i = 0; i < mdef.op_size(); ++i) { + auto* op = mdef.mutable_op(i); + const auto mpscnnOp = std::string("MPSCNN") + op->type(); + CAFFE_ENFORCE(opKeySet.find(mpscnnOp) != opKeySet.end()); + op->set_type(mpscnnOp); + } + + mdef = runMPSCNNFusion(mdef); + static std::set<std::string> mpscnnInputOps = { + "CopyToMPSCNN", "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess"}; + static std::set<std::string> mpscnnOutputOps = { + "CopyFromMPSCNN", "MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess"}; + + if (mpscnnInputOps.find(mdef.op(0).type()) == mpscnnInputOps.end() && + mpscnnOutputOps.find(mdef.op(mdef.op_size() - 1).type()) == mpscnnOutputOps.end()) { + mdef = insertInputOutputCopyOps(mdef); + } + CAFFE_ENFORCE_GE(mdef.op_size(), 2); + CAFFE_ENFORCE(mpscnnInputOps.find(mdef.op(0).type()) != mpscnnInputOps.end()); + CAFFE_ENFORCE(mpscnnOutputOps.find(mdef.op(mdef.op_size() - 1).type()) != mpscnnOutputOps.end()); + return mdef; +} + +void dumpDef(const NetDef& d) { + for (const auto& op : d.op()) { + LOG(INFO) << op.input(0) << " -> " << op.type() << " -> " << op.output(0); + } +} + +NetDef annotateDefWithReadCounts(const NetDef& net) { + // Now we have usage versions, we want to compute, for each blob version, the number of usages of + // each blob version. + // ReadCount + auto analysis = analyzeNet(net); + using ReadCount = std::unordered_map<std::string, size_t>; + std::vector<ReadCount> readCounts; + + auto computeReadCount = [&](size_t i, const OperatorDef& op) { + ReadCount rcs; + for (const auto bv : analysis.ssa[i].outVersions) { + const auto versionUsages = analysis.inUsages[bv.first][bv.second]; + rcs[bv.first] = versionUsages.size(); + } + readCounts.push_back(rcs); + }; + for (auto i = 0; i < net.op_size(); ++i) { + computeReadCount(i, net.op(i)); + } + + NetDef annotatedNet; + annotatedNet.CopyFrom(net); + for (auto i = 0; i < annotatedNet.op_size(); ++i) { + auto* op = annotatedNet.mutable_op(i); + // TODO - relax this? CAFFE_ENFORCE_EQ(op->output_size(), 1); + const auto& blob = op->output(0); + const size_t readCount = readCounts[i][blob]; + if (readCount > 1) { + auto* arg = op->add_arg(); + arg->set_name(kMPSCNNReadCountArg); + arg->set_i(readCount); + LOG(INFO) << "Op: " << i << ", ty: " << op->type() << ", blob: " << blob + << ", read count: " << readCount; + } + } + return annotatedNet; +} + +bool tryConvertToMPSCNN(const NetDef& initNet, const NetDef& predictNet, NetDef* metalPredictNet) { +// iOS 10.0 and above. +#define SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(v) \ + ([[[UIDevice currentDevice] systemVersion] compare:v options:NSNumericSearch] != \ + NSOrderedAscending) +#define SYSTEM_VERSION_EQUAL_TO(v) \ + ([[[UIDevice currentDevice] systemVersion] compare:v options:NSNumericSearch] == NSOrderedSame) + + if (!SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(@"10.2")) { + LOG(ERROR) << "MPSCNN is unstable for ios version under 10.2."; + return false; + } + +#undef SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO +#undef SYSTEM_VERSION_EQUAL_TO + + // The iOS GPU Family 3 v2 feature set. Introduced with the Apple A9 GPU and iOS 10.0. + // Don't instantiate the MPSCNNContext, as that compiles the kernel source. + if (![MTLCreateSystemDefaultDevice() supportsFeatureSet:MTLFeatureSet_iOS_GPUFamily3_v2]) { + LOG(ERROR) << "The iOS GPU is less than an A9, so MPSCNN is not available"; + return false; + } + + try { + // Instantiating the net and catching failures allows us to + Workspace ws; + ws.RunNetOnce(initNet); + // Throws if unsupported operators are found. + *metalPredictNet = rewriteForMetal(predictNet); + *metalPredictNet = annotateDefWithReadCounts(*metalPredictNet); + // Throws if unsupported parameters are found. + ws.CreateNet(*metalPredictNet); + LOG(INFO) << "MPSCNN is successfully enabled"; + return true; + } catch (const std::exception& e) { + LOG(ERROR) << "Caught exception trying to convert NetDef to MPSCNN: " << e.what(); + return false; + } +} + +void mpscnnRecordExecutionFinish() { [getMPSCNNContext().commandQueue insertDebugCaptureBoundary]; } + +} |