diff options
Diffstat (limited to 'modules')
41 files changed, 5171 insertions, 0 deletions
diff --git a/modules/main.cc b/modules/main.cc new file mode 100644 index 0000000..86e9c36 --- /dev/null +++ b/modules/main.cc @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <assert.h> + +#include "Module.h" +#include "aitt_internal_definitions.h" + +extern "C" { + +// Function name Should be same with aitt::AittTransport::MODULE_ENTRY_NAME +API void *aitt_module_entry(const char *ip, AittDiscovery &discovery) +{ + assert(!strcmp(__func__, aitt::AittTransport::MODULE_ENTRY_NAME) + && "Entry point name is not matched"); + + std::string ip_address(ip); + Module *module = new Module(ip_address, discovery); + + AittTransport *tModule = dynamic_cast<AittTransport *>(module); + // NOTE: + // validate that the module creates valid object (which inherits AittTransport) + assert(tModule && "Transport Module is not created"); + + return tModule; +} + +} // extern "C" diff --git a/modules/tcp/CMakeLists.txt b/modules/tcp/CMakeLists.txt new file mode 100644 index 0000000..edac2fd --- /dev/null +++ b/modules/tcp/CMakeLists.txt @@ -0,0 +1,14 @@ +SET(AITT_TCP aitt-transport-tcp) + +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) + +ADD_LIBRARY(TCP_OBJ OBJECT TCP.cc TCPServer.cc) +ADD_LIBRARY(${AITT_TCP} SHARED ../main.cc Module.cc $<TARGET_OBJECTS:TCP_OBJ>) +TARGET_LINK_LIBRARIES(${AITT_TCP} ${AITT_TCP_NEEDS_LIBRARIES} Threads::Threads ${AITT_COMMON}) + +INSTALL(TARGETS ${AITT_TCP} DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +IF(BUILD_TESTING) + ADD_SUBDIRECTORY(samples) + ADD_SUBDIRECTORY(tests) +ENDIF(BUILD_TESTING) diff --git a/modules/tcp/Module.cc b/modules/tcp/Module.cc new file mode 100644 index 0000000..bc50d7d --- /dev/null +++ b/modules/tcp/Module.cc @@ -0,0 +1,513 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Module.h" + +#include <MQ.h> +#include <flatbuffers/flexbuffers.h> +#include <unistd.h> + +#include "aitt_internal.h" + +/* + * P2P Data Packet Definition + * TopicLength: 4 bytes + * TopicString: $TopicLength + */ + +Module::Module(const std::string &ip, AittDiscovery &discovery) : AittTransport(discovery), ip(ip) +{ + aittThread = std::thread(&Module::ThreadMain, this); + + discovery_cb = discovery.AddDiscoveryCB(AITT_TYPE_TCP, + std::bind(&Module::DiscoveryMessageCallback, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); + DBG("Discovery Callback : %p, %d", this, discovery_cb); +} + +Module::~Module(void) +{ + discovery.RemoveDiscoveryCB(discovery_cb); + + while (main_loop.Quit() == false) { + // wait when called before the thread has completely created. + usleep(1000); + } + + if (aittThread.joinable()) + aittThread.join(); +} + +void Module::ThreadMain(void) +{ + pthread_setname_np(pthread_self(), "TCPWorkerLoop"); + main_loop.Run(); +} + +void Module::Publish(const std::string &topic, const void *data, const size_t datalen, + const std::string &correlation, AittQoS qos, bool retain) +{ + // NOTE: + // Iterate discovered service table + // PublishMap + // map { + // "/customTopic/faceRecog": map { + // "$clientId": map { + // 11234: $handle, + // + // ... + // + // 21234: nullptr, + // }, + // }, + // } + std::lock_guard<std::mutex> auto_lock_publish(publishTableLock); + for (PublishMap::iterator it = publishTable.begin(); it != publishTable.end(); ++it) { + // NOTE: + // Find entries that have matched with the given topic + if (!aitt::MQ::CompareTopic(it->first, topic)) + continue; + + // NOTE: + // Iterate all hosts + for (HostMap::iterator hostIt = it->second.begin(); hostIt != it->second.end(); ++hostIt) { + // Iterate all ports, + // the current implementation only be able to have the ZERO or a SINGLE entry + // hostIt->first // clientId + for (PortMap::iterator portIt = hostIt->second.begin(); portIt != hostIt->second.end(); + ++portIt) { + // portIt->first // port + // portIt->second // handle + if (!portIt->second) { + std::string host; + { + ClientMap::iterator clientIt; + std::lock_guard<std::mutex> auto_lock_client(clientTableLock); + + clientIt = clientTable.find(hostIt->first); + if (clientIt != clientTable.end()) + host = clientIt->second; + + // NOTE: + // otherwise, it is a critical error + // The broken clientTable or subscribeTable + } + + std::unique_ptr<TCP> client(std::make_unique<TCP>(host, portIt->first)); + + // TODO: + // If the client gets disconnected, + // This channel entry must be cleared + // In order to do that, + // There should be an observer to monitor + // each connections and manipulate + // the discovered service table + portIt->second = std::move(client); + } + + if (!portIt->second) { + ERR("Failed to create a new client instance"); + continue; + } + + SendTopic(topic, portIt); + SendPayload(datalen, portIt, data); + } + } // connectionEntries + } // publishTable +} + +void Module::SendTopic(const std::string &topic, Module::PortMap::iterator &portIt) +{ + uint32_t topicLen = topic.length(); + size_t szData = sizeof(topicLen); + portIt->second->Send(static_cast<void *>(&topicLen), szData); + szData = topicLen; + portIt->second->Send(static_cast<const void *>(topic.c_str()), szData); +} + +void Module::SendPayload(const size_t &datalen, Module::PortMap::iterator &portIt, const void *data) +{ + uint32_t sendsize = datalen; + size_t szsize = sizeof(sendsize); + + try { + if (0 == datalen) { + // distinguish between connection problems and zero-size messages + INFO("Send zero-size Message"); + sendsize = UINT32_MAX; + } + portIt->second->Send(static_cast<void *>(&sendsize), szsize); + + int msgSize = datalen; + while (0 < msgSize) { + size_t sentSize = msgSize; + char *dataIdx = (char *)data + (sendsize - msgSize); + portIt->second->Send(dataIdx, sentSize); + if (sentSize > 0) { + msgSize -= sentSize; + } + } + } catch (std::exception &e) { + ERR("An exception(%s) occurs during Send().", e.what()); + } +} + +void Module::Publish(const std::string &topic, const void *data, const size_t datalen, AittQoS qos, + bool retain) +{ + Publish(topic, data, datalen, std::string(), qos, retain); +} + +void *Module::Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb, + void *cbdata, AittQoS qos) +{ + std::unique_ptr<TCP::Server> tcpServer; + + unsigned short port = 0; + tcpServer = std::make_unique<TCP::Server>("0.0.0.0", port); + TCPServerData *listen_info = new TCPServerData; + listen_info->impl = this; + listen_info->cb = cb; + listen_info->cbdata = cbdata; + listen_info->topic = topic; + auto handle = tcpServer->GetHandle(); + + main_loop.AddWatch(handle, AcceptConnection, listen_info); + + // 서비스 테이블에 토픽을 키워드로 프로토콜을 등록한다. + { + std::lock_guard<std::mutex> autoLock(subscribeTableLock); + subscribeTable.insert(SubscribeMap::value_type(topic, std::move(tcpServer))); + UpdateDiscoveryMsg(); + } + + return reinterpret_cast<void *>(handle); +} + +void *Module::Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb, + const void *data, const size_t datalen, void *cbdata, AittQoS qos) +{ + return nullptr; +} + +void *Module::Unsubscribe(void *handlePtr) +{ + int handle = static_cast<int>(reinterpret_cast<intptr_t>(handlePtr)); + TCPServerData *listen_info = dynamic_cast<TCPServerData *>(main_loop.RemoveWatch(handle)); + if (!listen_info) + return nullptr; + + { + std::lock_guard<std::mutex> autoLock(subscribeTableLock); + auto it = subscribeTable.find(listen_info->topic); + if (it == subscribeTable.end()) + throw std::runtime_error("Service is not registered: " + listen_info->topic); + + subscribeTable.erase(it); + + UpdateDiscoveryMsg(); + } + + void *cbdata = listen_info->cbdata; + listen_info->client_lock.lock(); + for (auto fd : listen_info->client_list) { + TCPData *connect_info = dynamic_cast<TCPData *>(main_loop.RemoveWatch(fd)); + delete connect_info; + } + listen_info->client_list.clear(); + listen_info->client_lock.unlock(); + delete listen_info; + + return cbdata; +} + +void Module::DiscoveryMessageCallback(const std::string &clientId, const std::string &status, + const void *msg, const int szmsg) +{ + // NOTE: + // Iterate discovered service table + // PublishMap + // map { + // "/customTopic/faceRecog": map { + // "clientId.uniq.abcd.123": map { + // 11234: pair { + // "protocol": 1, + // "handle": nullptr, + // }, + // + // ... + // + // 21234: pair { + // "protocol": 2, + // "handle": nullptr, + // } + // }, + // }, + // } + + if (!status.compare(AittDiscovery::WILL_LEAVE_NETWORK)) { + { + std::lock_guard<std::mutex> autoLock(clientTableLock); + // Delete from the { clientId : Host } mapping table + clientTable.erase(clientId); + } + + { + // NOTE: + // Iterate all topics in the publishTable holds discovered client information + std::lock_guard<std::mutex> autoLock(publishTableLock); + for (auto it = publishTable.begin(); it != publishTable.end(); ++it) + it->second.erase(clientId); + } + return; + } + + // serviceMessage (flexbuffers) + // map { + // "host": "192.168.1.11", + // "$topic": port, + // } + auto map = flexbuffers::GetRoot(static_cast<const uint8_t *>(msg), szmsg).AsMap(); + std::string host = map["host"].AsString().c_str(); + + // NOTE: + // Update the clientTable + { + std::lock_guard<std::mutex> autoLock(clientTableLock); + auto clientIt = clientTable.find(clientId); + if (clientIt == clientTable.end()) + clientTable.insert(ClientMap::value_type(clientId, host)); + else if (clientIt->second.compare(host)) + clientIt->second = host; + } + + auto topics = map.Keys(); + for (size_t idx = 0; idx < topics.size(); ++idx) { + std::string topic = topics[idx].AsString().c_str(); + + if (!topic.compare("host")) + continue; + + auto port = map[topic].AsUInt16(); + + { + std::lock_guard<std::mutex> autoLock(publishTableLock); + UpdatePublishTable(topic, clientId, port); + } + } +} + +void Module::UpdateDiscoveryMsg() +{ + flexbuffers::Builder fbb; + // flexbuffers + // { + // "host": "127.0.0.1", + // "/customTopic/aitt/faceRecog": $port, + // "/customTopic/aitt/ASR": 102020, + // + // ... + // + // "/customTopic/aitt/+": 20123, + // } + fbb.Map([this, &fbb]() { + fbb.String("host", ip); + + // SubscribeTable + // map { + // "/customTopic/mytopic": $serverHandle, + // ... + // } + for (auto it = subscribeTable.begin(); it != subscribeTable.end(); ++it) { + if (it->second) + fbb.UInt(it->first.c_str(), it->second->GetPort()); + else + fbb.UInt(it->first.c_str(), 0); // this is an error case + } + }); + fbb.Finish(); + + auto buf = fbb.GetBuffer(); + discovery.UpdateDiscoveryMsg(AITT_TYPE_TCP, buf.data(), buf.size()); +} + +void Module::ReceiveData(MainLoopHandler::MainLoopResult result, int handle, + MainLoopHandler::MainLoopData *user_data) +{ + TCPData *connect_info = dynamic_cast<TCPData *>(user_data); + RET_IF(connect_info == nullptr); + TCPServerData *parent_info = connect_info->parent; + RET_IF(parent_info == nullptr); + Module *impl = parent_info->impl; + RET_IF(impl == nullptr); + + if (result == MainLoopHandler::HANGUP) { + ERR("Disconnected"); + return impl->HandleClientDisconnect(handle); + } + + uint32_t szmsg = 0; + size_t szdata = sizeof(szmsg); + char *msg = nullptr; + std::string topic; + + try { + topic = impl->GetTopicName(connect_info); + if (topic.empty()) { + ERR("Unknown Topic"); + return impl->HandleClientDisconnect(handle); + } + + connect_info->client->Recv(static_cast<void *>(&szmsg), szdata); + if (szmsg == 0) { + ERR("Disconnected"); + return impl->HandleClientDisconnect(handle); + } + + if (UINT32_MAX == szmsg) { + // distinguish between connection problems and zero-size messages + INFO("Got zero-size Message"); + szmsg = 0; + } + + msg = static_cast<char *>(malloc(szmsg)); + int msgSize = szmsg; + while (0 < msgSize) { + size_t receivedSize = msgSize; + connect_info->client->Recv(static_cast<void *>(msg + (szmsg - msgSize)), receivedSize); + if (receivedSize > 0) { + msgSize -= receivedSize; + } + } + } catch (std::exception &e) { + ERR("An exception(%s) occurs during Recv()", e.what()); + } + + std::string correlation; + // TODO: + // Correlation data (string) should be filled + + parent_info->cb(topic, msg, szmsg, parent_info->cbdata, correlation); + free(msg); +} + +void Module::HandleClientDisconnect(int handle) +{ + TCPData *connect_info = dynamic_cast<TCPData *>(main_loop.RemoveWatch(handle)); + if (connect_info == nullptr) { + ERR("No watch data"); + return; + } + connect_info->parent->client_lock.lock(); + auto it = std::find(connect_info->parent->client_list.begin(), + connect_info->parent->client_list.end(), handle); + connect_info->parent->client_list.erase(it); + connect_info->parent->client_lock.unlock(); + + delete connect_info; +} + +std::string Module::GetTopicName(Module::TCPData *connect_info) +{ + uint32_t topic_len = 0; + size_t data_size = sizeof(topic_len); + connect_info->client->Recv(static_cast<void *>(&topic_len), data_size); + + if (AITT_TOPIC_NAME_MAX < topic_len) { + ERR("Invalid topic name length(%d)", topic_len); + return std::string(); + } + + char data[topic_len]; + data_size = topic_len; + connect_info->client->Recv(data, data_size); + if (data_size != topic_len) + ERR("Recv() Fail"); + + return std::string(data, data_size); +} + +void Module::AcceptConnection(MainLoopHandler::MainLoopResult result, int handle, + MainLoopHandler::MainLoopData *user_data) +{ + // TODO: + // Update the discovery map + std::unique_ptr<TCP> client; + + TCPServerData *listen_info = dynamic_cast<TCPServerData *>(user_data); + Module *impl = listen_info->impl; + { + std::lock_guard<std::mutex> autoLock(impl->subscribeTableLock); + + auto clientIt = impl->subscribeTable.find(listen_info->topic); + if (clientIt == impl->subscribeTable.end()) + return; + + client = clientIt->second->AcceptPeer(); + } + + if (client == nullptr) { + ERR("Unable to accept a peer"); // NOTE: FATAL ERROR + return; + } + + int cHandle = client->GetHandle(); + listen_info->client_list.push_back(cHandle); + + TCPData *ecd = new TCPData; + ecd->parent = listen_info; + ecd->client = std::move(client); + + impl->main_loop.AddWatch(cHandle, ReceiveData, ecd); +} + +void Module::UpdatePublishTable(const std::string &topic, const std::string &clientId, + unsigned short port) +{ + auto topicIt = publishTable.find(topic); + if (topicIt == publishTable.end()) { + PortMap portMap; + portMap.insert(PortMap::value_type(port, nullptr)); + HostMap hostMap; + hostMap.insert(HostMap::value_type(clientId, std::move(portMap))); + publishTable.insert(PublishMap::value_type(topic, std::move(hostMap))); + return; + } + + auto hostIt = topicIt->second.find(clientId); + if (hostIt == topicIt->second.end()) { + PortMap portMap; + portMap.insert(PortMap::value_type(port, nullptr)); + topicIt->second.insert(HostMap::value_type(clientId, std::move(portMap))); + return; + } + + // NOTE: + // The current implementation only has a single port entry + // therefore, if the hostIt is not empty, there is the previous connection + if (!hostIt->second.empty()) { + auto portIt = hostIt->second.begin(); + + if (portIt->first == port) + return; // nothing changed. keep the current handle + + // otherwise, delete the connection handle + // to make a new connection with the new port + hostIt->second.clear(); + } + + hostIt->second.insert(PortMap::value_type(port, nullptr)); +} diff --git a/modules/tcp/Module.h b/modules/tcp/Module.h new file mode 100644 index 0000000..4011980 --- /dev/null +++ b/modules/tcp/Module.h @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include <AittTransport.h> +#include <MainLoopHandler.h> + +#include <map> +#include <memory> +#include <mutex> +#include <string> +#include <thread> + +#include "TCPServer.h" + +using AittTransport = aitt::AittTransport; +using MainLoopHandler = aitt::MainLoopHandler; +using AittDiscovery = aitt::AittDiscovery; + +class Module : public AittTransport { + public: + explicit Module(const std::string &ip, AittDiscovery &discovery); + virtual ~Module(void); + + void Publish(const std::string &topic, const void *data, const size_t datalen, + const std::string &correlation, AittQoS qos = AITT_QOS_AT_MOST_ONCE, + bool retain = false) override; + + void Publish(const std::string &topic, const void *data, const size_t datalen, + AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override; + + void *Subscribe(const std::string &topic, const SubscribeCallback &cb, void *cbdata = nullptr, + AittQoS qos = AITT_QOS_AT_MOST_ONCE) override; + + void *Subscribe(const std::string &topic, const SubscribeCallback &cb, const void *data, + const size_t datalen, void *cbdata = nullptr, + AittQoS qos = AITT_QOS_AT_MOST_ONCE) override; + void *Unsubscribe(void *handle) override; + + private: + struct TCPServerData : public MainLoopHandler::MainLoopData { + Module *impl; + SubscribeCallback cb; + void *cbdata; + std::string topic; + std::vector<int> client_list; + std::mutex client_lock; + }; + + struct TCPData : public MainLoopHandler::MainLoopData { + TCPServerData *parent; + std::unique_ptr<TCP> client; + }; + + // SubscribeTable + // map { + // "/customTopic/mytopic": $serverHandle, + // ... + // } + using SubscribeMap = std::map<std::string, std::unique_ptr<TCP::Server>>; + + // ClientTable + // map { + // $clientId: $host, + // "client.uniqId.123": "192.168.1.11" + // ... + // } + using ClientMap = std::map<std::string /* id */, std::string /* host */>; + + // NOTE: + // There could be multiple clientIds for the single host + // If several applications are run on the same device, each applicaion will get unique client + // Ids therefore we have to keep in mind that the clientId is not 1:1 matched for the IPAddress. + + // PublishTable + // map { + // "/customTopic/faceRecog": map { + // $clientId: map { + // 11234: $clientHandle, + // + // ... + // + // 21234: $clientHandle, + // }, + // }, + // } + // + // NOTE: + // TCP handle should be the unique_ptr, so if we delete the entry from the map, + // the handle must be released automatically + // in order to make the handle "unique_ptr", it should be a class object not the "void *" + using PortMap = std::map<unsigned short /* port */, std::unique_ptr<TCP>>; + using HostMap = std::map<std::string /* clientId */, PortMap>; + using PublishMap = std::map<std::string /* topic */, HostMap>; + + static void AcceptConnection(MainLoopHandler::MainLoopResult result, int handle, + MainLoopHandler::MainLoopData *watchData); + void DiscoveryMessageCallback(const std::string &clientId, const std::string &status, + const void *msg, const int szmsg); + void UpdateDiscoveryMsg(); + static void ReceiveData(MainLoopHandler::MainLoopResult result, int handle, + MainLoopHandler::MainLoopData *watchData); + void HandleClientDisconnect(int handle); + std::string GetTopicName(TCPData *connect_info); + void ThreadMain(void); + void SendPayload(const size_t &datalen, Module::PortMap::iterator &portIt, const void *data); + void SendTopic(const std::string &topic, Module::PortMap::iterator &portIt); + + void UpdatePublishTable(const std::string &topic, const std::string &host, unsigned short port); + + MainLoopHandler main_loop; + std::thread aittThread; + std::string ip; + int discovery_cb; + + PublishMap publishTable; + std::mutex publishTableLock; + SubscribeMap subscribeTable; + std::mutex subscribeTableLock; + ClientMap clientTable; + std::mutex clientTableLock; +}; diff --git a/modules/tcp/TCP.cc b/modules/tcp/TCP.cc new file mode 100644 index 0000000..3b6751e --- /dev/null +++ b/modules/tcp/TCP.cc @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "TCP.h" + +#include <arpa/inet.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <cstdlib> +#include <cstring> +#include <stdexcept> + +#include "aitt_internal.h" + +TCP::TCP(const std::string &host, unsigned short port) : handle(-1), addrlen(0), addr(nullptr) +{ + int ret = 0; + + do { + if (port == 0) { + ret = EINVAL; + break; + } + + handle = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, 0); + if (handle < 0) { + ERR("socket() Fail()"); + break; + } + + addrlen = sizeof(sockaddr_in); + addr = static_cast<sockaddr *>(calloc(1, addrlen)); + if (!addr) { + ERR("calloc() Fail()"); + break; + } + + sockaddr_in *inet_addr = reinterpret_cast<sockaddr_in *>(addr); + if (!inet_pton(AF_INET, host.c_str(), &inet_addr->sin_addr)) { + ret = EINVAL; + break; + } + + inet_addr->sin_port = htons(port); + inet_addr->sin_family = AF_INET; + + ret = connect(handle, addr, addrlen); + if (ret < 0) { + ERR("connect() Fail(%s, %d)", host.c_str(), port); + break; + } + + SetupOptions(); + return; + } while (0); + + if (ret <= 0) + ret = errno; + + free(addr); + if (handle >= 0 && close(handle) < 0) + ERR_CODE(errno, "close"); + throw std::runtime_error(strerror(ret)); +} + +TCP::TCP(int handle, sockaddr *addr, socklen_t szAddr) : handle(handle), addrlen(szAddr), addr(addr) +{ + SetupOptions(); +} + +TCP::~TCP(void) +{ + if (handle < 0) + return; + + free(addr); + if (close(handle) < 0) + ERR_CODE(errno, "close"); +} + +void TCP::SetupOptions(void) +{ + int on = 1; + + int ret = setsockopt(handle, IPPROTO_IP, TCP_NODELAY, &on, sizeof(on)); + if (ret < 0) { + ERR_CODE(errno, "delay option setting failed"); + } +} + +void TCP::Send(const void *data, size_t &szData) +{ + int ret = send(handle, data, szData, 0); + if (ret < 0) { + ERR("Fail to send data, handle = %d, size = %zu", handle, szData); + throw std::runtime_error(strerror(errno)); + } + + szData = ret; +} + +void TCP::Recv(void *data, size_t &szData) +{ + int ret = recv(handle, data, szData, 0); + if (ret < 0) { + ERR("Fail to recv data, handle = %d, size = %zu", handle, szData); + throw std::runtime_error(strerror(errno)); + } + + szData = ret; +} + +int TCP::GetHandle(void) +{ + return handle; +} + +void TCP::GetPeerInfo(std::string &host, unsigned short &port) +{ + char address[INET_ADDRSTRLEN] = { + 0, + }; + + if (!inet_ntop(AF_INET, &reinterpret_cast<sockaddr_in *>(this->addr)->sin_addr, address, + sizeof(address))) + throw std::runtime_error(strerror(errno)); + + port = ntohs(reinterpret_cast<sockaddr_in *>(this->addr)->sin_port); + host = address; +} + +unsigned short TCP::GetPort(void) +{ + sockaddr_in addr; + socklen_t addrlen = sizeof(addr); + + if (getsockname(handle, reinterpret_cast<sockaddr *>(&addr), &addrlen) < 0) + throw std::runtime_error(strerror(errno)); + + return ntohs(addr.sin_port); +} diff --git a/modules/tcp/TCP.h b/modules/tcp/TCP.h new file mode 100644 index 0000000..535819c --- /dev/null +++ b/modules/tcp/TCP.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include <sys/socket.h> +#include <sys/types.h> /* See NOTES */ + +#include <string> + +class TCP { + public: + class Server; + + TCP(const std::string &host, unsigned short port); + virtual ~TCP(void); + + void Send(const void *data, size_t &szData); + void Recv(void *data, size_t &szData); + int GetHandle(void); + unsigned short GetPort(void); + void GetPeerInfo(std::string &host, unsigned short &port); + + private: + TCP(int handle, sockaddr *addr, socklen_t addrlen); + void SetupOptions(void); + + int handle; + socklen_t addrlen; + sockaddr *addr; +}; diff --git a/modules/tcp/TCPServer.cc b/modules/tcp/TCPServer.cc new file mode 100644 index 0000000..55f8511 --- /dev/null +++ b/modules/tcp/TCPServer.cc @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "TCPServer.h" + +#include <arpa/inet.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <cstdlib> +#include <stdexcept> + +#include "aitt_internal.h" + +#define BACKLOG 10 // Accept only 10 simultaneously connections by default + +TCP::Server::Server(const std::string &host, unsigned short &port) + : handle(-1), addr(nullptr), addrlen(0) +{ + int ret = 0; + + do { + handle = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, 0); + if (handle < 0) + break; + + addrlen = sizeof(sockaddr_in); + addr = static_cast<sockaddr *>(calloc(1, sizeof(sockaddr_in))); + if (!addr) + break; + + sockaddr_in *inet_addr = reinterpret_cast<sockaddr_in *>(addr); + if (!inet_pton(AF_INET, host.c_str(), &inet_addr->sin_addr)) { + ret = EINVAL; + break; + } + + inet_addr->sin_port = htons(port); + inet_addr->sin_family = AF_INET; + + int on = 1; + ret = setsockopt(handle, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + if (ret < 0) + break; + + ret = bind(handle, addr, addrlen); + if (ret < 0) + break; + + if (!port) { + if (getsockname(handle, addr, &addrlen) < 0) + break; + port = ntohs(inet_addr->sin_port); + } + + ret = listen(handle, BACKLOG); + if (ret < 0) + break; + + return; + } while (0); + + if (ret <= 0) + ret = errno; + + free(addr); + + if (handle >= 0 && close(handle) < 0) + ERR_CODE(errno, "close"); + + throw std::runtime_error(strerror(ret)); +} + +TCP::Server::~Server(void) +{ + if (handle < 0) + return; + + free(addr); + if (close(handle) < 0) + ERR_CODE(errno, "close"); +} + +std::unique_ptr<TCP> TCP::Server::AcceptPeer(void) +{ + sockaddr *peerAddr; + socklen_t szAddr = sizeof(sockaddr_in); + int peerHandle; + + peerAddr = static_cast<sockaddr *>(calloc(1, szAddr)); + if (!peerAddr) + throw std::runtime_error(strerror(errno)); + + peerHandle = accept(handle, peerAddr, &szAddr); + if (peerHandle < 0) { + free(peerAddr); + throw std::runtime_error(strerror(errno)); + } + + return std::unique_ptr<TCP>(new TCP(peerHandle, peerAddr, szAddr)); +} + +int TCP::Server::GetHandle(void) +{ + return handle; +} + +unsigned short TCP::Server::GetPort(void) +{ + sockaddr_in addr; + socklen_t addrlen = sizeof(addr); + + if (getsockname(handle, reinterpret_cast<sockaddr *>(&addr), &addrlen) < 0) + throw std::runtime_error(strerror(errno)); + + return ntohs(addr.sin_port); +} diff --git a/modules/tcp/TCPServer.h b/modules/tcp/TCPServer.h new file mode 100644 index 0000000..3c82bc6 --- /dev/null +++ b/modules/tcp/TCPServer.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include <memory> +#include <string> + +#include "TCP.h" + +class TCP::Server { + public: + Server(const std::string &host, unsigned short &port); + virtual ~Server(void); + + std::unique_ptr<TCP> AcceptPeer(void); + + int GetHandle(void); + unsigned short GetPort(void); + + private: + int handle; + sockaddr *addr; + socklen_t addrlen; +}; diff --git a/modules/tcp/samples/CMakeLists.txt b/modules/tcp/samples/CMakeLists.txt new file mode 100644 index 0000000..8fd1b4b --- /dev/null +++ b/modules/tcp/samples/CMakeLists.txt @@ -0,0 +1,3 @@ +ADD_EXECUTABLE("aitt_tcp_test" tcp_test.cc $<TARGET_OBJECTS:TCP_OBJ>) +TARGET_LINK_LIBRARIES("aitt_tcp_test" ${PROJECT_NAME} Threads::Threads ${AITT_NEEDS_LIBRARIES}) +INSTALL(TARGETS "aitt_tcp_test" DESTINATION ${AITT_TEST_BINDIR}) diff --git a/modules/tcp/samples/tcp_test.cc b/modules/tcp/samples/tcp_test.cc new file mode 100644 index 0000000..d319e27 --- /dev/null +++ b/modules/tcp/samples/tcp_test.cc @@ -0,0 +1,235 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <TCP.h> +#include <TCPServer.h> +#include <getopt.h> +#include <glib.h> + +#include <functional> +#include <iostream> +#include <memory> +#include <string> + +//#define _LOG_WITH_TIMESTAMP +#include "aitt_internal.h" +#ifdef _LOG_WITH_TIMESTAMP +__thread __aitt__tls__ __aitt; +#endif + +#define HELLO_STRING "hello" +#define BYE_STRING "bye" +#define SEND_INTERVAL 1000 + +class AittTcpSample { + public: + AittTcpSample(const std::string &host, unsigned short &port) + : server(std::make_unique<TCP::Server>(host, port)) + { + } + virtual ~AittTcpSample(void) {} + + std::unique_ptr<TCP::Server> server; +}; + +int main(int argc, char *argv[]) +{ + const option opts[] = { + { + .name = "server", + .has_arg = 0, + .flag = nullptr, + .val = 's', + }, + { + .name = "host", + .has_arg = 1, + .flag = nullptr, + .val = 'h', + }, + { + .name = "port", + .has_arg = 1, + .flag = nullptr, + .val = 'p', + }, + }; + int c; + int idx; + bool isServer = false; + std::string host = "127.0.0.1"; + unsigned short port = 0; + + while ((c = getopt_long(argc, argv, "sh:up:", opts, &idx)) != -1) { + switch (c) { + case 's': + isServer = true; + break; + case 'h': + host = optarg; + break; + case 'p': + port = std::stoi(optarg); + break; + default: + break; + } + } + + INFO("Host[%s] port[%u]", host.c_str(), port); + + struct EventData { + GSource source_; + GPollFD fd; + AittTcpSample *sample; + }; + + guint timeoutId = 0; + GSource *src = nullptr; + EventData *ed = nullptr; + + GMainLoop *mainLoop = g_main_loop_new(nullptr, FALSE); + if (!mainLoop) { + ERR("Failed to create a main loop"); + return 1; + } + + // Handling the server/client events + if (isServer) { + GSourceFuncs srcs = { + [](GSource *src, gint *timeout) -> gboolean { + *timeout = 1; + return FALSE; + }, + [](GSource *src) -> gboolean { + EventData *ed = reinterpret_cast<EventData *>(src); + RETV_IF(ed == nullptr, FALSE); + + if ((ed->fd.revents & G_IO_IN) == G_IO_IN) + return TRUE; + if ((ed->fd.revents & G_IO_ERR) == G_IO_ERR) + return TRUE; + + return FALSE; + }, + [](GSource *src, GSourceFunc callback, gpointer user_data) -> gboolean { + EventData *ed = reinterpret_cast<EventData *>(src); + RETV_IF(ed == nullptr, FALSE); + + if ((ed->fd.revents & G_IO_ERR) == G_IO_ERR) { + ERR("Error!"); + return FALSE; + } + + std::unique_ptr<TCP> peer = ed->sample->server->AcceptPeer(); + + INFO("Assigned port: %u, %u", ed->sample->server->GetPort(), peer->GetPort()); + std::string peerHost; + unsigned short peerPort = 0; + peer->GetPeerInfo(peerHost, peerPort); + INFO("Peer Info: %s %u", peerHost.c_str(), peerPort); + + char buffer[10]; + void *ptr = static_cast<void *>(buffer); + size_t szData = sizeof(HELLO_STRING); + peer->Recv(ptr, szData); + INFO("Gots[%s]", buffer); + + szData = sizeof(BYE_STRING); + peer->Send(BYE_STRING, szData); + INFO("Reply to client[%s]", BYE_STRING); + + return TRUE; + }, + nullptr, + }; + + src = g_source_new(&srcs, sizeof(EventData)); + if (!src) { + g_main_loop_unref(mainLoop); + ERR("g_source_new failed"); + return 1; + } + + ed = reinterpret_cast<EventData *>(src); + + try { + ed->sample = new AittTcpSample(host, port); + } catch (std::exception &e) { + ERR("new: %s", e.what()); + g_source_unref(src); + g_main_loop_unref(mainLoop); + return 1; + } + + INFO("host: %s, port: %u", host.c_str(), port); + + ed->fd.fd = ed->sample->server->GetHandle(); + ed->fd.events = G_IO_IN | G_IO_ERR; + g_source_add_poll(src, &ed->fd); + guint id = g_source_attach(src, g_main_loop_get_context(mainLoop)); + g_source_unref(src); + if (id == 0) { + delete ed->sample; + g_source_destroy(src); + g_main_loop_unref(mainLoop); + return 1; + } + } else { + static struct Main { + const std::string &host; + unsigned short port; + } main_data = { + .host = host, + .port = port, + }; + // Now the server is ready. + // Let's create a new client and communicate with the server within every + // SEND_INTERTVAL + timeoutId = g_timeout_add( + SEND_INTERVAL, + [](gpointer data) -> gboolean { + Main *ctx = static_cast<Main *>(data); + std::unique_ptr<TCP> client(std::make_unique<TCP>(ctx->host, ctx->port)); + + INFO("Assigned client port: %u", client->GetPort()); + + INFO("Send[%s]", HELLO_STRING); + size_t szBuffer = sizeof(HELLO_STRING); + client->Send(HELLO_STRING, szBuffer); + + char buffer[10]; + void *ptr = static_cast<void *>(buffer); + szBuffer = sizeof(BYE_STRING); + client->Recv(ptr, szBuffer); + INFO("Replied with[%s]", buffer); + + // Send oneshot message, and disconnect from the server + return TRUE; + }, + &main_data); + } + + g_main_loop_run(mainLoop); + + if (src) { + delete ed->sample; + g_source_destroy(src); + } + if (timeoutId) + g_source_remove(timeoutId); + g_main_loop_unref(mainLoop); + return 0; +} diff --git a/modules/tcp/tests/CMakeLists.txt b/modules/tcp/tests/CMakeLists.txt new file mode 100644 index 0000000..bf1adf1 --- /dev/null +++ b/modules/tcp/tests/CMakeLists.txt @@ -0,0 +1,19 @@ +PKG_CHECK_MODULES(UT_NEEDS REQUIRED gmock_main) +INCLUDE_DIRECTORIES(${UT_NEEDS_INCLUDE_DIRS}) +LINK_DIRECTORIES(${UT_NEEDS_LIBRARY_DIRS}) + +SET(AITT_TCP_UT ${PROJECT_NAME}_tcp_ut) + +SET(AITT_TCP_UT_SRC TCP_test.cc TCPServer_test.cc) + +ADD_EXECUTABLE(${AITT_TCP_UT} ${AITT_TCP_UT_SRC} $<TARGET_OBJECTS:TCP_OBJ>) +TARGET_LINK_LIBRARIES(${AITT_TCP_UT} ${UT_NEEDS_LIBRARIES} Threads::Threads ${AITT_NEEDS_LIBRARIES}) +INSTALL(TARGETS ${AITT_TCP_UT} DESTINATION ${AITT_TEST_BINDIR}) + +ADD_TEST( + NAME + ${AITT_TCP_UT} + COMMAND + ${CMAKE_COMMAND} -E env + ${CMAKE_CURRENT_BINARY_DIR}/${AITT_TCP_UT} --gtest_filter=*_Anytime +) diff --git a/modules/tcp/tests/TCPServer_test.cc b/modules/tcp/tests/TCPServer_test.cc new file mode 100644 index 0000000..e8b48b1 --- /dev/null +++ b/modules/tcp/tests/TCPServer_test.cc @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "../TCPServer.h" + +#include <gtest/gtest.h> + +#include <condition_variable> +#include <cstring> +#include <memory> +#include <mutex> +#include <thread> + +#define TEST_SERVER_ADDRESS "127.0.0.1" +#define TEST_SERVER_INVALID_ADDRESS "287.0.0.1" +#define TEST_SERVER_PORT 8123 +#define TEST_SERVER_AVAILABLE_PORT 0 + +TEST(TCPServer, Positive_Create_Anytime) +{ + unsigned short port = TEST_SERVER_PORT; + std::unique_ptr<TCP::Server> tcp(std::make_unique<TCP::Server>(TEST_SERVER_ADDRESS, port)); + ASSERT_NE(tcp, nullptr); +} + +TEST(TCPServer, Negative_Create_Anytime) +{ + try { + unsigned short port = TEST_SERVER_PORT; + + std::unique_ptr<TCP::Server> tcp( + std::make_unique<TCP::Server>(TEST_SERVER_INVALID_ADDRESS, port)); + ASSERT_EQ(tcp, nullptr); + } catch (std::exception &e) { + ASSERT_STREQ(e.what(), strerror(EINVAL)); + } +} + +TEST(TCPServer, Positive_Create_AutoPort_Anytime) +{ + unsigned short port = TEST_SERVER_AVAILABLE_PORT; + std::unique_ptr<TCP::Server> tcp(std::make_unique<TCP::Server>(TEST_SERVER_ADDRESS, port)); + ASSERT_NE(tcp, nullptr); + ASSERT_NE(port, 0); +} + +TEST(TCPServer, Positive_GetPort_Anytime) +{ + unsigned short port = TEST_SERVER_PORT; + std::unique_ptr<TCP::Server> tcp(std::make_unique<TCP::Server>(TEST_SERVER_ADDRESS, port)); + ASSERT_NE(tcp, nullptr); + ASSERT_EQ(tcp->GetPort(), TEST_SERVER_PORT); +} + +TEST(TCPServer, Positive_GetHandle_Anytime) +{ + unsigned short port = TEST_SERVER_PORT; + std::unique_ptr<TCP::Server> tcp(std::make_unique<TCP::Server>(TEST_SERVER_ADDRESS, port)); + ASSERT_NE(tcp, nullptr); + ASSERT_GE(tcp->GetHandle(), 0); +} + +TEST(TCPServer, Positive_GetPort_AutoPort_Anytime) +{ + unsigned short port = TEST_SERVER_AVAILABLE_PORT; + std::unique_ptr<TCP::Server> tcp(std::make_unique<TCP::Server>(TEST_SERVER_ADDRESS, port)); + ASSERT_NE(tcp, nullptr); + ASSERT_EQ(tcp->GetPort(), port); +} + +TEST(TCPServer, Positive_AcceptPeer_Anytime) +{ + std::mutex m; + std::condition_variable ready_cv; + std::condition_variable connected_cv; + bool ready = false; + bool connected = false; + + unsigned short serverPort = TEST_SERVER_PORT; + std::thread serverThread( + [serverPort, &m, &ready, &connected, &ready_cv, &connected_cv](void) mutable -> void { + std::unique_ptr<TCP::Server> tcp( + std::make_unique<TCP::Server>(TEST_SERVER_ADDRESS, serverPort)); + { + std::lock_guard<std::mutex> lk(m); + ready = true; + } + ready_cv.notify_one(); + + std::unique_ptr<TCP> peer = tcp->AcceptPeer(); + { + std::lock_guard<std::mutex> lk(m); + connected = !!peer; + } + connected_cv.notify_one(); + }); + + { + std::unique_lock<std::mutex> lk(m); + ready_cv.wait(lk, [&ready] { return ready; }); + std::unique_ptr<TCP> tcp(std::make_unique<TCP>(TEST_SERVER_ADDRESS, serverPort)); + connected_cv.wait(lk, [&connected] { return connected; }); + } + + serverThread.join(); + + ASSERT_EQ(ready, true); + ASSERT_EQ(connected, true); +} diff --git a/modules/tcp/tests/TCP_test.cc b/modules/tcp/tests/TCP_test.cc new file mode 100644 index 0000000..604bd23 --- /dev/null +++ b/modules/tcp/tests/TCP_test.cc @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <gtest/gtest.h> + +#include <condition_variable> +#include <cstring> +#include <functional> +#include <memory> +#include <mutex> +#include <thread> + +#include "TCPServer.h" + +#define TEST_SERVER_ADDRESS "127.0.0.1" +#define TEST_SERVER_INVALID_ADDRESS "287.0.0.1" +#define TEST_SERVER_PORT 8123 +#define TEST_SERVER_AVAILABLE_PORT 0 +#define TEST_BUFFER_SIZE 256 +#define TEST_BUFFER_HELLO "Hello World" +#define TEST_BUFFER_BYE "Good Bye" + +class TCPTest : public testing::Test { + protected: + void SetUp() override + { + ready = false; + serverPort = TEST_SERVER_PORT; + customTest = [](void) {}; + + clientThread = std::thread([this](void) mutable -> void { + std::unique_lock<std::mutex> lk(m); + ready_cv.wait(lk, [this] { return ready; }); + client = std::make_unique<TCP>(TEST_SERVER_ADDRESS, serverPort); + + customTest(); + }); + } + + void RunServer(void) + { + tcp = std::make_unique<TCP::Server>(TEST_SERVER_ADDRESS, serverPort); + { + std::lock_guard<std::mutex> lk(m); + ready = true; + } + ready_cv.notify_one(); + + peer = tcp->AcceptPeer(); + } + + void TearDown() override { clientThread.join(); } + + protected: + std::mutex m; + std::condition_variable ready_cv; + bool ready; + unsigned short serverPort; + std::thread clientThread; + std::unique_ptr<TCP::Server> tcp; + std::unique_ptr<TCP> peer; + std::unique_ptr<TCP> client; + std::function<void(void)> customTest; +}; + +TEST(TCP, Negative_Create_InvalidPort_Anytime) +{ + try { + std::unique_ptr<TCP> tcp( + std::make_unique<TCP>(TEST_SERVER_ADDRESS, TEST_SERVER_AVAILABLE_PORT)); + ASSERT_EQ(tcp, nullptr); + } catch (std::exception &e) { + ASSERT_STREQ(e.what(), strerror(EINVAL)); + } +} + +TEST(TCP, Negative_Create_InvalidAddress_Anytime) +{ + try { + std::unique_ptr<TCP> tcp( + std::make_unique<TCP>(TEST_SERVER_INVALID_ADDRESS, TEST_SERVER_PORT)); + ASSERT_EQ(tcp, nullptr); + } catch (std::exception &e) { + ASSERT_STREQ(e.what(), strerror(EINVAL)); + } +} + +TEST_F(TCPTest, Positive_GetPeerInfo_Anytime) +{ + std::string peerHost; + unsigned short peerPort = 0; + + RunServer(); + + peer->GetPeerInfo(peerHost, peerPort); + ASSERT_STREQ(peerHost.c_str(), TEST_SERVER_ADDRESS); + ASSERT_GT(peerPort, 0); +} + +TEST_F(TCPTest, Positive_GetHandle_Anytime) +{ + RunServer(); + int handle = peer->GetHandle(); + ASSERT_GE(handle, 0); +} + +TEST_F(TCPTest, Positive_GetPort_Anytime) +{ + RunServer(); + unsigned short port = peer->GetPort(); + ASSERT_GT(port, 0); +} + +TEST_F(TCPTest, Positive_SendRecv_Anytime) +{ + char helloBuffer[TEST_BUFFER_SIZE]; + char byeBuffer[TEST_BUFFER_SIZE]; + + customTest = [this, &helloBuffer](void) mutable -> void { + size_t szData = sizeof(helloBuffer); + client->Recv(static_cast<void *>(helloBuffer), szData); + + szData = sizeof(TEST_BUFFER_BYE); + client->Send(TEST_BUFFER_BYE, szData); + }; + + RunServer(); + + size_t szMsg = sizeof(TEST_BUFFER_HELLO); + peer->Send(TEST_BUFFER_HELLO, szMsg); + + szMsg = sizeof(byeBuffer); + peer->Recv(static_cast<void *>(byeBuffer), szMsg); + + ASSERT_STREQ(helloBuffer, TEST_BUFFER_HELLO); + ASSERT_STREQ(byeBuffer, TEST_BUFFER_BYE); +} diff --git a/modules/webrtc/CMakeLists.txt b/modules/webrtc/CMakeLists.txt new file mode 100644 index 0000000..9452b2b --- /dev/null +++ b/modules/webrtc/CMakeLists.txt @@ -0,0 +1,24 @@ +SET(AITT_WEBRTC aitt-transport-webrtc) + +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) + +PKG_CHECK_MODULES(AITT_WEBRTC_NEEDS REQUIRED + capi-media-camera + capi-media-webrtc + json-glib-1.0 +) +INCLUDE_DIRECTORIES(${AITT_WEBRTC_NEEDS_INCLUDE_DIRS}) +LINK_DIRECTORIES(${AITT_WEBRTC_NEEDS_LIBRARY_DIRS}) + +FILE(GLOB AITT_WEBRTC_SRC *.cc) +list(REMOVE_ITEM AITT_WEBRTC_SRC ${CMAKE_CURRENT_SOURCE_DIR}/Module.cc) +ADD_LIBRARY(WEBRTC_OBJ OBJECT ${AITT_WEBRTC_SRC}) +ADD_LIBRARY(${AITT_WEBRTC} SHARED $<TARGET_OBJECTS:WEBRTC_OBJ> ../main.cc Module.cc) +TARGET_LINK_LIBRARIES(${AITT_WEBRTC} ${AITT_WEBRTC_NEEDS_LIBRARIES} ${AITT_COMMON}) +TARGET_COMPILE_OPTIONS(${AITT_WEBRTC} PUBLIC ${AITT_WEBRTC_NEEDS_CFLAGS_OTHER}) + +INSTALL(TARGETS ${AITT_WEBRTC} DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +IF(BUILD_TESTING) + ADD_SUBDIRECTORY(tests) +ENDIF(BUILD_TESTING) diff --git a/modules/webrtc/CameraHandler.cc b/modules/webrtc/CameraHandler.cc new file mode 100644 index 0000000..c3fc8ec --- /dev/null +++ b/modules/webrtc/CameraHandler.cc @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "CameraHandler.h" + +#include "aitt_internal.h" + +#define RETURN_DEFINED_NAME_AS_STRING(defined_constant) \ + case defined_constant: \ + return #defined_constant; + +CameraHandler::~CameraHandler(void) +{ + if (handle_) { + camera_state_e state = CAMERA_STATE_NONE; + + int ret = camera_get_state(handle_, &state); + if (ret != CAMERA_ERROR_NONE) { + ERR("camera_get_state() Fail(%s)", ErrorToString(ret)); + } + + if (state == CAMERA_STATE_PREVIEW) { + INFO("CameraHandler preview is not stopped (stop)"); + ret = camera_stop_preview(handle_); + if (ret != CAMERA_ERROR_NONE) { + ERR("camera_stop_preview() Fail(%s)", ErrorToString(ret)); + } + } + } + + if (handle_) + camera_destroy(handle_); +} + +int CameraHandler::Init(const MediaPacketPreviewCallback &preview_cb, void *user_data) +{ + int ret = camera_create(CAMERA_DEVICE_CAMERA0, &handle_); + if (ret != CAMERA_ERROR_NONE) { + ERR("camera_create() Fail(%s)", ErrorToString(ret)); + return -1; + } + SettingCamera(preview_cb, user_data); + + return 0; +} + +void CameraHandler::SettingCamera(const MediaPacketPreviewCallback &preview_cb, void *user_data) +{ + int ret = camera_set_media_packet_preview_cb(handle_, CameraPreviewCB, this); + if (ret != CAMERA_ERROR_NONE) { + ERR("camera_set_media_packet_preview_cb() Fail(%s)", ErrorToString(ret)); + return; + } + media_packet_preview_cb_ = preview_cb; + user_data_ = user_data; +} + +void CameraHandler::Deinit(void) +{ + if (!handle_) { + ERR("Handler is nullptr"); + return; + } + + is_started_ = false; + media_packet_preview_cb_ = nullptr; + user_data_ = nullptr; +} + +int CameraHandler::StartPreview(void) +{ + camera_state_e state; + int ret = camera_get_state(handle_, &state); + if (ret != CAMERA_ERROR_NONE) { + ERR("camera_get_state() Fail(%s)", ErrorToString(ret)); + return -1; + } + + if (state == CAMERA_STATE_PREVIEW) { + INFO("Preview is already started"); + is_started_ = true; + return 0; + } + + ret = camera_start_preview(handle_); + if (ret != CAMERA_ERROR_NONE) { + ERR("camera_start_preview() Fail(%s)", ErrorToString(ret)); + return -1; + } + + is_started_ = true; + + return 0; +} + +int CameraHandler::StopPreview(void) +{ + RETV_IF(handle_ == nullptr, -1); + is_started_ = false; + + return 0; +} + +void CameraHandler::CameraPreviewCB(media_packet_h media_packet, void *user_data) +{ + auto camera_handler = static_cast<CameraHandler *>(user_data); + if (!camera_handler) { + ERR("Invalid user_data"); + return; + } + + if (!camera_handler->is_started_) { + ERR("Preveiw is not started yet"); + return; + } + + if (!camera_handler->media_packet_preview_cb_) { + ERR("Preveiw cb is not set"); + return; + } + + camera_handler->media_packet_preview_cb_(media_packet, camera_handler->user_data_); +} + +const char *CameraHandler::ErrorToString(const int error) +{ + switch (error) { + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_NONE) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_INVALID_PARAMETER) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_INVALID_STATE) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_OUT_OF_MEMORY) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_DEVICE) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_INVALID_OPERATION) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_SECURITY_RESTRICTED) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_DEVICE_BUSY) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_DEVICE_NOT_FOUND) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_ESD) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_PERMISSION_DENIED) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_NOT_SUPPORTED) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_RESOURCE_CONFLICT) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_ERROR_SERVICE_DISCONNECTED) + } + + return "Unknown error"; +} + +const char *CameraHandler::StateToString(const camera_state_e state) +{ + switch (state) { + RETURN_DEFINED_NAME_AS_STRING(CAMERA_STATE_NONE) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_STATE_CREATED) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_STATE_PREVIEW) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_STATE_CAPTURING) + RETURN_DEFINED_NAME_AS_STRING(CAMERA_STATE_CAPTURED) + } + + return "Unknown state"; +} diff --git a/modules/webrtc/CameraHandler.h b/modules/webrtc/CameraHandler.h new file mode 100644 index 0000000..5c44828 --- /dev/null +++ b/modules/webrtc/CameraHandler.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include <camera.h> + +#include <functional> + +class CameraHandler { + public: + using MediaPacketPreviewCallback = std::function<void(media_packet_h, void *)>; + + ~CameraHandler(); + int Init(const MediaPacketPreviewCallback &preview_cb, void *user_data); + void Deinit(void); + int StartPreview(void); + int StopPreview(void); + + static const char *ErrorToString(const int error); + static const char *StateToString(const camera_state_e state); + + private: + void SettingCamera(const MediaPacketPreviewCallback &preview_cb, void *user_data); + static void CameraPreviewCB(media_packet_h media_packet, void *user_data); + + camera_h handle_; + bool is_started_; + MediaPacketPreviewCallback media_packet_preview_cb_; + void *user_data_; +}; diff --git a/modules/webrtc/Config.h b/modules/webrtc/Config.h new file mode 100644 index 0000000..63dbd4b --- /dev/null +++ b/modules/webrtc/Config.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <string> + +class Config { + public: + Config() : disable_ssl_(false), broker_port_(0), user_data_length_(0) {}; + Config(const std::string &id, const std::string &broker_ip, int broker_port, + const std::string &room_id, const std::string &source_id = std::string("")) + : local_id_(id), + room_id_(room_id), + source_id_(source_id), + disable_ssl_(false), + broker_ip_(broker_ip), + broker_port_(broker_port), + user_data_length_(0){}; + std::string GetLocalId(void) const { return local_id_; }; + void SetLocalId(const std::string &local_id) { local_id_ = local_id; }; + std::string GetRoomId(void) const { return room_id_; }; + void SetRoomId(const std::string &room_id) { room_id_ = room_id; }; + std::string GetSourceId(void) const { return source_id_; }; + void SetSourceId(const std::string &source_id) { source_id_ = source_id; }; + void SetSignalingServerUrl(const std::string &signaling_server_url) + { + signaling_server_url_ = signaling_server_url; + }; + std::string GetSignalingServerUrl(void) const { return signaling_server_url_; }; + void SetDisableSSl(bool disable_ssl) { disable_ssl_ = disable_ssl; }; + bool GetDisableSSl(void) const { return disable_ssl_; }; + std::string GetBrokerIp(void) const { return broker_ip_; }; + void SetBrokerIp(const std::string &broker_ip) { broker_ip_ = broker_ip; }; + int GetBrokerPort(void) const { return broker_port_; }; + void SetBrokerPort(int port) { broker_port_ = port; }; + unsigned int GetUserDataLength(void) const { return user_data_length_; }; + void SetUserDataLength(unsigned int user_data_length) { user_data_length_ = user_data_length; }; + + private: + std::string local_id_; + std::string room_id_; + std::string source_id_; + std::string signaling_server_url_; + bool disable_ssl_; + std::string broker_ip_; + int broker_port_; + unsigned int user_data_length_; +}; diff --git a/modules/webrtc/IfaceServer.h b/modules/webrtc/IfaceServer.h new file mode 100644 index 0000000..ad6d36d --- /dev/null +++ b/modules/webrtc/IfaceServer.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <functional> +#include <vector> + +class IfaceServer { + public: + enum class ConnectionState { + Disconnected, + Connecting, + Connected, + Registering, + Registered, + }; + + virtual ~IfaceServer(){}; + virtual void SetConnectionStateChangedCb( + std::function<void(ConnectionState)> connection_state_changed_cb) = 0; + virtual void UnsetConnectionStateChangedCb(void) = 0; + virtual int Connect(void) = 0; + virtual int Disconnect(void) = 0; + virtual bool IsConnected(void) = 0; + virtual int SendMessage(const std::string &peer_id, const std::string &message) = 0; +}; diff --git a/modules/webrtc/Module.cc b/modules/webrtc/Module.cc new file mode 100644 index 0000000..3c9e4f8 --- /dev/null +++ b/modules/webrtc/Module.cc @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Module.h" + +#include <flatbuffers/flexbuffers.h> + +#include "Config.h" +#include "aitt_internal.h" + +Module::Module(const std::string &ip, AittDiscovery &discovery) : AittTransport(discovery) +{ +} + +Module::~Module(void) +{ +} + +void Module::Publish(const std::string &topic, const void *data, const size_t datalen, + const std::string &correlation, AittQoS qos, bool retain) +{ + // TODO +} + +void Module::Publish(const std::string &topic, const void *data, const size_t datalen, AittQoS qos, + bool retain) +{ + std::lock_guard<std::mutex> publish_table_lock(publish_table_lock_); + + auto config = BuildConfigFromFb(data, datalen); + if (config.GetUserDataLength()) { + publish_table_[topic] = + std::make_shared<PublishStream>(topic, BuildConfigFromFb(data, datalen)); + + publish_table_[topic]->Start(); + } else { + auto publish_table_itr = publish_table_.find(topic); + if (publish_table_itr == publish_table_.end()) { + ERR("%s not found", topic.c_str()); + return; + } + auto publish_stream = publish_table_itr->second; + publish_stream->Stop(); + publish_table_.erase(publish_table_itr); + } +} + +void *Module::Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb, + void *cbdata, AittQoS qos) +{ + return nullptr; +} + +void *Module::Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb, + const void *data, const size_t datalen, void *cbdata, AittQoS qos) +{ + std::lock_guard<std::mutex> subscribe_table_lock(subscribe_table_lock_); + + subscribe_table_[topic] = + std::make_shared<SubscribeStream>(topic, BuildConfigFromFb(data, datalen)); + + subscribe_table_[topic]->Start(qos == AITT_QOS_EXACTLY_ONCE, cbdata); + + return subscribe_table_[topic].get(); +} + +Config Module::BuildConfigFromFb(const void *data, const size_t data_size) +{ + Config config; + auto webrtc_configs = + flexbuffers::GetRoot(static_cast<const uint8_t *>(data), data_size).AsMap(); + auto webrtc_config_keys = webrtc_configs.Keys(); + for (size_t idx = 0; idx < webrtc_config_keys.size(); ++idx) { + std::string key = webrtc_config_keys[idx].AsString().c_str(); + + if (key.compare("Id") == 0) + config.SetLocalId(webrtc_configs[key].AsString().c_str()); + else if (key.compare("RoomId") == 0) + config.SetRoomId(webrtc_configs[key].AsString().c_str()); + else if (key.compare("SourceId") == 0) + config.SetSourceId(webrtc_configs[key].AsString().c_str()); + else if (key.compare("BrokerIp") == 0) + config.SetBrokerIp(webrtc_configs[key].AsString().c_str()); + else if (key.compare("BrokerPort") == 0) + config.SetBrokerPort(webrtc_configs[key].AsInt32()); + else if (key.compare("UserDataLength") == 0) + config.SetUserDataLength(webrtc_configs[key].AsUInt32()); + else { + printf("Not supported key name: %s\n", key.c_str()); + } + } + + return config; +} + +void *Module::Unsubscribe(void *handlePtr) +{ + void *ret = nullptr; + std::string topic; + std::lock_guard<std::mutex> subscribe_table_lock(subscribe_table_lock_); + for (auto itr = subscribe_table_.begin(); itr != subscribe_table_.end(); ++itr) { + if (itr->second.get() == handlePtr) { + auto topic = itr->first; + break; + } + } + + if (topic.size() != 0) { + ret = subscribe_table_[topic]->Stop(); + subscribe_table_.erase(topic); + } + + return ret; +} diff --git a/modules/webrtc/Module.h b/modules/webrtc/Module.h new file mode 100644 index 0000000..ca31eb8 --- /dev/null +++ b/modules/webrtc/Module.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <AittTransport.h> +#include <MainLoopHandler.h> + +#include <map> +#include <memory> +#include <mutex> +#include <set> +#include <string> +#include <thread> + +#include "PublishStream.h" +#include "SubscribeStream.h" + +using AittTransport = aitt::AittTransport; +using MainLoopHandler = aitt::MainLoopHandler; +using AittDiscovery = aitt::AittDiscovery; + +class Module : public AittTransport { + public: + explicit Module(const std::string &ip, AittDiscovery &discovery); + virtual ~Module(void); + + // TODO: How about regarding topic as service name? + void Publish(const std::string &topic, const void *data, const size_t datalen, + const std::string &correlation, AittQoS qos = AITT_QOS_AT_MOST_ONCE, + bool retain = false) override; + + void Publish(const std::string &topic, const void *data, const size_t datalen, + AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override; + + // TODO: How about regarding topic as service name? + void *Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb, + void *cbdata = nullptr, AittQoS qos = AITT_QOS_AT_MOST_ONCE) override; + + void *Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb, + const void *data, const size_t datalen, void *cbdata = nullptr, + AittQoS qos = AITT_QOS_AT_MOST_ONCE) override; + + void *Unsubscribe(void *handle) override; + + private: + Config BuildConfigFromFb(const void *data, const size_t data_size); + + std::map<std::string, std::shared_ptr<PublishStream>> publish_table_; + std::mutex publish_table_lock_; + std::map<std::string, std::shared_ptr<SubscribeStream>> subscribe_table_; + std::mutex subscribe_table_lock_; +}; diff --git a/modules/webrtc/MqttServer.cc b/modules/webrtc/MqttServer.cc new file mode 100644 index 0000000..70a07ed --- /dev/null +++ b/modules/webrtc/MqttServer.cc @@ -0,0 +1,289 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MqttServer.h" + +#include "aitt_internal.h" + +#define MQTT_HANDLER_MSG_QOS 1 +#define MQTT_HANDLER_MGMT_QOS 2 + +MqttServer::MqttServer(const Config &config) : mq(config.GetLocalId(), true) +{ + broker_ip_ = config.GetBrokerIp(); + broker_port_ = config.GetBrokerPort(); + id_ = config.GetLocalId(); + room_id_ = config.GetRoomId(); + source_id_ = config.GetSourceId(); + is_publisher_ = (id_ == source_id_); + + DBG("ID[%s] BROKER IP[%s] BROKER PORT [%d] ROOM[%s] %s", id_.c_str(), broker_ip_.c_str(), + broker_port_, room_id_.c_str(), is_publisher_ ? "Publisher" : "Subscriber"); + + mq.SetConnectionCallback(std::bind(&MqttServer::ConnectCallBack, this, std::placeholders::_1)); +} + +MqttServer::~MqttServer() +{ + // Prevent to call below callbacks after destructoring + connection_state_changed_cb_ = nullptr; + room_message_arrived_cb_ = nullptr; +} + +void MqttServer::SetConnectionState(ConnectionState state) +{ + connection_state_ = state; + if (connection_state_changed_cb_) + connection_state_changed_cb_(state); +} + +void MqttServer::ConnectCallBack(int status) +{ + if (status == AITT_CONNECTED) + OnConnect(); + else + OnDisconnect(); +} + +void MqttServer::OnConnect() +{ + INFO("Connected to signalling server"); + + // Sometimes it seems that broker is silently disconnected/reconnected + if (GetConnectionState() != ConnectionState::Connecting) { + ERR("Invalid status"); + return; + } + + SetConnectionState(ConnectionState::Connected); + SetConnectionState(ConnectionState::Registering); + try { + RegisterWithServer(); + } catch (const std::runtime_error &e) { + ERR("%s", e.what()); + SetConnectionState(ConnectionState::Connected); + } +} + +void MqttServer::OnDisconnect() +{ + INFO("mosquitto disconnected"); + + SetConnectionState(ConnectionState::Disconnected); + // TODO +} + +void MqttServer::RegisterWithServer(void) +{ + if (connection_state_ != IfaceServer::ConnectionState::Registering) { + ERR("Invaild status(%d)", (int)connection_state_); + throw std::runtime_error("Invalid status"); + } + + // Notify Who is source? + std::string source_topic = room_id_ + std::string("/source"); + if (is_publisher_) { + mq.Publish(source_topic, id_.c_str(), id_.size(), AITT_QOS_EXACTLY_ONCE, true); + SetConnectionState(ConnectionState::Registered); + } else { + mq.Subscribe(source_topic, + std::bind(&MqttServer::HandleSourceTopic, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, + std::placeholders::_5), + nullptr, AITT_QOS_EXACTLY_ONCE); + } +} + +void MqttServer::HandleSourceTopic(aitt::MSG *msg, const std::string &topic, const void *data, + const size_t datalen, void *user_data) +{ + INFO("Source topic"); + if (connection_state_ != IfaceServer::ConnectionState::Registering) { + ERR("Invaild status(%d)", (int)connection_state_); + return; + } + + if (is_publisher_) { + ERR("Ignore"); + } else { + std::string message(static_cast<const char *>(data), datalen); + INFO("Set source ID %s", message.c_str()); + SetSourceId(message); + SetConnectionState(ConnectionState::Registered); + } +} + +bool MqttServer::IsConnected(void) +{ + INFO("%s", __func__); + + return connection_state_ == IfaceServer::ConnectionState::Registered; +} + +int MqttServer::Connect(void) +{ + std::string will_message = std::string("ROOM_PEER_LEFT ") + id_; + mq.SetWillInfo(room_id_, will_message.c_str(), will_message.size(), AITT_QOS_EXACTLY_ONCE, + false); + + SetConnectionState(ConnectionState::Connecting); + mq.Connect(broker_ip_, broker_port_, std::string(), std::string()); + + return 0; +} + +int MqttServer::Disconnect(void) +{ + if (is_publisher_) { + INFO("remove retained"); + std::string source_topic = room_id_ + std::string("/source"); + mq.Publish(source_topic, nullptr, 0, AITT_QOS_AT_LEAST_ONCE, true); + } + + std::string left_message = std::string("ROOM_PEER_LEFT ") + id_; + mq.Publish(room_id_, left_message.c_str(), left_message.size(), AITT_QOS_AT_LEAST_ONCE, false); + + mq.Disconnect(); + + room_id_ = std::string(""); + + SetConnectionState(ConnectionState::Disconnected); + return 0; +} + +int MqttServer::SendMessage(const std::string &peer_id, const std::string &msg) +{ + if (room_id_.empty()) { + ERR("Invaild status"); + return -1; + } + if (peer_id.size() == 0 || msg.size() == 0) { + ERR("Invalid parameter"); + return -1; + } + + std::string receiver_topic = room_id_ + std::string("/") + peer_id; + std::string server_formatted_msg = "ROOM_PEER_MSG " + id_ + " " + msg; + mq.Publish(receiver_topic, server_formatted_msg.c_str(), server_formatted_msg.size(), + AITT_QOS_AT_LEAST_ONCE); + + return 0; +} + +std::string MqttServer::GetConnectionStateStr(ConnectionState state) +{ + std::string state_str; + switch (state) { + case IfaceServer::ConnectionState::Disconnected: { + state_str = std::string("Disconnected"); + break; + } + case IfaceServer::ConnectionState::Connecting: { + state_str = std::string("Connecting"); + break; + } + case IfaceServer::ConnectionState::Connected: { + state_str = std::string("Connected"); + break; + } + case IfaceServer::ConnectionState::Registering: { + state_str = std::string("Registering"); + break; + } + case IfaceServer::ConnectionState::Registered: { + state_str = std::string("Registered"); + break; + } + } + + return state_str; +} + +void MqttServer::JoinRoom(const std::string &room_id) +{ + if (room_id.empty() || room_id != room_id_) { + ERR("Invaild room id"); + throw std::runtime_error(std::string("Invalid room_id")); + } + + // Subscribe PEER_JOIN PEER_LEFT + mq.Subscribe(room_id_, + std::bind(&MqttServer::HandleRoomTopic, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, + std::placeholders::_5), + nullptr, AITT_QOS_EXACTLY_ONCE); + + // Subscribe PEER_MSG + std::string receiving_topic = room_id + std::string("/") + id_; + mq.Subscribe(receiving_topic, + std::bind(&MqttServer::HandleMessageTopic, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, + std::placeholders::_5), + nullptr, AITT_QOS_AT_LEAST_ONCE); + + INFO("Subscribe room topics"); + + if (!is_publisher_) { + std::string join_message = std::string("ROOM_PEER_JOINED ") + id_; + mq.Publish(room_id_, join_message.c_str(), join_message.size(), AITT_QOS_EXACTLY_ONCE); + } +} + +void MqttServer::HandleRoomTopic(aitt::MSG *msg, const std::string &topic, const void *data, + const size_t datalen, void *user_data) +{ + std::string message(static_cast<const char *>(data), datalen); + INFO("Room topic(%s, %s)", topic.c_str(), message.c_str()); + + std::string peer_id; + if (message.compare(0, 16, "ROOM_PEER_JOINED") == 0) { + peer_id = message.substr(17, std::string::npos); + } else if (message.compare(0, 14, "ROOM_PEER_LEFT") == 0) { + peer_id = message.substr(15, std::string::npos); + } else { + ERR("Invalid type of Room message %s", message.c_str()); + return; + } + + if (peer_id == id_) { + ERR("ignore"); + return; + } + + if (is_publisher_) { + if (room_message_arrived_cb_) + room_message_arrived_cb_(message); + } else { + // TODO: ADHOC, will handle this by room + if (peer_id != source_id_) { + ERR("peer(%s) is Not source(%s)", peer_id.c_str(), source_id_.c_str()); + return; + } + + if (room_message_arrived_cb_) + room_message_arrived_cb_(message); + } +} + +void MqttServer::HandleMessageTopic(aitt::MSG *msg, const std::string &topic, const void *data, + const size_t datalen, void *user_data) +{ + INFO("Message topic"); + std::string message(static_cast<const char *>(data), datalen); + + if (room_message_arrived_cb_) + room_message_arrived_cb_(message); +} diff --git a/modules/webrtc/MqttServer.h b/modules/webrtc/MqttServer.h new file mode 100644 index 0000000..7f93192 --- /dev/null +++ b/modules/webrtc/MqttServer.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <MQ.h> + +#include "Config.h" +#include "IfaceServer.h" + +class MqttServer : public IfaceServer { + public: + explicit MqttServer(const Config &config); + virtual ~MqttServer(); + + void SetConnectionStateChangedCb( + std::function<void(ConnectionState)> connection_state_changed_cb) override + { + connection_state_changed_cb_ = connection_state_changed_cb; + }; + void UnsetConnectionStateChangedCb(void) override { connection_state_changed_cb_ = nullptr; }; + + bool IsConnected(void) override; + int Connect(void) override; + int Disconnect(void) override; + int SendMessage(const std::string &peer_id, const std::string &msg) override; + + static std::string GetConnectionStateStr(ConnectionState state); + void RegisterWithServer(void); + void JoinRoom(const std::string &room_id); + void SetConnectionState(ConnectionState state); + ConnectionState GetConnectionState(void) const { return connection_state_; }; + std::string GetId(void) const { return id_; }; + std::string GetSourceId(void) const { return source_id_; }; + void SetSourceId(const std::string &source_id) { source_id_ = source_id; }; + + void SetRoomMessageArrivedCb(std::function<void(const std::string &)> room_message_arrived_cb) + { + room_message_arrived_cb_ = room_message_arrived_cb; + }; + void UnsetRoomMessageArrivedCb(void) { room_message_arrived_cb_ = nullptr; } + + private: + static void MessageCallback(mosquitto *handle, void *mqtt_server, const mosquitto_message *msg, + const mosquitto_property *props); + void OnConnect(); + void OnDisconnect(); + void ConnectCallBack(int status); + void HandleRoomTopic(aitt::MSG *msg, const std::string &topic, const void *data, + const size_t datalen, void *user_data); + void HandleSourceTopic(aitt::MSG *msg, const std::string &topic, const void *data, + const size_t datalen, void *user_data); + void HandleMessageTopic(aitt::MSG *msg, const std::string &topic, const void *data, + const size_t datalen, void *user_data); + + std::string broker_ip_; + int broker_port_; + std::string id_; + std::string room_id_; + std::string source_id_; + bool is_publisher_; + aitt::MQ mq; + + ConnectionState connection_state_; + std::function<void(ConnectionState)> connection_state_changed_cb_; + std::function<void(const std::string &)> room_message_arrived_cb_; +}; diff --git a/modules/webrtc/PublishStream.cc b/modules/webrtc/PublishStream.cc new file mode 100644 index 0000000..f93ecea --- /dev/null +++ b/modules/webrtc/PublishStream.cc @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "PublishStream.h" + +#include <sys/time.h> + +#include "WebRtcEventHandler.h" +#include "aitt_internal.h" + +PublishStream::~PublishStream() +{ + // TODO: clear resources +} + +void PublishStream::Start(void) +{ + PrepareStream(); + SetSignalingServerCallbacks(); + SetRoomCallbacks(); +} + +void PublishStream::PrepareStream(void) +{ + std::lock_guard<std::mutex> prepared_stream_lock(prepared_stream_lock_); + prepared_stream_ = std::make_shared<WebRtcStream>(); + prepared_stream_->Create(true, false); + prepared_stream_->AttachCameraSource(); + auto on_stream_state_changed_prepared_cb = + std::bind(OnStreamStateChangedPrepared, std::placeholders::_1, std::ref(*this)); + prepared_stream_->GetEventHandler().SetOnStateChangedCb(on_stream_state_changed_prepared_cb); + prepared_stream_->Start(); +} + +void PublishStream::OnStreamStateChangedPrepared(WebRtcState::Stream state, PublishStream &stream) +{ + ERR("%s", __func__); + if (state == WebRtcState::Stream::NEGOTIATING) { + auto on_offer_created_prepared_cb = + std::bind(OnOfferCreatedPrepared, std::placeholders::_1, std::ref(stream)); + stream.prepared_stream_->CreateOfferAsync(on_offer_created_prepared_cb); + } +} + +void PublishStream::OnOfferCreatedPrepared(std::string sdp, PublishStream &stream) +{ + ERR("%s", __func__); + + stream.prepared_stream_->SetPreparedLocalDescription(sdp); + stream.prepared_stream_->SetLocalDescription(sdp); + try { + stream.server_->Connect(); + } catch (const std::exception &e) { + ERR("Failed to start Publish stream %s", e.what()); + } +} + +void PublishStream::SetSignalingServerCallbacks(void) +{ + auto on_signaling_server_connection_state_changed = + std::bind(OnSignalingServerConnectionStateChanged, std::placeholders::_1, + std::ref(*room_), std::ref(*server_)); + + server_->SetConnectionStateChangedCb(on_signaling_server_connection_state_changed); + + auto on_room_message_arrived = + std::bind(OnRoomMessageArrived, std::placeholders::_1, std::ref(*room_)); + + server_->SetRoomMessageArrivedCb(on_room_message_arrived); +} + +void PublishStream::OnSignalingServerConnectionStateChanged(IfaceServer::ConnectionState state, + WebRtcRoom &room, MqttServer &server) +{ + DBG("current state [%s]", MqttServer::GetConnectionStateStr(state).c_str()); + + if (state == IfaceServer::ConnectionState::Disconnected) { + ; // TODO: what to do when server is disconnected? + } else if (state == IfaceServer::ConnectionState::Registered) { + server.JoinRoom(room.getId()); + } +} + +void PublishStream::OnRoomMessageArrived(const std::string &message, WebRtcRoom &room) +{ + room.handleMessage(message); +} + +void PublishStream::SetRoomCallbacks() +{ + auto on_room_joined = std::bind(OnRoomJoined, std::ref(*this)); + + room_->SetRoomJoinedCb(on_room_joined); + + auto on_peer_joined = std::bind(OnPeerJoined, std::placeholders::_1, std::ref(*this)); + room_->SetPeerJoinedCb(on_peer_joined); + + auto on_peer_left = std::bind(OnPeerLeft, std::placeholders::_1, std::ref(*room_)); + room_->SetPeerLeftCb(on_peer_left); +} + +void PublishStream::OnRoomJoined(PublishStream &publish_stream) +{ + // TODO: Notify Room Joined? + DBG("%s on %p", __func__, &publish_stream); +} + +void PublishStream::OnPeerJoined(const std::string &peer_id, PublishStream &publish_stream) +{ + DBG("%s [%s]", __func__, peer_id.c_str()); + if (!publish_stream.room_->AddPeer(peer_id)) { + ERR("Failed to add peer"); + return; + } + + try { + WebRtcPeer &peer = publish_stream.room_->GetPeer(peer_id); + + std::unique_lock<std::mutex> prepared_stream_lock(publish_stream.prepared_stream_lock_); + auto prepared_stream = publish_stream.prepared_stream_; + publish_stream.prepared_stream_ = nullptr; + prepared_stream_lock.unlock(); + + try { + peer.SetWebRtcStream(prepared_stream); + publish_stream.SetWebRtcStreamCallbacks(peer); + publish_stream.server_->SendMessage(peer.getId(), + peer.GetWebRtcStream()->GetPreparedLocalDescription()); + prepared_stream->SetPreparedLocalDescription(""); + } catch (std::exception &e) { + ERR("Failed to start stream for peer %s", e.what()); + } + // TODO why we can't prepare more sources? + + } catch (std::exception &e) { + ERR("Wired %s", e.what()); + } +} + +void PublishStream::SetWebRtcStreamCallbacks(WebRtcPeer &peer) +{ + // TODO: set more webrtc callbacks + WebRtcEventHandler event_handlers; + auto on_stream_state_changed_cb = std::bind(OnStreamStateChanged, std::placeholders::_1, + std::ref(peer), std::ref(*server_)); + event_handlers.SetOnStateChangedCb(on_stream_state_changed_cb); + + auto on_signaling_state_notify_cb = std::bind(OnSignalingStateNotify, std::placeholders::_1, + std::ref(peer), std::ref(*server_)); + event_handlers.SetOnSignalingStateNotifyCb(on_signaling_state_notify_cb); + + auto on_ice_connection_state_notify = std::bind(OnIceConnectionStateNotify, + std::placeholders::_1, std::ref(peer), std::ref(*server_)); + event_handlers.SetOnIceConnectionStateNotifyCb(on_ice_connection_state_notify); + + peer.GetWebRtcStream()->SetEventHandler(event_handlers); +} + +void PublishStream::OnStreamStateChanged(WebRtcState::Stream state, WebRtcPeer &peer, + MqttServer &server) +{ + ERR("%s for %s", __func__, peer.getId().c_str()); +} + +void PublishStream::OnSignalingStateNotify(WebRtcState::Signaling state, WebRtcPeer &peer, + MqttServer &server) +{ + ERR("Singaling State: %s", WebRtcState::SignalingToStr(state).c_str()); + if (state == WebRtcState::Signaling::STABLE) { + auto ice_candidates = peer.GetWebRtcStream()->GetIceCandidates(); + for (const auto &candidate : ice_candidates) + server.SendMessage(peer.getId(), candidate); + } +} + +void PublishStream::OnIceConnectionStateNotify(WebRtcState::IceConnection state, WebRtcPeer &peer, + MqttServer &server) +{ + ERR("IceConnection State: %s", WebRtcState::IceConnectionToStr(state).c_str()); +} + +void PublishStream::OnPeerLeft(const std::string &peer_id, WebRtcRoom &room) +{ + DBG("%s [%s]", __func__, peer_id.c_str()); + if (!room.RemovePeer(peer_id)) + ERR("Failed to remove peer"); +} + +void PublishStream::Stop(void) +{ + try { + server_->Disconnect(); + } catch (const std::exception &e) { + ERR("Failed to disconnect server %s", e.what()); + } + + room_->ClearPeers(); +} diff --git a/modules/webrtc/PublishStream.h b/modules/webrtc/PublishStream.h new file mode 100644 index 0000000..1805528 --- /dev/null +++ b/modules/webrtc/PublishStream.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <memory> +#include <mutex> +#include <string> + +#include "Config.h" +#include "MqttServer.h" +#include "WebRtcRoom.h" +#include "WebRtcStream.h" + +class PublishStream { + // TODO: Notify & get status + public: + PublishStream() = delete; + PublishStream(const std::string &topic, const Config &config) + : topic_(topic), + config_(config), + server_(std::make_shared<MqttServer>(config)), + room_(std::make_shared<WebRtcRoom>(config.GetRoomId())), + prepared_stream_(nullptr){}; + ~PublishStream(); + + void Start(void); + void Stop(void); + void SetSignalingServerCallbacks(void); + void SetRoomCallbacks(void); + void SetWebRtcStreamCallbacks(WebRtcPeer &peer); + void PrepareStream(void); + + private: + static void OnStreamStateChangedPrepared(WebRtcState::Stream state, PublishStream &stream); + static void OnOfferCreatedPrepared(std::string sdp, PublishStream &stream); + static void OnSignalingServerConnectionStateChanged(IfaceServer::ConnectionState state, + WebRtcRoom &room, MqttServer &server); + static void OnRoomMessageArrived(const std::string &message, WebRtcRoom &room); + static void OnRoomJoined(PublishStream &publish_stream); + static void OnPeerJoined(const std::string &peer_id, PublishStream &publish_stream); + static void OnPeerLeft(const std::string &peer_id, WebRtcRoom &room); + static void OnStreamStateChanged(WebRtcState::Stream state, WebRtcPeer &peer, + MqttServer &server); + + static void OnSignalingStateNotify(WebRtcState::Signaling state, WebRtcPeer &peer, + MqttServer &server); + static void OnIceConnectionStateNotify(WebRtcState::IceConnection state, WebRtcPeer &peer, + MqttServer &server); + + private: + std::string topic_; + Config config_; + std::shared_ptr<MqttServer> server_; + std::shared_ptr<WebRtcRoom> room_; + std::mutex prepared_stream_lock_; + std::shared_ptr<WebRtcStream> prepared_stream_; +}; diff --git a/modules/webrtc/SubscribeStream.cc b/modules/webrtc/SubscribeStream.cc new file mode 100644 index 0000000..841cfa6 --- /dev/null +++ b/modules/webrtc/SubscribeStream.cc @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "SubscribeStream.h" + +#include "WebRtcEventHandler.h" +#include "aitt_internal.h" + +SubscribeStream::~SubscribeStream() +{ + // TODO Clear resources +} + +void SubscribeStream::Start(bool need_display, void *display_object) +{ + display_object_ = display_object; + is_track_added_ = need_display; + SetSignalingServerCallbacks(); + SetRoomCallbacks(); + try { + server_->Connect(); + } catch (const std::exception &e) { + ERR("Failed to start Subscribe stream %s", e.what()); + } +} + +void SubscribeStream::SetSignalingServerCallbacks(void) +{ + auto on_signaling_server_connection_state_changed = + std::bind(OnSignalingServerConnectionStateChanged, std::placeholders::_1, + std::ref(*room_), std::ref(*server_)); + + server_->SetConnectionStateChangedCb(on_signaling_server_connection_state_changed); + + auto on_room_message_arrived = + std::bind(OnRoomMessageArrived, std::placeholders::_1, std::ref(*room_)); + + server_->SetRoomMessageArrivedCb(on_room_message_arrived); +} + +void SubscribeStream::OnSignalingServerConnectionStateChanged(IfaceServer::ConnectionState state, + WebRtcRoom &room, MqttServer &server) +{ + // TODO VD doesn't show DBG level log + ERR("current state [%s]", MqttServer::GetConnectionStateStr(state).c_str()); + + if (state == IfaceServer::ConnectionState::Disconnected) { + ; // TODO: what to do when server is disconnected? + } else if (state == IfaceServer::ConnectionState::Registered) { + if (server.GetSourceId().size() != 0) + room.SetSourceId(server.GetSourceId()); + server.JoinRoom(room.getId()); + } +} + +void SubscribeStream::OnRoomMessageArrived(const std::string &message, WebRtcRoom &room) +{ + room.handleMessage(message); +} + +void SubscribeStream::SetRoomCallbacks(void) +{ + auto on_room_joined = std::bind(OnRoomJoined, std::ref(*this)); + + room_->SetRoomJoinedCb(on_room_joined); + + auto on_peer_joined = std::bind(OnPeerJoined, std::placeholders::_1, std::ref(*this)); + room_->SetPeerJoinedCb(on_peer_joined); + + auto on_peer_left = std::bind(OnPeerLeft, std::placeholders::_1, std::ref(*room_)); + room_->SetPeerLeftCb(on_peer_left); +} + +void SubscribeStream::OnRoomJoined(SubscribeStream &subscribe_stream) +{ + // TODO: Notify Room Joined? + ERR("%s on %p", __func__, &subscribe_stream); +} + +void SubscribeStream::OnPeerJoined(const std::string &peer_id, SubscribeStream &subscribe_stream) +{ + ERR("%s [%s]", __func__, peer_id.c_str()); + + if (peer_id.compare(subscribe_stream.room_->GetSourceId()) != 0) { + ERR("is not matched to source ID, ignored"); + return; + } + + if (!subscribe_stream.room_->AddPeer(peer_id)) { + ERR("Failed to add peer"); + return; + } + + try { + WebRtcPeer &peer = subscribe_stream.room_->GetPeer(peer_id); + + auto webrtc_subscribe_stream = peer.GetWebRtcStream(); + webrtc_subscribe_stream->Create(false, subscribe_stream.is_track_added_); + webrtc_subscribe_stream->Start(); + subscribe_stream.SetWebRtcStreamCallbacks(peer); + } catch (std::out_of_range &e) { + ERR("Wired %s", e.what()); + } +} + +void SubscribeStream::SetWebRtcStreamCallbacks(WebRtcPeer &peer) +{ + WebRtcEventHandler event_handlers; + + auto on_signaling_state_notify = std::bind(OnSignalingStateNotify, std::placeholders::_1, + std::ref(peer), std::ref(*server_)); + event_handlers.SetOnSignalingStateNotifyCb(on_signaling_state_notify); + + auto on_ice_connection_state_notify = std::bind(OnIceConnectionStateNotify, + std::placeholders::_1, std::ref(peer), std::ref(*server_)); + event_handlers.SetOnIceConnectionStateNotifyCb(on_ice_connection_state_notify); + + auto on_encoded_frame = std::bind(OnEncodedFrame, std::ref(peer)); + event_handlers.SetOnEncodedFrameCb(on_encoded_frame); + + auto on_track_added = + std::bind(OnTrackAdded, std::placeholders::_1, display_object_, std::ref(peer)); + event_handlers.SetOnTrakAddedCb(on_track_added); + + peer.GetWebRtcStream()->SetEventHandler(event_handlers); +} + +void SubscribeStream::OnSignalingStateNotify(WebRtcState::Signaling state, WebRtcPeer &peer, + MqttServer &server) +{ + ERR("Singaling State: %s", WebRtcState::SignalingToStr(state).c_str()); + if (state == WebRtcState::Signaling::HAVE_REMOTE_OFFER) { + auto on_answer_created_cb = + std::bind(OnAnswerCreated, std::placeholders::_1, std::ref(peer), std::ref(server)); + peer.GetWebRtcStream()->CreateAnswerAsync(on_answer_created_cb); + } +} + +void SubscribeStream::OnIceConnectionStateNotify(WebRtcState::IceConnection state, WebRtcPeer &peer, + MqttServer &server) +{ + ERR("IceConnection State: %s", WebRtcState::IceConnectionToStr(state).c_str()); + if (state == WebRtcState::IceConnection::CHECKING) { + auto ice_candidates = peer.GetWebRtcStream()->GetIceCandidates(); + for (const auto &candidate : ice_candidates) + server.SendMessage(peer.getId(), candidate); + } +} + +void SubscribeStream::OnAnswerCreated(std::string sdp, WebRtcPeer &peer, MqttServer &server) +{ + server.SendMessage(peer.getId(), sdp); + peer.GetWebRtcStream()->SetLocalDescription(sdp); +} + +void SubscribeStream::OnEncodedFrame(WebRtcPeer &peer) +{ + // TODO +} + +void SubscribeStream::OnTrackAdded(unsigned int id, void *display_object, WebRtcPeer &peer) +{ + peer.GetWebRtcStream()->SetDisplayObject(id, display_object); +} + +void SubscribeStream::OnPeerLeft(const std::string &peer_id, WebRtcRoom &room) +{ + /*TODO + ERR("%s [%s]", __func__, peer_id.c_str()); + if (peer_id.compare(room.getSourceId()) != 0) { + ERR("is not matched to source ID, ignored"); + return; + } + */ + if (!room.RemovePeer(peer_id)) + ERR("Failed to remove peer"); +} + +void *SubscribeStream::Stop(void) +{ + try { + server_->Disconnect(); + } catch (const std::exception &e) { + ERR("Failed to disconnect server %s", e.what()); + } + + room_->ClearPeers(); + + return display_object_; +} diff --git a/modules/webrtc/SubscribeStream.h b/modules/webrtc/SubscribeStream.h new file mode 100644 index 0000000..c8853f6 --- /dev/null +++ b/modules/webrtc/SubscribeStream.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <memory> +#include <mutex> + +#include "Config.h" +#include "MqttServer.h" +#include "WebRtcRoom.h" +#include "WebRtcStream.h" + +class SubscribeStream { + public: + SubscribeStream() = delete; + SubscribeStream(const std::string &topic, const Config &config) + : topic_(topic), + config_(config), + server_(std::make_shared<MqttServer>(config)), + room_(std::make_shared<WebRtcRoom>(config.GetRoomId())), + is_track_added_(false), + display_object_(nullptr){}; + ~SubscribeStream(); + + // TODO what will be final form of callback + void Start(bool need_display, void *display_object); + void *Stop(void); + void SetSignalingServerCallbacks(void); + void SetRoomCallbacks(void); + void SetWebRtcStreamCallbacks(WebRtcPeer &peer); + + private: + static void OnSignalingServerConnectionStateChanged(IfaceServer::ConnectionState state, + WebRtcRoom &room, MqttServer &server); + static void OnRoomMessageArrived(const std::string &message, WebRtcRoom &room); + static void OnRoomJoined(SubscribeStream &subscribe_stream); + static void OnPeerJoined(const std::string &peer_id, SubscribeStream &subscribe_stream); + static void OnPeerLeft(const std::string &peer_id, WebRtcRoom &room); + static void OnSignalingStateNotify(WebRtcState::Signaling state, WebRtcPeer &peer, + MqttServer &server); + static void OnIceConnectionStateNotify(WebRtcState::IceConnection state, WebRtcPeer &peer, + MqttServer &server); + static void OnAnswerCreated(std::string sdp, WebRtcPeer &peer, MqttServer &server); + static void OnEncodedFrame(WebRtcPeer &peer); + static void OnTrackAdded(unsigned int id, void *dispaly_object, WebRtcPeer &peer); + + private: + std::string topic_; + Config config_; + std::shared_ptr<MqttServer> server_; + std::shared_ptr<WebRtcRoom> room_; + bool is_track_added_; + void *display_object_; +}; diff --git a/modules/webrtc/WebRtcEventHandler.h b/modules/webrtc/WebRtcEventHandler.h new file mode 100644 index 0000000..c922672 --- /dev/null +++ b/modules/webrtc/WebRtcEventHandler.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <functional> +#include <string> + +#include "WebRtcState.h" + +class WebRtcEventHandler { + public: + // TODO Add error and state callbacks + void SetOnStateChangedCb(std::function<void(WebRtcState::Stream)> on_state_changed_cb) + { + on_state_changed_cb_ = on_state_changed_cb; + }; + void CallOnStateChangedCb(WebRtcState::Stream state) const + { + if (on_state_changed_cb_) + on_state_changed_cb_(state); + }; + void UnsetOnStateChangedCb(void) { on_state_changed_cb_ = nullptr; }; + + void SetOnSignalingStateNotifyCb( + std::function<void(WebRtcState::Signaling)> on_signaling_state_notify_cb) + { + on_signaling_state_notify_cb_ = on_signaling_state_notify_cb; + }; + void CallOnSignalingStateNotifyCb(WebRtcState::Signaling state) const + { + if (on_signaling_state_notify_cb_) + on_signaling_state_notify_cb_(state); + }; + void UnsetOnSignalingStateNotifyCb(void) { on_signaling_state_notify_cb_ = nullptr; }; + + void SetOnIceConnectionStateNotifyCb(std::function<void(WebRtcState::IceConnection)> on_ice_connection_state_notify_cb) + { + on_ice_connection_state_notify_cb_ = on_ice_connection_state_notify_cb; + }; + void CallOnIceConnectionStateNotifyCb(WebRtcState::IceConnection state) const + { + if (on_ice_connection_state_notify_cb_) + on_ice_connection_state_notify_cb_(state); + }; + void UnsetOnIceConnectionStateNotifyeCb(void) { on_ice_connection_state_notify_cb_ = nullptr; }; + + void SetOnEncodedFrameCb(std::function<void(void)> on_encoded_frame_cb) + { + on_encoded_frame_cb_ = on_encoded_frame_cb; + }; + void CallOnEncodedFrameCb(void) const + { + if (on_encoded_frame_cb_) + on_encoded_frame_cb_(); + }; + void UnsetEncodedFrameCb(void) { on_encoded_frame_cb_ = nullptr; }; + + void SetOnTrakAddedCb(std::function<void(unsigned int id)> on_track_added_cb) + { + on_track_added_cb_ = on_track_added_cb; + }; + void CallOnTrakAddedCb(unsigned int id) const + { + if (on_track_added_cb_) + on_track_added_cb_(id); + }; + void UnsetTrackAddedCb(void) { on_track_added_cb_ = nullptr; }; + + private: + std::function<void(void)> on_negotiation_needed_cb_; + std::function<void(WebRtcState::Stream)> on_state_changed_cb_; + std::function<void(WebRtcState::Signaling)> on_signaling_state_notify_cb_; + std::function<void(WebRtcState::IceConnection)> on_ice_connection_state_notify_cb_; + std::function<void(void)> on_encoded_frame_cb_; + std::function<void(unsigned int id)> on_track_added_cb_; +}; diff --git a/modules/webrtc/WebRtcMessage.cc b/modules/webrtc/WebRtcMessage.cc new file mode 100644 index 0000000..3d9af79 --- /dev/null +++ b/modules/webrtc/WebRtcMessage.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <json-glib/json-glib.h> + +#include "aitt_internal.h" + +#include "WebRtcMessage.h" + +WebRtcMessage::Type WebRtcMessage::getMessageType(const std::string &message) +{ + WebRtcMessage::Type type = WebRtcMessage::Type::UNKNOWN; + JsonParser *parser = json_parser_new(); + if (!json_parser_load_from_data(parser, message.c_str(), -1, NULL)) { + DBG("Unknown message '%s', ignoring", message.c_str()); + g_object_unref(parser); + return type; + } + + JsonNode *root = json_parser_get_root(parser); + if (!JSON_NODE_HOLDS_OBJECT(root)) { + DBG("Unknown json message '%s', ignoring", message.c_str()); + g_object_unref(parser); + return type; + } + + JsonObject *object = json_node_get_object(root); + /* Check type of JSON message */ + + if (json_object_has_member(object, "sdp")) { + type = WebRtcMessage::Type::SDP; + } else if (json_object_has_member(object, "ice")) { + type = WebRtcMessage::Type::ICE; + } else { + DBG("%s:UNKNOWN", __func__); + } + + g_object_unref(parser); + return type; +} diff --git a/modules/webrtc/WebRtcMessage.h b/modules/webrtc/WebRtcMessage.h new file mode 100644 index 0000000..6057a22 --- /dev/null +++ b/modules/webrtc/WebRtcMessage.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <string> + +class WebRtcMessage { + public: + enum class Type { + SDP, + ICE, + UNKNOWN, + }; + static WebRtcMessage::Type getMessageType(const std::string &message); +}; diff --git a/modules/webrtc/WebRtcPeer.cc b/modules/webrtc/WebRtcPeer.cc new file mode 100644 index 0000000..119f6e4 --- /dev/null +++ b/modules/webrtc/WebRtcPeer.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WebRtcPeer.h" + +#include "WebRtcMessage.h" +#include "aitt_internal.h" + +WebRtcPeer::WebRtcPeer(const std::string &peer_id) + : local_id_(peer_id), webrtc_stream_(std::make_shared<WebRtcStream>()) +{ + DBG("%s", __func__); +} + +WebRtcPeer::~WebRtcPeer() +{ + webrtc_stream_ = nullptr; + DBG("%s removed", local_id_.c_str()); +} + +std::shared_ptr<WebRtcStream> WebRtcPeer::GetWebRtcStream(void) const +{ + return webrtc_stream_; +} + +void WebRtcPeer::SetWebRtcStream(std::shared_ptr<WebRtcStream> webrtc_stream) +{ + webrtc_stream_ = webrtc_stream; +} + +std::string WebRtcPeer::getId(void) const +{ + return local_id_; +} + +void WebRtcPeer::HandleMessage(const std::string &message) +{ + WebRtcMessage::Type type = WebRtcMessage::getMessageType(message); + if (type == WebRtcMessage::Type::SDP) + webrtc_stream_->SetRemoteDescription(message); + else if (type == WebRtcMessage::Type::ICE) + webrtc_stream_->AddIceCandidateFromMessage(message); + else + DBG("%s can't handle %s", __func__, message.c_str()); +} diff --git a/modules/webrtc/WebRtcPeer.h b/modules/webrtc/WebRtcPeer.h new file mode 100644 index 0000000..1ccb4e9 --- /dev/null +++ b/modules/webrtc/WebRtcPeer.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <memory> +#include <string> + +#include "WebRtcStream.h" + +class WebRtcPeer { + public: + explicit WebRtcPeer(const std::string &peer_id); + ~WebRtcPeer(); + + std::shared_ptr<WebRtcStream> GetWebRtcStream(void) const; + void SetWebRtcStream(std::shared_ptr<WebRtcStream> webrtc_stream); + std::string getId(void) const; + void HandleMessage(const std::string &message); + + private: + std::string local_id_; + std::shared_ptr<WebRtcStream> webrtc_stream_; +}; diff --git a/modules/webrtc/WebRtcRoom.cc b/modules/webrtc/WebRtcRoom.cc new file mode 100644 index 0000000..781b72b --- /dev/null +++ b/modules/webrtc/WebRtcRoom.cc @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WebRtcRoom.h" + +#include <sstream> +#include <stdexcept> + +#include "aitt_internal.h" + +WebRtcRoom::~WebRtcRoom() +{ + //TODO How about removing handling webrtc_stream part from Room? + for (auto pair : peers_) { + auto peer = pair.second; + auto webrtc_stream = peer->GetWebRtcStream(); + webrtc_stream->Stop(); + webrtc_stream->Destroy(); + } +} + +static std::vector<std::string> split(const std::string &line, char delimiter) +{ + std::vector<std::string> result; + std::stringstream input(line); + + std::string buffer; + while (getline(input, buffer, delimiter)) { + result.push_back(buffer); + } + + return result; +} + +void WebRtcRoom::handleMessage(const std::string &msg) +{ + if (msg.compare("ROOM_OK ") == 0) + CallRoomJoinedCb(); + else if (msg.compare(0, 8, "ROOM_OK ") == 0) { + CallRoomJoinedCb(); + HandleRoomJoinedWithPeerList(msg.substr(8, std::string::npos)); + } else if (msg.compare(0, 16, "ROOM_PEER_JOINED") == 0) { + CallPeerJoinedCb(msg.substr(17, std::string::npos)); + } else if (msg.compare(0, 14, "ROOM_PEER_LEFT") == 0) { + CallPeerLeftCb(msg.substr(15, std::string::npos)); + } else if (msg.compare(0, 13, "ROOM_PEER_MSG") == 0) { + HandlePeerMessage(msg.substr(14, std::string::npos)); + } else { + DBG("Not defined"); + } + + return; +} + +void WebRtcRoom::HandleRoomJoinedWithPeerList(const std::string &peer_list) +{ + auto peer_ids = split(peer_list, ' '); + for (const auto &peer_id : peer_ids) { + CallPeerJoinedCb(peer_id); + } +} + +void WebRtcRoom::HandlePeerMessage(const std::string &msg) +{ + std::size_t pos = msg.find(' '); + if (pos == std::string::npos) { + DBG("%s can't handle %s", __func__, msg.c_str()); + return; + } + + auto peer_id = msg.substr(0, pos); + auto itr = peers_.find(peer_id); + if (itr == peers_.end()) { + ERR("%s is not in peer list", peer_id.c_str()); + //Opening backdoor here for source. What'll be crisis for this? + CallPeerJoinedCb(peer_id); + itr = peers_.find(peer_id); + RET_IF(itr == peers_.end()); + } + + itr->second->HandleMessage(msg.substr(pos + 1, std::string::npos)); +} + +bool WebRtcRoom::AddPeer(const std::string &peer_id) +{ + auto peer = std::make_shared<WebRtcPeer>(peer_id); + auto ret = peers_.insert(std::make_pair(peer_id, peer)); + + return ret.second; +} + +bool WebRtcRoom::RemovePeer(const std::string &peer_id) +{ + auto itr = peers_.find(peer_id); + if (itr == peers_.end()) { + DBG("There's no such peer"); + return false; + } + auto peer = itr->second; + + //TODO How about removing handling webrtc_stream part from Room? + auto webrtc_stream = peer->GetWebRtcStream(); + webrtc_stream->Stop(); + webrtc_stream->Destroy(); + + return peers_.erase(peer_id) == 1; +} + +WebRtcPeer &WebRtcRoom::GetPeer(const std::string &peer_id) +{ + auto itr = peers_.find(peer_id); + if (itr == peers_.end()) + throw std::out_of_range("There's no such peer"); + + return *itr->second; +} + +void WebRtcRoom::ClearPeers(void) +{ + //TODO How about removing handling webrtc_stream part from Room? + for (auto pair : peers_) { + auto peer = pair.second; + auto webrtc_stream = peer->GetWebRtcStream(); + webrtc_stream->Stop(); + webrtc_stream->Destroy(); + } + peers_.clear(); +} diff --git a/modules/webrtc/WebRtcRoom.h b/modules/webrtc/WebRtcRoom.h new file mode 100644 index 0000000..fabeb1e --- /dev/null +++ b/modules/webrtc/WebRtcRoom.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <functional> +#include <map> +#include <memory> +#include <string> + +#include "WebRtcPeer.h" + +class WebRtcRoom { + public: + enum class State { + JOINNING, + JOINED, + }; + WebRtcRoom() = delete; + WebRtcRoom(const std::string &room_id) : id_(room_id){}; + ~WebRtcRoom(); + void setRoomState(State current) { state_ = current; } + State getRoomState(void) const { return state_; }; + void handleMessage(const std::string &msg); + bool AddPeer(const std::string &peer_id); + bool RemovePeer(const std::string &peer_id); + void ClearPeers(void); + // You need to handle out_of_range exception if there's no matching peer; + WebRtcPeer &GetPeer(const std::string &peer_id); + std::string getId(void) const { return id_; }; + void SetSourceId(const std::string &source_id) { source_id_ = source_id; }; + std::string GetSourceId(void) const { return source_id_; }; + + void SetRoomJoinedCb(std::function<void(void)> on_room_joined_cb) + { + on_room_joined_cb_ = on_room_joined_cb; + }; + void CallRoomJoinedCb(void) const + { + if (on_room_joined_cb_) + on_room_joined_cb_(); + }; + void UnsetRoomJoinedCb(void) { on_room_joined_cb_ = nullptr; }; + void SetPeerJoinedCb(std::function<void(const std::string &peer_id)> on_peer_joined_cb) + { + on_peer_joined_cb_ = on_peer_joined_cb; + }; + void CallPeerJoinedCb(const std::string &peer_id) const + { + if (on_peer_joined_cb_) + on_peer_joined_cb_(peer_id); + }; + void UnsetPeerJoinedCb(void) { on_peer_joined_cb_ = nullptr; }; + void SetPeerLeftCb(std::function<void(const std::string &peer_id)> on_peer_left_cb) + { + on_peer_left_cb_ = on_peer_left_cb; + }; + void CallPeerLeftCb(const std::string &peer_id) const + { + if (on_peer_left_cb_) + on_peer_left_cb_(peer_id); + }; + void UnsetPeerLeftCb(void) { on_peer_left_cb_ = nullptr; }; + + private: + void HandleRoomJoinedWithPeerList(const std::string &peer_list); + void HandlePeerMessage(const std::string &msg); + + private: + std::string id_; + std::string source_id_; + std::map<std::string, std::shared_ptr<WebRtcPeer>> peers_; + State state_; + std::function<void(void)> on_room_joined_cb_; + std::function<void(const std::string &peer_id)> on_peer_joined_cb_; + std::function<void(const std::string &peer_id)> on_peer_left_cb_; +}; diff --git a/modules/webrtc/WebRtcState.cc b/modules/webrtc/WebRtcState.cc new file mode 100644 index 0000000..437460d --- /dev/null +++ b/modules/webrtc/WebRtcState.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WebRtcState.h" + +WebRtcState::Stream WebRtcState::ToStreamState(webrtc_state_e state) +{ + switch (state) { + case WEBRTC_STATE_IDLE: { + return Stream::IDLE; + } + case WEBRTC_STATE_NEGOTIATING: { + return Stream::NEGOTIATING; + } + case WEBRTC_STATE_PLAYING: { + return Stream::PLAYING; + } + } + return Stream::IDLE; +} + +std::string WebRtcState::StreamToStr(WebRtcState::Stream state) +{ + switch (state) { + case Stream::IDLE: { + return std::string("IDLE"); + } + case Stream::NEGOTIATING: { + return std::string("NEGOTIATING"); + } + case Stream::PLAYING: { + return std::string("PLAYING"); + } + } + return std::string(""); +} + +WebRtcState::Signaling WebRtcState::ToSignalingState(webrtc_signaling_state_e state) +{ + switch (state) { + case WEBRTC_SIGNALING_STATE_STABLE: { + return Signaling::STABLE; + } + case WEBRTC_SIGNALING_STATE_HAVE_LOCAL_OFFER: { + return Signaling::HAVE_LOCAL_OFFER; + } + case WEBRTC_SIGNALING_STATE_HAVE_REMOTE_OFFER: { + return Signaling::HAVE_REMOTE_OFFER; + } + case WEBRTC_SIGNALING_STATE_HAVE_LOCAL_PRANSWER: { + return Signaling::HAVE_LOCAL_PRANSWER; + } + case WEBRTC_SIGNALING_STATE_HAVE_REMOTE_PRANSWER: { + return Signaling::HAVE_REMOTE_PRANSWER; + } + case WEBRTC_SIGNALING_STATE_CLOSED: { + return Signaling::CLOSED; + } + } + return Signaling::STABLE; +} + +std::string WebRtcState::SignalingToStr(WebRtcState::Signaling state) +{ + switch (state) { + case (WebRtcState::Signaling::STABLE): { + return std::string("STABLE"); + } + case (WebRtcState::Signaling::CLOSED): { + return std::string("CLOSED"); + } + case (WebRtcState::Signaling::HAVE_LOCAL_OFFER): { + return std::string("HAVE_LOCAL_OFFER"); + } + case (WebRtcState::Signaling::HAVE_REMOTE_OFFER): { + return std::string("HAVE_REMOTE_OFFER"); + } + case (WebRtcState::Signaling::HAVE_LOCAL_PRANSWER): { + return std::string("HAVE_LOCAL_PRANSWER"); + } + case (WebRtcState::Signaling::HAVE_REMOTE_PRANSWER): { + return std::string("HAVE_REMOTE_PRANSWER"); + } + } + return std::string(""); +} + +WebRtcState::IceGathering WebRtcState::ToIceGatheringState(webrtc_ice_gathering_state_e state) +{ + switch (state) { + case WEBRTC_ICE_GATHERING_STATE_COMPLETE: { + return IceGathering::COMPLETE; + } + case WEBRTC_ICE_GATHERING_STATE_GATHERING: { + return IceGathering::GATHERING; + } + case WEBRTC_ICE_GATHERING_STATE_NEW: { + return IceGathering::NEW; + } + } + return IceGathering::NEW; +} + +std::string WebRtcState::IceGatheringToStr(WebRtcState::IceGathering state) +{ + switch (state) { + case (WebRtcState::IceGathering::NEW): { + return std::string("NEW"); + } + case (WebRtcState::IceGathering::GATHERING): { + return std::string("GATHERING"); + } + case (WebRtcState::IceGathering::COMPLETE): { + return std::string("COMPLETE"); + } + } + return std::string(""); +} + +WebRtcState::IceConnection WebRtcState::ToIceConnectionState(webrtc_ice_connection_state_e state) +{ + switch (state) { + case WEBRTC_ICE_CONNECTION_STATE_CHECKING: { + return IceConnection::CHECKING; + } + case WEBRTC_ICE_CONNECTION_STATE_CLOSED: { + return IceConnection::CLOSED; + } + case WEBRTC_ICE_CONNECTION_STATE_COMPLETED: { + return IceConnection::COMPLETED; + } + case WEBRTC_ICE_CONNECTION_STATE_CONNECTED: { + return IceConnection::CONNECTED; + } + case WEBRTC_ICE_CONNECTION_STATE_DISCONNECTED: { + return IceConnection::DISCONNECTED; + } + case WEBRTC_ICE_CONNECTION_STATE_FAILED: { + return IceConnection::FAILED; + } + case WEBRTC_ICE_CONNECTION_STATE_NEW: { + return IceConnection::NEW; + } + } + return IceConnection::NEW; +} + +std::string WebRtcState::IceConnectionToStr(WebRtcState::IceConnection state) +{ + switch (state) { + case (WebRtcState::IceConnection::NEW): { + return std::string("NEW"); + } + case (WebRtcState::IceConnection::CHECKING): { + return std::string("CHECKING"); + } + case (WebRtcState::IceConnection::CONNECTED): { + return std::string("CONNECTED"); + } + case (WebRtcState::IceConnection::COMPLETED): { + return std::string("COMPLETED"); + } + case (WebRtcState::IceConnection::FAILED): { + return std::string("FAILED"); + } + case (WebRtcState::IceConnection::DISCONNECTED): { + return std::string("DISCONNECTED"); + } + case (WebRtcState::IceConnection::CLOSED): { + return std::string("CLOSED"); + } + } + return std::string(""); +} diff --git a/modules/webrtc/WebRtcState.h b/modules/webrtc/WebRtcState.h new file mode 100644 index 0000000..c6ad8d0 --- /dev/null +++ b/modules/webrtc/WebRtcState.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <functional> +#include <string> + +#include <webrtc.h> + +class WebRtcState { + public: + enum class Stream { + IDLE, + NEGOTIATING, + PLAYING, + }; + + enum class PeerConnection { + NEW, + CONNECTING, + CONNECTED, + DISCONNECTED, + FAILED, + CLOSED, + }; + + enum class Signaling { + STABLE, + CLOSED, + HAVE_LOCAL_OFFER, + HAVE_REMOTE_OFFER, + HAVE_LOCAL_PRANSWER, + HAVE_REMOTE_PRANSWER, + }; + + enum class IceGathering { + NEW, + GATHERING, + COMPLETE, + }; + + enum class IceConnection { + NEW, + CHECKING, + CONNECTED, + COMPLETED, + FAILED, + DISCONNECTED, + CLOSED, + }; + + public: + static Stream ToStreamState(webrtc_state_e state); + static std::string StreamToStr(WebRtcState::Stream state); + static Signaling ToSignalingState(webrtc_signaling_state_e state); + static std::string SignalingToStr(WebRtcState::Signaling state); + static IceGathering ToIceGatheringState(webrtc_ice_gathering_state_e state); + static std::string IceGatheringToStr(WebRtcState::IceGathering state); + static IceConnection ToIceConnectionState(webrtc_ice_connection_state_e state); + static std::string IceConnectionToStr(WebRtcState::IceConnection state); +}; diff --git a/modules/webrtc/WebRtcStream.cc b/modules/webrtc/WebRtcStream.cc new file mode 100644 index 0000000..ef717b0 --- /dev/null +++ b/modules/webrtc/WebRtcStream.cc @@ -0,0 +1,414 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WebRtcStream.h" + +#include "aitt_internal.h" + +WebRtcStream::~WebRtcStream() +{ + Destroy(); + DBG("%s", __func__); +} + +bool WebRtcStream::Create(bool is_source, bool need_display) +{ + if (webrtc_handle_) { + ERR("Already created %p", webrtc_handle_); + return false; + } + + auto ret = webrtc_create(&webrtc_handle_); + if (ret != WEBRTC_ERROR_NONE) { + ERR("Failed to create webrtc handle"); + return false; + } + AttachSignals(is_source, need_display); + + return true; +} + +void WebRtcStream::Destroy(void) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return; + } + auto stop_ret = webrtc_stop(webrtc_handle_); + if (stop_ret != WEBRTC_ERROR_NONE) + ERR("Failed to stop webrtc handle"); + + auto ret = webrtc_destroy(webrtc_handle_); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to destroy webrtc handle"); + webrtc_handle_ = nullptr; +} + +bool WebRtcStream::Start(void) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + if (camera_handler_) + camera_handler_->StartPreview(); + + auto ret = webrtc_start(webrtc_handle_); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to start webrtc handle"); + + return ret == WEBRTC_ERROR_NONE; +} + +bool WebRtcStream::Stop(void) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + if (camera_handler_) + camera_handler_->StopPreview(); + + auto ret = webrtc_stop(webrtc_handle_); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to stop webrtc handle"); + + return ret == WEBRTC_ERROR_NONE; +} + +bool WebRtcStream::AttachCameraSource(void) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + + if (source_id_) { + ERR("source already attached"); + return false; + } + + auto ret = + webrtc_add_media_source(webrtc_handle_, WEBRTC_MEDIA_SOURCE_TYPE_CAMERA, &source_id_); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to add media source"); + + return ret == WEBRTC_ERROR_NONE; +} + +bool WebRtcStream::AttachCameraPreviewSource(void) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + + if (source_id_) { + ERR("source already attached"); + return false; + } + + camera_handler_ = std::make_unique<CameraHandler>(); + camera_handler_->Init(OnMediaPacketPreview, this); + + auto ret = webrtc_add_media_source(webrtc_handle_, WEBRTC_MEDIA_SOURCE_TYPE_MEDIA_PACKET, + &source_id_); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to add media source"); + + return ret == WEBRTC_ERROR_NONE; +} + +void WebRtcStream::OnMediaPacketPreview(media_packet_h media_packet, void *user_data) +{ + ERR("%s", __func__); + auto webrtc_stream = static_cast<WebRtcStream *>(user_data); + RET_IF(webrtc_stream == nullptr); + + if (webrtc_stream->is_source_overflow_) { + return; + } + if (webrtc_media_packet_source_push_packet(webrtc_stream->webrtc_handle_, + webrtc_stream->source_id_, media_packet) + != WEBRTC_ERROR_NONE) { + media_packet_destroy(media_packet); + } +} + +bool WebRtcStream::DetachCameraSource(void) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + + if (!source_id_) { + ERR("Camera source is not attached"); + return false; + } + + camera_handler_ = nullptr; + + auto ret = webrtc_remove_media_source(webrtc_handle_, source_id_); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to remove media source"); + + return ret == WEBRTC_ERROR_NONE; +} + +void WebRtcStream::SetDisplayObject(unsigned int id, void *object) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return; + } + + if (!object) { + ERR("Object is not specified"); + return; + } + + webrtc_set_display(webrtc_handle_, id, WEBRTC_DISPLAY_TYPE_EVAS, object); +} + +bool WebRtcStream::CreateOfferAsync(std::function<void(std::string)> on_created_cb) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + on_offer_created_cb_ = on_created_cb; + auto ret = webrtc_create_offer_async(webrtc_handle_, NULL, OnOfferCreated, this); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to create offer async"); + + return ret == WEBRTC_ERROR_NONE; +} + +void WebRtcStream::OnOfferCreated(webrtc_h webrtc, const char *description, void *user_data) +{ + RET_IF(!user_data); + + WebRtcStream *webrtc_stream = static_cast<WebRtcStream *>(user_data); + + if (webrtc_stream->on_offer_created_cb_) + webrtc_stream->on_offer_created_cb_(std::string(description)); +} + +bool WebRtcStream::CreateAnswerAsync(std::function<void(std::string)> on_created_cb) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + on_answer_created_cb_ = on_created_cb; + auto ret = webrtc_create_answer_async(webrtc_handle_, NULL, OnAnswerCreated, this); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to create answer async"); + + return ret == WEBRTC_ERROR_NONE; +} + +void WebRtcStream::OnAnswerCreated(webrtc_h webrtc, const char *description, void *user_data) +{ + if (!user_data) + return; + + WebRtcStream *webrtc_stream = static_cast<WebRtcStream *>(user_data); + if (webrtc_stream->on_answer_created_cb_) + webrtc_stream->on_answer_created_cb_(std::string(description)); +} + +bool WebRtcStream::SetLocalDescription(const std::string &description) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + auto ret = webrtc_set_local_description(webrtc_handle_, description.c_str()); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to set local description"); + + return ret == WEBRTC_ERROR_NONE; +} + +bool WebRtcStream::SetRemoteDescription(const std::string &description) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + + webrtc_state_e state; + auto get_state_ret = webrtc_get_state(webrtc_handle_, &state); + if (get_state_ret != WEBRTC_ERROR_NONE) { + ERR("Failed to get state"); + return false; + } + + if (state != WEBRTC_STATE_NEGOTIATING) { + remote_description_ = description; + ERR("Invalid state, will be registred at NEGOTIATING state"); + return true; + } + + auto ret = webrtc_set_remote_description(webrtc_handle_, description.c_str()); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to set remote description"); + + return ret == WEBRTC_ERROR_NONE; +} + +bool WebRtcStream::AddIceCandidateFromMessage(const std::string &ice_message) +{ + ERR("%s", __func__); + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return false; + } + auto ret = webrtc_add_ice_candidate(webrtc_handle_, ice_message.c_str()); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to set add ice candidate"); + + return ret == WEBRTC_ERROR_NONE; +} + +void WebRtcStream::AttachSignals(bool is_source, bool need_display) +{ + if (!webrtc_handle_) { + ERR("WebRTC handle is not created"); + return; + } + + int ret = WEBRTC_ERROR_NONE; + // TODO: ADHOC TV profile doesn't show DBG level log + ret = webrtc_set_error_cb(webrtc_handle_, OnError, this); + DBG("webrtc_set_error_cb %s", ret == WEBRTC_ERROR_NONE ? "Succeeded" : "failed"); + ret = webrtc_set_state_changed_cb(webrtc_handle_, OnStateChanged, this); + DBG("webrtc_set_state_changed_cb %s", ret == WEBRTC_ERROR_NONE ? "Succeeded" : "failed"); + ret = webrtc_set_signaling_state_change_cb(webrtc_handle_, OnSignalingStateChanged, this); + DBG("webrtc_set_signaling_state_change_cb %s", + ret == WEBRTC_ERROR_NONE ? "Succeeded" : "failed"); + ret = webrtc_set_ice_connection_state_change_cb(webrtc_handle_, OnIceConnectionStateChanged, + this); + DBG("webrtc_set_ice_connection_state_change_cb %s", + ret == WEBRTC_ERROR_NONE ? "Succeeded" : "failed"); + ret = webrtc_set_ice_candidate_cb(webrtc_handle_, OnIceCandiate, this); + DBG("webrtc_set_ice_candidate_cb %s", ret == WEBRTC_ERROR_NONE ? "Succeeded" : "failed"); + + if (!is_source && !need_display) { + ret = webrtc_set_encoded_video_frame_cb(webrtc_handle_, OnEncodedFrame, this); + ERR("webrtc_set_encoded_video_frame_cb %s", + ret == WEBRTC_ERROR_NONE ? "Succeeded" : "failed"); + } + + if (!is_source && need_display) { + ret = webrtc_set_track_added_cb(webrtc_handle_, OnTrackAdded, this); + ERR("webrtc_set_track_added_cb %s", ret == WEBRTC_ERROR_NONE ? "Succeeded" : "failed"); + } + + ret = webrtc_media_packet_source_set_buffer_state_changed_cb(webrtc_handle_, source_id_, + OnMediaPacketBufferStateChanged, this); + DBG("webrtc_media_packet_source_set_buffer_state_changed_cb %s", + ret == WEBRTC_ERROR_NONE ? "Succeeded" : "failed"); + + return; +} + +void WebRtcStream::OnError(webrtc_h webrtc, webrtc_error_e error, webrtc_state_e state, + void *user_data) +{ + // TODO + ERR("%s", __func__); +} + +void WebRtcStream::OnStateChanged(webrtc_h webrtc, webrtc_state_e previous, webrtc_state_e current, + void *user_data) +{ + ERR("%s", __func__); + auto webrtc_stream = static_cast<WebRtcStream *>(user_data); + RET_IF(webrtc_stream == nullptr); + + if (current == WEBRTC_STATE_NEGOTIATING && webrtc_stream->remote_description_.size() != 0) { + ERR("received remote description exists"); + auto ret = webrtc_set_remote_description(webrtc_stream->webrtc_handle_, + webrtc_stream->remote_description_.c_str()); + if (ret != WEBRTC_ERROR_NONE) + ERR("Failed to set remote description"); + webrtc_stream->remote_description_ = std::string(); + } + webrtc_stream->GetEventHandler().CallOnStateChangedCb(WebRtcState::ToStreamState(current)); +} + +void WebRtcStream::OnSignalingStateChanged(webrtc_h webrtc, webrtc_signaling_state_e state, + void *user_data) +{ + ERR("%s", __func__); + auto webrtc_stream = static_cast<WebRtcStream *>(user_data); + RET_IF(webrtc_stream == nullptr); + webrtc_stream->GetEventHandler().CallOnSignalingStateNotifyCb( + WebRtcState::ToSignalingState(state)); +} + +void WebRtcStream::OnIceConnectionStateChanged(webrtc_h webrtc, webrtc_ice_connection_state_e state, + void *user_data) +{ + ERR("%s %d", __func__, state); + auto webrtc_stream = static_cast<WebRtcStream *>(user_data); + RET_IF(webrtc_stream == nullptr); + + webrtc_stream->GetEventHandler().CallOnIceConnectionStateNotifyCb( + WebRtcState::ToIceConnectionState(state)); +} + +void WebRtcStream::OnIceCandiate(webrtc_h webrtc, const char *candidate, void *user_data) +{ + ERR("%s", __func__); + auto webrtc_stream = static_cast<WebRtcStream *>(user_data); + webrtc_stream->ice_candidates_.push_back(candidate); +} + +void WebRtcStream::OnEncodedFrame(webrtc_h webrtc, webrtc_media_type_e type, unsigned int track_id, + media_packet_h packet, void *user_data) +{ + ERR("%s", __func__); + // TODO +} + +void WebRtcStream::OnTrackAdded(webrtc_h webrtc, webrtc_media_type_e type, unsigned int id, + void *user_data) +{ + // type AUDIO(0), VIDEO(1) + INFO("Added Track : id(%d), type(%s)", id, type ? "Video" : "Audio"); + + ERR("%s", __func__); + auto webrtc_stream = static_cast<WebRtcStream *>(user_data); + RET_IF(webrtc_stream == nullptr); + + if (type == WEBRTC_MEDIA_TYPE_VIDEO) + webrtc_stream->GetEventHandler().CallOnTrakAddedCb(id); +} + +void WebRtcStream::OnMediaPacketBufferStateChanged(unsigned int source_id, + webrtc_media_packet_source_buffer_state_e state, void *user_data) +{ + ERR("%s", __func__); + auto webrtc_stream = static_cast<WebRtcStream *>(user_data); + RET_IF(webrtc_stream == nullptr); + + webrtc_stream->is_source_overflow_ = + (state == WEBRTC_MEDIA_PACKET_SOURCE_BUFFER_STATE_OVERFLOW); +} diff --git a/modules/webrtc/WebRtcStream.h b/modules/webrtc/WebRtcStream.h new file mode 100644 index 0000000..755c1ae --- /dev/null +++ b/modules/webrtc/WebRtcStream.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <functional> +#include <list> +#include <memory> +#include <mutex> +#include <string> + +// TODO: webrtc.h is very heavy header file. +// I think we need to decide whether to include this or not +#include <webrtc.h> + +#include "CameraHandler.h" +#include "WebRtcEventHandler.h" + +class WebRtcStream { + public: + ~WebRtcStream(); + bool Create(bool is_source, bool need_display); + void Destroy(void); + bool Start(void); + bool Stop(void); + bool AttachCameraSource(void); + bool AttachCameraPreviewSource(void); + static void OnMediaPacketPreview(media_packet_h media_packet, void *user_data); + bool DetachCameraSource(void); + void SetDisplayObject(unsigned int id, void *object); + void AttachSignals(bool is_source, bool need_display); + // Cautions : Event handler is not a pointer. So, change event_handle after Set Event handler + // doesn't affect event handler which is included int WebRtcStream + void SetEventHandler(WebRtcEventHandler event_handler) { event_handler_ = event_handler; }; + WebRtcEventHandler &GetEventHandler(void) { return event_handler_; }; + + bool CreateOfferAsync(std::function<void(std::string)> on_created_cb); + void CallOnOfferCreatedCb(std::string offer) + { + if (on_offer_created_cb_) + on_offer_created_cb_(offer); + } + bool CreateAnswerAsync(std::function<void(std::string)> on_created_cb); + void CallOnAnswerCreatedCb(std::string answer) + { + if (on_answer_created_cb_) + on_answer_created_cb_(answer); + } + void SetPreparedLocalDescription(const std::string &description) + { + local_description_ = description; + }; + std::string GetPreparedLocalDescription(void) const { return local_description_; }; + + bool SetLocalDescription(const std::string &description); + bool SetRemoteDescription(const std::string &description); + + bool AddIceCandidateFromMessage(const std::string &ice_message); + const std::vector<std::string> &GetIceCandidates() const { return ice_candidates_; }; + + std::string GetRemoteDescription(void) const { return remote_description_; }; + + private: + static void OnOfferCreated(webrtc_h webrtc, const char *description, void *user_data); + static void OnAnswerCreated(webrtc_h webrtc, const char *description, void *user_data); + static void OnError(webrtc_h webrtc, webrtc_error_e error, webrtc_state_e state, + void *user_data); + static void OnStateChanged(webrtc_h webrtc, webrtc_state_e previous, webrtc_state_e current, + void *user_data); + static void OnSignalingStateChanged(webrtc_h webrtc, webrtc_signaling_state_e state, + void *user_data); + static void OnIceConnectionStateChanged(webrtc_h webrtc, webrtc_ice_connection_state_e state, + void *user_data); + static void OnIceCandiate(webrtc_h webrtc, const char *candidate, void *user_data); + static void OnEncodedFrame(webrtc_h webrtc, webrtc_media_type_e type, unsigned int track_id, + media_packet_h packet, void *user_data); + static void OnTrackAdded(webrtc_h webrtc, webrtc_media_type_e type, unsigned int id, + void *user_data); + static void OnMediaPacketBufferStateChanged(unsigned int source_id, + webrtc_media_packet_source_buffer_state_e state, void *user_data); + + private: + webrtc_h webrtc_handle_; + std::shared_ptr<CameraHandler> camera_handler_; + // DO we need to make is_source_overflow_ as atomic? + bool is_source_overflow_; + unsigned int source_id_; + std::string local_description_; + std::string remote_description_; + std::vector<std::string> ice_candidates_; + std::function<void(std::string)> on_offer_created_cb_; + std::function<void(std::string)> on_answer_created_cb_; + WebRtcEventHandler event_handler_; +}; diff --git a/modules/webrtc/tests/CMakeLists.txt b/modules/webrtc/tests/CMakeLists.txt new file mode 100644 index 0000000..a1dd90f --- /dev/null +++ b/modules/webrtc/tests/CMakeLists.txt @@ -0,0 +1,21 @@ +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}/..) + +PKG_CHECK_MODULES(UT_NEEDS REQUIRED gmock_main) +INCLUDE_DIRECTORIES(${UT_NEEDS_INCLUDE_DIRS}) +LINK_DIRECTORIES(${UT_NEEDS_LIBRARY_DIRS}) + +SET(AITT_WEBRTC_UT ${PROJECT_NAME}_webrtc_ut) +SET(AITT_WEBRTC_UT_SRC WEBRTC_test.cc) + +ADD_EXECUTABLE(${AITT_WEBRTC_UT} ${AITT_WEBRTC_UT_SRC} $<TARGET_OBJECTS:WEBRTC_OBJ>) +TARGET_LINK_LIBRARIES(${AITT_WEBRTC_UT} ${UT_NEEDS_LIBRARIES} ${AITT_WEBRTC_NEEDS_LIBRARIES} ${AITT_COMMON}) +INSTALL(TARGETS ${AITT_WEBRTC_UT} DESTINATION ${AITT_TEST_BINDIR}) + +ADD_TEST( + NAME + ${AITT_WEBRTC_UT} + COMMAND + ${CMAKE_COMMAND} -E env + LD_LIBRARY_PATH=../../../common:$ENV{LD_LIBRARY_PATH} + ${CMAKE_CURRENT_BINARY_DIR}/${AITT_WEBRTC_UT} --gtest_filter=*_Anytime +) diff --git a/modules/webrtc/tests/MockPublishStream.h b/modules/webrtc/tests/MockPublishStream.h new file mode 100644 index 0000000..ced285e --- /dev/null +++ b/modules/webrtc/tests/MockPublishStream.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <memory> +#include <string> + +#include "Config.h" +#include "MqttServer.h" +#include "WebRtcRoom.h" + +class MockPublishStream { + // TODO: Notify & get status + public: + MockPublishStream() = delete; + MockPublishStream(const std::string &topic, const Config &config) + : topic_(topic), config_(config), server_(std::make_shared<MqttServer>(config)) + { + config.SetSourceId(config.GetLocalId()); + config.SetRoomId(config::MQTT_ROOM_PREFIX() + topic); + room_ = std::make_shared<WebRtcRoom>(config.GetRoomId()); + }; + ~MockPublishStream(); + + void Start(void) + { + SetSignalingServerCallbacks(); + SetRoomCallbacks(); + server_->Connect(); + } + void Stop(void){ + // TODO + }; + + private: + void SetSignalingServerCallbacks(void) + { + auto on_signaling_server_connection_state_changed = + std::bind(OnSignalingServerConnectionStateChanged, std::placeholders::_1, + std::ref(*room_), std::ref(*server_)); + + server_->SetConnectionStateChangedCb(on_signaling_server_connection_state_changed); + + auto on_room_message_arrived = + std::bind(OnRoomMessageArrived, std::placeholders::_1, std::ref(*room_)); + + server_->SetRoomMessageArrivedCb(on_room_message_arrived); + }; + + void SetRoomCallbacks(void) + { + auto on_peer_joined = + std::bind(OnPeerJoined, std::placeholders::_1, std::ref(*server_), std::ref(*room_)); + room_->SetPeerJoinedCb(on_peer_joined); + + auto on_peer_left = std::bind(OnPeerLeft, std::placeholders::_1, std::ref(*room_)); + room_->SetPeerLeftCb(on_peer_left); + }; + + static void OnSignalingServerConnectionStateChanged(IfaceServer::ConnectionState state, + WebRtcRoom &room, MqttServer &server) + { + DBG("current state [%s]", SignalingServer::GetConnectionStateStr(state).c_str()); + + if (state == IfaceServer::ConnectionState::Registered) + server.JoinRoom(room.getId()); + }; + + static void OnRoomMessageArrived(const std::string &message, WebRtcRoom &room) + { + room.handleMessage(message); + }; + + static void OnPeerJoined(const std::string &peer_id, MqttServer &server, WebRtcRoom &room) + { + DBG("%s [%s]", __func__, peer_id.c_str()); + }; + + static void OnPeerLeft(const std::string &peer_id, WebRtcRoom &room) + { + DBG("%s [%s]", __func__, peer_id.c_str()); + }; + + private: + std::string topic_; + config config_; + std::shared_ptr<MqttServer> server_; + std::shared_ptr<WebRtcRoom> room_; +}; diff --git a/modules/webrtc/tests/MockSubscribeStream.h b/modules/webrtc/tests/MockSubscribeStream.h new file mode 100644 index 0000000..ea2736d --- /dev/null +++ b/modules/webrtc/tests/MockSubscribeStream.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <memory> + +#include "Config.h" +#include "MqttServer.h" + +class MockSubscribeStream { + public: + MockSubscribeStream() = delete; + MockSubscribeStream(const std::string &topic, const Config &config) + : topic_(topic), config_(config), server_(std::make_shared<SignalingServer>(config)) + { + config.SetRoomId(config::MQTT_ROOM_PREFIX() + topic); + room_ = std::make_shared<WebRtcRoom>(config.GetRoomId()); + }; + ~MockSubscribeStream(){}; + + void Start(void) + { + SetSignalingServerCallbacks(); + SetRoomCallbacks(); + server_->Connect(); + }; + + void Stop(void){ + //TODO + }; + + private: + void SetSignalingServerCallbacks(void) + { + auto on_signaling_server_connection_state_changed = + std::bind(OnSignalingServerConnectionStateChanged, std::placeholders::_1, + std::ref(*room_), std::ref(*server_)); + + server_->SetConnectionStateChangedCb(on_signaling_server_connection_state_changed); + + auto on_room_message_arrived = + std::bind(OnRoomMessageArrived, std::placeholders::_1, std::ref(*room_)); + + server_->SetRoomMessageArrivedCb(on_room_message_arrived); + }; + + void SetRoomCallbacks(void) + { + auto on_peer_joined = + std::bind(OnPeerJoined, std::placeholders::_1, std::ref(*server_), std::ref(*room_)); + room_->SetPeerJoinedCb(on_peer_joined); + + auto on_peer_left = std::bind(OnPeerLeft, std::placeholders::_1, std::ref(*room_)); + room_->SetPeerLeftCb(on_peer_left); + }; + + static void OnSignalingServerConnectionStateChanged(IfaceServer::ConnectionState state, + WebRtcRoom &room, MqttServer &server) + { + DBG("current state [%s]", SignalingServer::GetConnectionStateStr(state).c_str()); + + if (state == IfaceServer::ConnectionState::Registered) + server.JoinRoom(room.getId()); + }; + + static void OnRoomMessageArrived(const std::string &message, WebRtcRoom &room) + { + room.handleMessage(message); + }; + + static void OnPeerJoined(const std::string &peer_id, MqttServer &server, WebRtcRoom &room) + { + DBG("%s [%s]", __func__, peer_id.c_str()); + }; + + static void OnPeerLeft(const std::string &peer_id, WebRtcRoom &room) + { + DBG("%s [%s]", __func__, peer_id.c_str()); + }; + + private: + std::string topic_; + Config config_; + std::shared_ptr<MqttServer> server_; + std::shared_ptr<WebRtcRoom> room_; +}; diff --git a/modules/webrtc/tests/WEBRTC_test.cc b/modules/webrtc/tests/WEBRTC_test.cc new file mode 100644 index 0000000..5ee42c1 --- /dev/null +++ b/modules/webrtc/tests/WEBRTC_test.cc @@ -0,0 +1,598 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <glib.h> +#include <gtest/gtest.h> + +#include <chrono> +#include <set> +#include <thread> + +#include "AITTEx.h" +#include "Config.h" +#include "MqttServer.h" +#include "aitt_internal.h" + +#define DEFAULT_BROKER_IP "127.0.0.1" +#define DEFAULT_BROKER_PORT 1883 + +#define DEFAULT_WEBRTC_SRC_ID "webrtc_src" +#define DEFAULT_FIRST_SINK_ID "webrtc_first_sink" +#define DEFAULT_SECOND_SINK_ID "webrtc_second_sink" +#define DEFAULT_ROOM_ID AITT_MANAGED_TOPIC_PREFIX "webrtc/room/Room.webrtc" + +class MqttServerTest : public testing::Test { + protected: + void SetUp() override + { + webrtc_src_config_ = Config(DEFAULT_WEBRTC_SRC_ID, DEFAULT_BROKER_IP, DEFAULT_BROKER_PORT, + DEFAULT_ROOM_ID, DEFAULT_WEBRTC_SRC_ID); + webrtc_first_sink_config_ = Config(DEFAULT_FIRST_SINK_ID, DEFAULT_BROKER_IP, + DEFAULT_BROKER_PORT, DEFAULT_ROOM_ID); + webrtc_second_sink_config_ = Config(DEFAULT_SECOND_SINK_ID, DEFAULT_BROKER_IP, + DEFAULT_BROKER_PORT, DEFAULT_ROOM_ID); + + loop_ = g_main_loop_new(nullptr, FALSE); + } + + void TearDown() override { g_main_loop_unref(loop_); } + + protected: + Config webrtc_src_config_; + Config webrtc_first_sink_config_; + Config webrtc_second_sink_config_; + GMainLoop *loop_; +}; +static void onConnectionStateChanged(IfaceServer::ConnectionState state, MqttServer &server, + GMainLoop *loop) +{ + if (state == IfaceServer::ConnectionState::Registered) { + EXPECT_EQ(server.IsConnected(), true) << "should return Connected"; + g_main_loop_quit(loop); + } +} + +TEST_F(MqttServerTest, Positive_Connect_Anytime) +{ + try { + MqttServer server(webrtc_src_config_); + EXPECT_EQ(server.IsConnected(), false) << "Should return not connected"; + + auto on_connection_state_changed = + std::bind(onConnectionStateChanged, std::placeholders::_1, std::ref(server), loop_); + server.SetConnectionStateChangedCb(on_connection_state_changed); + + server.Connect(); + + g_main_loop_run(loop_); + + server.UnsetConnectionStateChangedCb(); + server.Disconnect(); + } catch (...) { + FAIL() << "Expected No throw"; + } +} +static int Positive_Connect_Src_Sinks_Anytime_connect_count; +static void onConnectionStateChangedPositive_Connect_Src_Sinks_Anytime( + IfaceServer::ConnectionState state, MqttServer &server, GMainLoop *loop) +{ + if (state == IfaceServer::ConnectionState::Registered) { + EXPECT_EQ(server.IsConnected(), true) << "should return Connected"; + ++Positive_Connect_Src_Sinks_Anytime_connect_count; + if (Positive_Connect_Src_Sinks_Anytime_connect_count == 3) { + g_main_loop_quit(loop); + } + } +} + +TEST_F(MqttServerTest, Positive_Connect_Src_Sinks_Anytime) +{ + try { + Positive_Connect_Src_Sinks_Anytime_connect_count = 0; + MqttServer src_server(webrtc_src_config_); + EXPECT_EQ(src_server.IsConnected(), false) << "Should return not connected"; + + auto on_src_connection_state_changed = + std::bind(onConnectionStateChangedPositive_Connect_Src_Sinks_Anytime, + std::placeholders::_1, std::ref(src_server), loop_); + src_server.SetConnectionStateChangedCb(on_src_connection_state_changed); + + src_server.Connect(); + + MqttServer first_sink_server(webrtc_first_sink_config_); + EXPECT_EQ(first_sink_server.IsConnected(), false) << "Should return not connected"; + + auto on_first_sink_connection_state_changed = + std::bind(onConnectionStateChangedPositive_Connect_Src_Sinks_Anytime, + std::placeholders::_1, std::ref(first_sink_server), loop_); + first_sink_server.SetConnectionStateChangedCb(on_first_sink_connection_state_changed); + + first_sink_server.Connect(); + + MqttServer second_sink_server(webrtc_second_sink_config_); + EXPECT_EQ(second_sink_server.IsConnected(), false) << "Should return not connected"; + + auto on_second_sink_connection_state_changed = + std::bind(onConnectionStateChangedPositive_Connect_Src_Sinks_Anytime, + std::placeholders::_1, std::ref(second_sink_server), loop_); + second_sink_server.SetConnectionStateChangedCb(on_second_sink_connection_state_changed); + + second_sink_server.Connect(); + + g_main_loop_run(loop_); + + src_server.UnsetConnectionStateChangedCb(); + first_sink_server.UnsetConnectionStateChangedCb(); + second_sink_server.UnsetConnectionStateChangedCb(); + src_server.Disconnect(); + first_sink_server.Disconnect(); + second_sink_server.Disconnect(); + } catch (...) { + FAIL() << "Expected No throw"; + } +} + +TEST_F(MqttServerTest, Negative_Disconnect_Anytime) +{ + EXPECT_THROW( + { + try { + MqttServer server(webrtc_src_config_); + EXPECT_EQ(server.IsConnected(), false) << "Should return not connected"; + + server.Disconnect(); + + g_main_loop_run(loop_); + } catch (const aitt::AITTEx &e) { + // and this tests that it has the correct message + throw; + } + }, + aitt::AITTEx); +} + +TEST_F(MqttServerTest, Positive_Disconnect_Anytime) +{ + try { + MqttServer server(webrtc_src_config_); + EXPECT_EQ(server.IsConnected(), false); + + auto on_connection_state_changed = + std::bind(onConnectionStateChanged, std::placeholders::_1, std::ref(server), loop_); + server.SetConnectionStateChangedCb(on_connection_state_changed); + + server.Connect(); + + g_main_loop_run(loop_); + + server.UnsetConnectionStateChangedCb(); + server.Disconnect(); + + EXPECT_EQ(server.IsConnected(), false) << "Should return not connected"; + } catch (...) { + FAIL() << "Expected No throw"; + } +} + +TEST_F(MqttServerTest, Negative_Register_Anytime) +{ + EXPECT_THROW( + { + try { + MqttServer server(webrtc_src_config_); + EXPECT_EQ(server.IsConnected(), false) << "Should return not connected"; + + server.RegisterWithServer(); + } catch (const std::runtime_error &e) { + // and this tests that it has the correct message + throw; + } + }, + std::runtime_error); +} + +TEST_F(MqttServerTest, Negative_JoinRoom_Invalid_Parameter_Anytime) +{ + EXPECT_THROW( + { + try { + MqttServer server(webrtc_src_config_); + EXPECT_EQ(server.IsConnected(), false) << "Should return not connected"; + + server.JoinRoom(std::string("InvalidRoomId")); + + } catch (const std::runtime_error &e) { + // and this tests that it has the correct message + throw; + } + }, + std::runtime_error); +} + +static void joinRoomOnRegisteredQuit(IfaceServer::ConnectionState state, MqttServer &server, + GMainLoop *loop) +{ + if (state != IfaceServer::ConnectionState::Registered) { + return; + } + + EXPECT_EQ(server.IsConnected(), true) << "should return Connected"; + try { + server.JoinRoom(DEFAULT_ROOM_ID); + g_main_loop_quit(loop); + } catch (...) { + FAIL() << "Expected No throw"; + } +} + +TEST_F(MqttServerTest, Positive_JoinRoom_Anytime) +{ + try { + MqttServer server(webrtc_src_config_); + EXPECT_EQ(server.IsConnected(), false) << "Should return not connected"; + + auto join_room_on_registered = + std::bind(joinRoomOnRegisteredQuit, std::placeholders::_1, std::ref(server), loop_); + server.SetConnectionStateChangedCb(join_room_on_registered); + + server.Connect(); + + g_main_loop_run(loop_); + + server.UnsetConnectionStateChangedCb(); + server.Disconnect(); + } catch (...) { + FAIL() << "Expected No throw"; + } +} + +static void joinRoomOnRegistered(IfaceServer::ConnectionState state, MqttServer &server) +{ + if (state != IfaceServer::ConnectionState::Registered) { + return; + } + + EXPECT_EQ(server.IsConnected(), true) << "should return Connected"; + try { + server.JoinRoom(DEFAULT_ROOM_ID); + } catch (...) { + FAIL() << "Expected No throw"; + } +} + +static void onSrcMessage(const std::string &msg, MqttServer &server, GMainLoop *loop) +{ + if (msg.compare(0, 16, "ROOM_PEER_JOINED") == 0) { + std::string peer_id = msg.substr(17, std::string::npos); + EXPECT_EQ(peer_id.compare(std::string(DEFAULT_FIRST_SINK_ID)), 0) + << "Not expected peer" << peer_id; + + } else if (msg.compare(0, 14, "ROOM_PEER_LEFT") == 0) { + std::string peer_id = msg.substr(15, std::string::npos); + EXPECT_EQ(peer_id.compare(std::string(DEFAULT_FIRST_SINK_ID)), 0) + << "Not expected peer" << peer_id; + g_main_loop_quit(loop); + } else { + FAIL() << "Invalid type of Room message " << msg; + } +} + +static void onSinkMessage(const std::string &msg, MqttServer &server, GMainLoop *loop) +{ + if (msg.compare(0, 16, "ROOM_PEER_JOINED") == 0) { + std::string peer_id = msg.substr(17, std::string::npos); + EXPECT_EQ(peer_id.compare(std::string(DEFAULT_WEBRTC_SRC_ID)), 0) + << "Not expected peer" << peer_id; + server.Disconnect(); + } else { + FAIL() << "Invalid type of Room message " << msg; + } +} + +TEST_F(MqttServerTest, Positive_src_sink) +{ + try { + MqttServer src_server(webrtc_src_config_); + auto join_room_on_registered_src = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(src_server)); + src_server.SetConnectionStateChangedCb(join_room_on_registered_src); + + auto on_src_message = + std::bind(onSrcMessage, std::placeholders::_1, std::ref(src_server), loop_); + src_server.SetRoomMessageArrivedCb(on_src_message); + src_server.Connect(); + + MqttServer sink_server(webrtc_first_sink_config_); + auto join_room_on_registered_sink = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(sink_server)); + sink_server.SetConnectionStateChangedCb(join_room_on_registered_sink); + + auto on_sink_message = + std::bind(onSinkMessage, std::placeholders::_1, std::ref(sink_server), loop_); + sink_server.SetRoomMessageArrivedCb(on_sink_message); + + sink_server.Connect(); + + g_main_loop_run(loop_); + + src_server.UnsetConnectionStateChangedCb(); + sink_server.UnsetConnectionStateChangedCb(); + src_server.Disconnect(); + } catch (...) { + FAIL() << "Expected No throw"; + } +} + +TEST_F(MqttServerTest, Positive_sink_src) +{ + try { + MqttServer sink_server(webrtc_first_sink_config_); + auto join_room_on_registered_sink = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(sink_server)); + sink_server.SetConnectionStateChangedCb(join_room_on_registered_sink); + + auto on_sink_message = + std::bind(onSinkMessage, std::placeholders::_1, std::ref(sink_server), loop_); + sink_server.SetRoomMessageArrivedCb(on_sink_message); + + sink_server.Connect(); + + MqttServer src_server(webrtc_src_config_); + auto join_room_on_registered_src = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(src_server)); + src_server.SetConnectionStateChangedCb(join_room_on_registered_src); + + auto on_src_message = + std::bind(onSrcMessage, std::placeholders::_1, std::ref(src_server), loop_); + src_server.SetRoomMessageArrivedCb(on_src_message); + src_server.Connect(); + + g_main_loop_run(loop_); + + src_server.UnsetConnectionStateChangedCb(); + sink_server.UnsetConnectionStateChangedCb(); + src_server.Disconnect(); + } catch (...) { + FAIL() << "Expected No throw"; + } +} + +static void onSrcMessageDisconnect(const std::string &msg, MqttServer &server, GMainLoop *loop) +{ + if (msg.compare(0, 16, "ROOM_PEER_JOINED") == 0) { + std::string peer_id = msg.substr(17, std::string::npos); + EXPECT_EQ(peer_id.compare(std::string(DEFAULT_FIRST_SINK_ID)), 0) + << "Not expected peer" << peer_id; + server.Disconnect(); + + } else { + FAIL() << "Invalid type of Room message " << msg; + } +} + +static void onSinkMessageDisconnect(const std::string &msg, MqttServer &server, GMainLoop *loop) +{ + if (msg.compare(0, 16, "ROOM_PEER_JOINED") == 0) { + std::string peer_id = msg.substr(17, std::string::npos); + EXPECT_EQ(peer_id.compare(std::string(DEFAULT_WEBRTC_SRC_ID)), 0) + << "Not expected peer" << peer_id; + } else if (msg.compare(0, 14, "ROOM_PEER_LEFT") == 0) { + std::string peer_id = msg.substr(15, std::string::npos); + EXPECT_EQ(peer_id.compare(std::string(DEFAULT_WEBRTC_SRC_ID)), 0) + << "Not expected peer" << peer_id; + g_main_loop_quit(loop); + } else { + FAIL() << "Invalid type of Room message " << msg; + } +} + +TEST_F(MqttServerTest, Positive_src_sink_disconnect_src_first_Anytime) +{ + try { + MqttServer src_server(webrtc_src_config_); + auto join_room_on_registered_src = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(src_server)); + src_server.SetConnectionStateChangedCb(join_room_on_registered_src); + + auto on_src_message = + std::bind(onSrcMessageDisconnect, std::placeholders::_1, std::ref(src_server), loop_); + src_server.SetRoomMessageArrivedCb(on_src_message); + src_server.Connect(); + + MqttServer sink_server(webrtc_first_sink_config_); + auto join_room_on_registered_sink = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(sink_server)); + sink_server.SetConnectionStateChangedCb(join_room_on_registered_sink); + + auto on_sink_message = std::bind(onSinkMessageDisconnect, std::placeholders::_1, + std::ref(sink_server), loop_); + sink_server.SetRoomMessageArrivedCb(on_sink_message); + + sink_server.Connect(); + + g_main_loop_run(loop_); + + src_server.UnsetConnectionStateChangedCb(); + sink_server.UnsetConnectionStateChangedCb(); + sink_server.Disconnect(); + } catch (...) { + FAIL() << "Expected No throw"; + } +} + +TEST_F(MqttServerTest, Positive_sink_src_disconnect_src_first_Anytime) +{ + try { + MqttServer sink_server(webrtc_first_sink_config_); + auto join_room_on_registered_sink = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(sink_server)); + sink_server.SetConnectionStateChangedCb(join_room_on_registered_sink); + + auto on_sink_message = std::bind(onSinkMessageDisconnect, std::placeholders::_1, + std::ref(sink_server), loop_); + sink_server.SetRoomMessageArrivedCb(on_sink_message); + + sink_server.Connect(); + + MqttServer src_server(webrtc_src_config_); + auto join_room_on_registered_src = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(src_server)); + src_server.SetConnectionStateChangedCb(join_room_on_registered_src); + + auto on_src_message = + std::bind(onSrcMessageDisconnect, std::placeholders::_1, std::ref(src_server), loop_); + src_server.SetRoomMessageArrivedCb(on_src_message); + src_server.Connect(); + + g_main_loop_run(loop_); + + src_server.UnsetConnectionStateChangedCb(); + sink_server.UnsetConnectionStateChangedCb(); + sink_server.Disconnect(); + } catch (...) { + FAIL() << "Expected No throw"; + } +} + +static int handled_sink; +static int expected_sink; + +std::set<std::string> sink_set; + +static void onSrcMessageThreeWay(const std::string &msg, MqttServer &server, GMainLoop *loop) +{ + if (msg.compare(0, 16, "ROOM_PEER_JOINED") == 0) { + auto peer_id = msg.substr(17, std::string::npos); + sink_set.insert(peer_id); + server.SendMessage(peer_id, "Three"); + + } else if (msg.compare(0, 14, "ROOM_PEER_LEFT") == 0) { + auto peer_id = msg.substr(15, std::string::npos); + + if (sink_set.find(peer_id) != sink_set.end()) + sink_set.erase(peer_id); + + if (sink_set.size() == 0 && handled_sink == expected_sink) + g_main_loop_quit(loop); + + } else if (msg.compare(0, 13, "ROOM_PEER_MSG") == 0) { + auto peer_msg = msg.substr(14, std::string::npos); + std::size_t pos = peer_msg.find(' '); + if (pos == std::string::npos) + FAIL() << "Invalid type of peer message" << msg; + + auto peer_id = peer_msg.substr(0, pos); + auto received_msg = peer_msg.substr(pos + 1, std::string::npos); + + if (received_msg.compare("Way") == 0) { + server.SendMessage(peer_id, "HandShake"); + ++handled_sink; + } else + FAIL() << "Can't understand message" << received_msg; + + } else { + FAIL() << "Invalid type of Room message " << msg; + } +} + +static void onSinkMessageThreeWay(const std::string &msg, MqttServer &server) +{ + if (msg.compare(0, 16, "ROOM_PEER_JOINED") == 0) { + auto peer_id = msg.substr(17, std::string::npos); + + EXPECT_EQ(peer_id.compare(std::string(DEFAULT_WEBRTC_SRC_ID)), 0) + << "Not expected peer" << peer_id; + + } else if (msg.compare(0, 14, "ROOM_PEER_LEFT") == 0) { + auto peer_id = msg.substr(15, std::string::npos); + + EXPECT_EQ(peer_id.compare(std::string(DEFAULT_WEBRTC_SRC_ID)), 0) + << "Not expected peer" << peer_id; + + server.Disconnect(); + + } else if (msg.compare(0, 13, "ROOM_PEER_MSG") == 0) { + auto peer_msg = msg.substr(14, std::string::npos); + std::size_t pos = peer_msg.find(' '); + if (pos == std::string::npos) + FAIL() << "Invalid type of peer message" << msg; + + auto peer_id = peer_msg.substr(0, pos); + auto received_msg = peer_msg.substr(pos + 1, std::string::npos); + + EXPECT_EQ(peer_id.compare(std::string(DEFAULT_WEBRTC_SRC_ID)), 0) + << "Not expected peer " << peer_id; + + if (received_msg.compare("Three") == 0) + server.SendMessage(peer_id, "Way"); + else if (received_msg.compare("HandShake") == 0) + server.Disconnect(); + else + FAIL() << "Can't understand message" << received_msg; + } else { + FAIL() << "Invalid type of Room message " << msg; + } +} + +TEST_F(MqttServerTest, Positive_SendMessageThreeWay_Src_Sinks1_Anytime) +{ + try { + handled_sink = 0; + expected_sink = 2; + MqttServer src_server(webrtc_src_config_); + + auto join_room_on_registered_src = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(src_server)); + src_server.SetConnectionStateChangedCb(join_room_on_registered_src); + + auto on_src_message = + std::bind(onSrcMessageThreeWay, std::placeholders::_1, std::ref(src_server), loop_); + src_server.SetRoomMessageArrivedCb(on_src_message); + src_server.Connect(); + + MqttServer first_sink_server(webrtc_first_sink_config_); + + auto join_room_on_registered_first_sink = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(first_sink_server)); + first_sink_server.SetConnectionStateChangedCb(join_room_on_registered_first_sink); + + auto on_first_sink_message = + std::bind(onSinkMessageThreeWay, std::placeholders::_1, std::ref(first_sink_server)); + first_sink_server.SetRoomMessageArrivedCb(on_first_sink_message); + first_sink_server.Connect(); + + MqttServer second_sink_server(webrtc_second_sink_config_); + + auto join_room_on_registered_second_sink = + std::bind(joinRoomOnRegistered, std::placeholders::_1, std::ref(second_sink_server)); + second_sink_server.SetConnectionStateChangedCb(join_room_on_registered_second_sink); + + auto on_second_sink_message = + std::bind(onSinkMessageThreeWay, std::placeholders::_1, std::ref(second_sink_server)); + second_sink_server.SetRoomMessageArrivedCb(on_second_sink_message); + + second_sink_server.Connect(); + + g_main_loop_run(loop_); + + src_server.UnsetConnectionStateChangedCb(); + first_sink_server.UnsetConnectionStateChangedCb(); + second_sink_server.UnsetConnectionStateChangedCb(); + src_server.Disconnect(); + } catch (...) { + FAIL() << "Expected No throw"; + } +} |