modules: add thrift module

Add glue code for the thrift module. This includes:
* workarounds for Zephyr's missing C++ facilities
* thrift config.h

This code was merged from the following repository
at the commit specified below, with minor formatting
and coding-style modifications.

https://github.com/zephyrproject-rtos/gsoc-2022-thrift
e12e014d295918cc5ba0b4c507d1bf595a2f539a

Signed-off-by: Chris Friedt <cfriedt@meta.com>
This commit is contained in:
Chris Friedt 2023-01-23 07:45:25 -05:00 committed by Stephanos Ioannidis
parent c5ff1bebe2
commit 0c00a3ea79
21 changed files with 3559 additions and 0 deletions

View file

@ -36,6 +36,7 @@ source "modules/Kconfig.st"
source "modules/Kconfig.stm32"
source "modules/Kconfig.syst"
source "modules/Kconfig.telink"
source "modules/thrift/Kconfig"
source "modules/Kconfig.tinycrypt"
source "modules/Kconfig.vega"
source "modules/Kconfig.wurthelektronik"
@ -95,6 +96,9 @@ comment "zcbor module not available."
comment "CHRE module not available."
depends on !ZEPHYR_CHRE_MODULE
comment "THRIFT module not available."
depends on !ZEPHYR_THRIFT_MODULE
# This ensures that symbols are available in Kconfig for dependency checking
# and referencing, while keeping the settings themselves unavailable when the
# modules are not present in the workspace

View file

@ -0,0 +1,43 @@
# Copyright 2022 Meta
# SPDX-License-Identifier: Apache-2.0
if(CONFIG_THRIFT)
set(THRIFT_UPSTREAM ${ZEPHYR_THRIFT_MODULE_DIR})
zephyr_library()
zephyr_include_directories(src)
zephyr_include_directories(include)
zephyr_include_directories(${THRIFT_UPSTREAM}/lib/cpp/src)
zephyr_library_sources(
src/_stat.c
src/thrift/server/TFDServer.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/protocol/TProtocol.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/server/TConnectedClient.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/server/TSimpleServer.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/SocketCommon.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TBufferTransports.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TFDTransport.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TTransportException.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TServerSocket.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TSocket.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/TApplicationException.cpp
${THRIFT_UPSTREAM}/lib/cpp/src/thrift/TOutput.cpp
# Replace with upstream equivalents when Zephyr's std::thread, etc, are fixed
src/thrift/concurrency/Mutex.cpp
src/thrift/server/TServerFramework.cpp
)
zephyr_library_sources_ifdef(CONFIG_THRIFT_SSL_SOCKET
# Replace with upstream equivalents when Zephyr's std::thread, etc, are fixed
src/thrift/transport/TSSLSocket.cpp
src/thrift/transport/TSSLServerSocket.cpp
)
# needed because std::iterator was deprecated with -std=c++17
zephyr_library_compile_options(-Wno-deprecated-declarations)
endif(CONFIG_THRIFT)

31
modules/thrift/Kconfig Normal file
View file

@ -0,0 +1,31 @@
# Copyright 2022 Meta
# SPDX-License-Identifier: Apache-2.0
config ZEPHYR_THRIFT_MODULE
bool
menuconfig THRIFT
bool "Support for Thrift [EXPERIMENTAL]"
select EXPERIMENTAL
depends on CPP
depends on STD_CPP17
depends on CPP_EXCEPTIONS
depends on POSIX_API
help
Enable this option to support Apache Thrift
if THRIFT
config THRIFT_SSL_SOCKET
bool "TSSLSocket support for Thrift"
depends on MBEDTLS
depends on MBEDTLS_PEM_CERTIFICATE_FORMAT
depends on NET_SOCKETS_SOCKOPT_TLS
help
Enable this option to support TSSLSocket for Thrift
module = THRIFT
module-str = THRIFT
source "subsys/logging/Kconfig.template.log_config"
endif # THRIFT

View file

@ -0,0 +1,33 @@
# Copyright 2022 Meta
# SPDX-License-Identifier: Apache-2.0
find_program(THRIFT_EXECUTABLE thrift)
if(NOT THRIFT_EXECUTABLE)
message(FATAL_ERROR "The 'thrift' command was not found")
endif()
function(thrift
target # CMake target (for dependencies / headers)
lang # The language for generated sources
lang_opts # Language options (e.g. ':no_skeleton')
out_dir # Output directory for generated files
# (do not include 'gen-cpp', etc)
source_file # The .thrift source file
options # Additional thrift options
# Generated files in ${ARGN}
)
file(MAKE_DIRECTORY ${out_dir})
add_custom_command(
OUTPUT ${ARGN}
COMMAND
${THRIFT_EXECUTABLE}
--gen ${lang}${lang_opts}
-o ${out_dir}
${source_file}
${options}
DEPENDS ${source_file}
)
target_include_directories(${target} PRIVATE ${out_dir}/gen-${lang})
endfunction()

View file

@ -0,0 +1,20 @@
/*
* Copyright 2022 Meta
*
* SPDX-License-Identifier: Apache-2.0
*/
#include <errno.h>
#include <sys/stat.h>
#include <zephyr/kernel.h>
int stat(const char *restrict path, struct stat *restrict buf)
{
ARG_UNUSED(path);
ARG_UNUSED(buf);
errno = ENOTSUP;
return -1;
}

View file

@ -0,0 +1,44 @@
/*
* Copyright 2022 Meta
*
* SPDX-License-Identifier: Apache-2.0
*/
#include <thrift/concurrency/Mutex.h>
namespace apache
{
namespace thrift
{
namespace concurrency
{
Mutex::Mutex()
{
}
void Mutex::lock() const
{
}
bool Mutex::trylock() const
{
return false;
}
bool Mutex::timedlock(int64_t milliseconds) const
{
return false;
}
void Mutex::unlock() const
{
}
void *Mutex::getUnderlyingImpl() const
{
return nullptr;
}
} // namespace concurrency
} // namespace thrift
} // namespace apache

View file

@ -0,0 +1,183 @@
/*
* Copyright (c) 2023 Meta
*
* SPDX-License-Identifier: Apache-2.0
*/
/* config.h. Generated from config.hin by configure. */
/* config.hin. Generated from configure.ac by autoheader. */
#ifndef ZEPHYR_MODULES_THRIFT_SRC_THRIFT_CONFIG_H_
#define ZEPHYR_MODULES_THRIFT_SRC_THRIFT_CONFIG_H_
/* Possible value for SIGNED_RIGHT_SHIFT_IS */
#define ARITHMETIC_RIGHT_SHIFT 1
/* Define to 1 if you have the <arpa/inet.h> header file. */
#define HAVE_ARPA_INET_H 1
/* Define to 1 if you have the `clock_gettime' function. */
#define HAVE_CLOCK_GETTIME 1
/* define if the compiler supports basic C++11 syntax */
#define HAVE_CXX11 1
/* Define to 1 if you have the <fcntl.h> header file. */
#define HAVE_FCNTL_H 1
/* Define to 1 if you have the `gethostbyname' function. */
#define HAVE_GETHOSTBYNAME 1
/* Define to 1 if you have the `gettimeofday' function. */
#define HAVE_GETTIMEOFDAY 1
/* Define to 1 if you have the `inet_ntoa' function. */
#define HAVE_INET_NTOA 1
/* Define to 1 if you have the <inttypes.h> header file. */
#define HAVE_INTTYPES_H 1
/* Define to 1 if you have the <limits.h> header file. */
#define HAVE_LIMITS_H 1
/* Define to 1 if your system has a GNU libc compatible `malloc' function, and to 0 otherwise. */
#define HAVE_MALLOC 1
/* Define to 1 if you have the `memmove' function. */
#define HAVE_MEMMOVE 1
/* Define to 1 if you have the <memory.h> header file. */
#define HAVE_MEMORY_H 1
/* Define to 1 if you have the `memset' function. */
#define HAVE_MEMSET 1
/* Define to 1 if you have the `mkdir' function. */
#define HAVE_MKDIR 1
/* Define to 1 if you have the <netdb.h> header file. */
#define HAVE_NETDB_H 1
/* Define to 1 if you have the <netinet/in.h> header file. */
#define HAVE_NETINET_IN_H 1
/* Define to 1 if you have the <poll.h> header file. */
#define HAVE_POLL_H 1
/* Define to 1 if you have the <pthread.h> header file. */
#define HAVE_PTHREAD_H 1
/* Define to 1 if the system has the type `ptrdiff_t'. */
#define HAVE_PTRDIFF_T 1
/* Define to 1 if your system has a GNU libc compatible `realloc' function, and to 0 otherwise. */
#define HAVE_REALLOC 1
/* Define to 1 if you have the <sched.h> header file. */
#define HAVE_SCHED_H 1
/* Define to 1 if you have the `select' function. */
#define HAVE_SELECT 1
/* Define to 1 if you have the `socket' function. */
#define HAVE_SOCKET 1
/* Define to 1 if stdbool.h conforms to C99. */
#define HAVE_STDBOOL_H 1
/* Define to 1 if you have the <stddef.h> header file. */
#define HAVE_STDDEF_H 1
/* Define to 1 if you have the <stdint.h> header file. */
#define HAVE_STDINT_H 1
/* Define to 1 if you have the <stdlib.h> header file. */
#define HAVE_STDLIB_H 1
/* Define to 1 if you have the `strchr' function. */
#define HAVE_STRCHR 1
/* Define to 1 if you have the `strdup' function. */
#define HAVE_STRDUP 1
/* Define to 1 if you have the `strerror' function. */
#define HAVE_STRERROR 1
/* Define to 1 if you have the `strerror_r' function. */
#define HAVE_STRERROR_R 1
/* Define to 1 if you have the `strftime' function. */
#define HAVE_STRFTIME 1
/* Define to 1 if you have the <strings.h> header file. */
#define HAVE_STRINGS_H 1
/* Define to 1 if you have the <string.h> header file. */
#define HAVE_STRING_H 1
/* Define to 1 if you have the `strstr' function. */
#define HAVE_STRSTR 1
/* Define to 1 if you have the `strtol' function. */
#define HAVE_STRTOL 1
/* Define to 1 if you have the `strtoul' function. */
#define HAVE_STRTOUL 1
/* Define to 1 if you have the <sys/ioctl.h> header file. */
#define HAVE_SYS_IOCTL_H 1
/* Define to 1 if you have the <sys/resource.h> header file. */
#define HAVE_SYS_RESOURCE_H 1
/* Define to 1 if you have the <sys/select.h> header file. */
#define HAVE_SYS_SELECT_H 1
/* Define to 1 if you have the <sys/socket.h> header file. */
#define HAVE_SYS_SOCKET_H 1
/* Define to 1 if you have the <sys/stat.h> header file. */
#define HAVE_SYS_STAT_H 1
/* Define to 1 if you have the <sys/time.h> header file. */
#define HAVE_SYS_TIME_H 1
/* Define to 1 if you have the <sys/types.h> header file. */
#define HAVE_SYS_TYPES_H 1
/* Define to 1 if you have the <unistd.h> header file. */
#define HAVE_UNISTD_H 1
/* Define to 1 if you have the `vprintf' function. */
#define HAVE_VPRINTF 1
/* define if zlib is available */
/* #undef HAVE_ZLIB */
/* Possible value for SIGNED_RIGHT_SHIFT_IS */
#define LOGICAL_RIGHT_SHIFT 2
/* Define as the return type of signal handlers (`int' or `void'). */
#define RETSIGTYPE void
/* Define to the type of arg 1 for `select'. */
#define SELECT_TYPE_ARG1 int
/* Define to the type of args 2, 3 and 4 for `select'. */
#define SELECT_TYPE_ARG234 (fd_set *)
/* Define to the type of arg 5 for `select'. */
#define SELECT_TYPE_ARG5 (struct timeval *)
/* Indicates the effect of the right shift operator on negative signed integers */
#define SIGNED_RIGHT_SHIFT_IS 1
/* Define to 1 if you have the ANSI C header files. */
#define STDC_HEADERS 1
/* Define to 1 if you can safely include both <sys/time.h> and <time.h>. */
#define TIME_WITH_SYS_TIME 1
/* Possible value for SIGNED_RIGHT_SHIFT_IS */
#define UNKNOWN_RIGHT_SHIFT 3
#endif /* ZEPHYR_MODULES_THRIFT_SRC_THRIFT_CONFIG_H_ */

View file

@ -0,0 +1,286 @@
/*
* Copyright (c) 2006- Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_
#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ 1
#include <thrift/protocol/TProtocol.h>
#include <thrift/protocol/TVirtualProtocol.h>
#include <memory>
namespace apache
{
namespace thrift
{
namespace protocol
{
/**
* The default binary protocol for thrift. Writes all data in a very basic
* binary format, essentially just spitting out the raw bytes.
*
*/
template <class Transport_, class ByteOrder_ = TNetworkBigEndian>
class TBinaryProtocolT : public TVirtualProtocol<TBinaryProtocolT<Transport_, ByteOrder_>>
{
public:
static const int32_t VERSION_MASK = ((int32_t)0xffff0000);
static const int32_t VERSION_1 = ((int32_t)0x80010000);
// VERSION_2 (0x80020000) was taken by TDenseProtocol (which has since been removed)
TBinaryProtocolT(std::shared_ptr<Transport_> trans)
: TVirtualProtocol<TBinaryProtocolT<Transport_, ByteOrder_>>(trans),
trans_(trans.get()), string_limit_(0), container_limit_(0), strict_read_(false),
strict_write_(true)
{
}
TBinaryProtocolT(std::shared_ptr<Transport_> trans, int32_t string_limit,
int32_t container_limit, bool strict_read, bool strict_write)
: TVirtualProtocol<TBinaryProtocolT<Transport_, ByteOrder_>>(trans),
trans_(trans.get()), string_limit_(string_limit),
container_limit_(container_limit), strict_read_(strict_read),
strict_write_(strict_write)
{
}
void setStringSizeLimit(int32_t string_limit)
{
string_limit_ = string_limit;
}
void setContainerSizeLimit(int32_t container_limit)
{
container_limit_ = container_limit;
}
void setStrict(bool strict_read, bool strict_write)
{
strict_read_ = strict_read;
strict_write_ = strict_write;
}
/**
* Writing functions.
*/
/*ol*/ uint32_t writeMessageBegin(const std::string &name, const TMessageType messageType,
const int32_t seqid);
/*ol*/ uint32_t writeMessageEnd();
inline uint32_t writeStructBegin(const char *name);
inline uint32_t writeStructEnd();
inline uint32_t writeFieldBegin(const char *name, const TType fieldType,
const int16_t fieldId);
inline uint32_t writeFieldEnd();
inline uint32_t writeFieldStop();
inline uint32_t writeMapBegin(const TType keyType, const TType valType,
const uint32_t size);
inline uint32_t writeMapEnd();
inline uint32_t writeListBegin(const TType elemType, const uint32_t size);
inline uint32_t writeListEnd();
inline uint32_t writeSetBegin(const TType elemType, const uint32_t size);
inline uint32_t writeSetEnd();
inline uint32_t writeBool(const bool value);
inline uint32_t writeByte(const int8_t byte);
inline uint32_t writeI16(const int16_t i16);
inline uint32_t writeI32(const int32_t i32);
inline uint32_t writeI64(const int64_t i64);
inline uint32_t writeDouble(const double dub);
template <typename StrType> inline uint32_t writeString(const StrType &str);
inline uint32_t writeBinary(const std::string &str);
/**
* Reading functions
*/
/*ol*/ uint32_t readMessageBegin(std::string &name, TMessageType &messageType,
int32_t &seqid);
/*ol*/ uint32_t readMessageEnd();
inline uint32_t readStructBegin(std::string &name);
inline uint32_t readStructEnd();
inline uint32_t readFieldBegin(std::string &name, TType &fieldType, int16_t &fieldId);
inline uint32_t readFieldEnd();
inline uint32_t readMapBegin(TType &keyType, TType &valType, uint32_t &size);
inline uint32_t readMapEnd();
inline uint32_t readListBegin(TType &elemType, uint32_t &size);
inline uint32_t readListEnd();
inline uint32_t readSetBegin(TType &elemType, uint32_t &size);
inline uint32_t readSetEnd();
inline uint32_t readBool(bool &value);
// Provide the default readBool() implementation for std::vector<bool>
using TVirtualProtocol<TBinaryProtocolT<Transport_, ByteOrder_>>::readBool;
inline uint32_t readByte(int8_t &byte);
inline uint32_t readI16(int16_t &i16);
inline uint32_t readI32(int32_t &i32);
inline uint32_t readI64(int64_t &i64);
inline uint32_t readDouble(double &dub);
template <typename StrType> inline uint32_t readString(StrType &str);
inline uint32_t readBinary(std::string &str);
int getMinSerializedSize(TType type) override;
void checkReadBytesAvailable(TSet &set) override
{
trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
}
void checkReadBytesAvailable(TList &list) override
{
trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
}
void checkReadBytesAvailable(TMap &map) override
{
int elmSize =
getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
trans_->checkReadBytesAvailable(map.size_ * elmSize);
}
protected:
template <typename StrType> uint32_t readStringBody(StrType &str, int32_t sz);
Transport_ *trans_;
int32_t string_limit_;
int32_t container_limit_;
// Enforce presence of version identifier
bool strict_read_;
bool strict_write_;
};
typedef TBinaryProtocolT<TTransport> TBinaryProtocol;
typedef TBinaryProtocolT<TTransport, TNetworkLittleEndian> TLEBinaryProtocol;
/**
* Constructs binary protocol handlers
*/
template <class Transport_, class ByteOrder_ = TNetworkBigEndian>
class TBinaryProtocolFactoryT : public TProtocolFactory
{
public:
TBinaryProtocolFactoryT()
: string_limit_(0), container_limit_(0), strict_read_(false), strict_write_(true)
{
}
TBinaryProtocolFactoryT(int32_t string_limit, int32_t container_limit, bool strict_read,
bool strict_write)
: string_limit_(string_limit), container_limit_(container_limit),
strict_read_(strict_read), strict_write_(strict_write)
{
}
~TBinaryProtocolFactoryT() override = default;
void setStringSizeLimit(int32_t string_limit)
{
string_limit_ = string_limit;
}
void setContainerSizeLimit(int32_t container_limit)
{
container_limit_ = container_limit;
}
void setStrict(bool strict_read, bool strict_write)
{
strict_read_ = strict_read;
strict_write_ = strict_write;
}
std::shared_ptr<TProtocol> getProtocol(std::shared_ptr<TTransport> trans) override
{
std::shared_ptr<Transport_> specific_trans =
std::dynamic_pointer_cast<Transport_>(trans);
TProtocol *prot;
if (specific_trans) {
prot = new TBinaryProtocolT<Transport_, ByteOrder_>(
specific_trans, string_limit_, container_limit_, strict_read_,
strict_write_);
} else {
prot = new TBinaryProtocolT<TTransport, ByteOrder_>(
trans, string_limit_, container_limit_, strict_read_,
strict_write_);
}
return std::shared_ptr<TProtocol>(prot);
}
private:
int32_t string_limit_;
int32_t container_limit_;
bool strict_read_;
bool strict_write_;
};
typedef TBinaryProtocolFactoryT<TTransport> TBinaryProtocolFactory;
typedef TBinaryProtocolFactoryT<TTransport, TNetworkLittleEndian> TLEBinaryProtocolFactory;
} // namespace protocol
} // namespace thrift
} // namespace apache
#include <thrift/protocol/TBinaryProtocol.tcc>
#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_

View file

@ -0,0 +1,119 @@
/*
* Copyright (c) 2006- Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef _THRIFT_SERVER_TCONNECTEDCLIENT_H_
#define _THRIFT_SERVER_TCONNECTEDCLIENT_H_ 1
#include <memory>
#include <thrift/TProcessor.h>
#include <thrift/protocol/TProtocol.h>
#include <thrift/server/TServer.h>
#include <thrift/transport/TTransport.h>
namespace apache
{
namespace thrift
{
namespace server
{
/**
* This represents a client connected to a TServer. The
* processing loop for a client must provide some required
* functionality common to all implementations so it is
* encapsulated here.
*/
class TConnectedClient
{
public:
/**
* Constructor.
*
* @param[in] processor the TProcessor
* @param[in] inputProtocol the input TProtocol
* @param[in] outputProtocol the output TProtocol
* @param[in] eventHandler the server event handler
* @param[in] client the TTransport representing the client
*/
TConnectedClient(
const std::shared_ptr<apache::thrift::TProcessor> &processor,
const std::shared_ptr<apache::thrift::protocol::TProtocol> &inputProtocol,
const std::shared_ptr<apache::thrift::protocol::TProtocol> &outputProtocol,
const std::shared_ptr<apache::thrift::server::TServerEventHandler> &eventHandler,
const std::shared_ptr<apache::thrift::transport::TTransport> &client);
/**
* Destructor.
*/
~TConnectedClient();
/**
* Drive the client until it is done.
* The client processing loop is:
*
* [optional] call eventHandler->createContext once
* [optional] call eventHandler->processContext per request
* call processor->process per request
* handle expected transport exceptions:
* END_OF_FILE means the client is gone
* INTERRUPTED means the client was interrupted
* by TServerTransport::interruptChildren()
* handle unexpected transport exceptions by logging
* handle standard exceptions by logging
* handle unexpected exceptions by logging
* cleanup()
*/
void run();
protected:
/**
* Cleanup after a client. This happens if the client disconnects,
* or if the server is stopped, or if an exception occurs.
*
* The cleanup processing is:
* [optional] call eventHandler->deleteContext once
* close the inputProtocol's TTransport
* close the outputProtocol's TTransport
* close the client
*/
virtual void cleanup();
private:
std::shared_ptr<apache::thrift::TProcessor> processor_;
std::shared_ptr<apache::thrift::protocol::TProtocol> inputProtocol_;
std::shared_ptr<apache::thrift::protocol::TProtocol> outputProtocol_;
std::shared_ptr<apache::thrift::server::TServerEventHandler> eventHandler_;
std::shared_ptr<apache::thrift::transport::TTransport> client_;
/**
* Context acquired from the eventHandler_ if one exists.
*/
void *opaqueContext_;
};
} // namespace server
} // namespace thrift
} // namespace apache
#endif // #ifndef _THRIFT_SERVER_TCONNECTEDCLIENT_H_

View file

@ -0,0 +1,218 @@
/*
* Copyright 2022 Meta
*
* SPDX-License-Identifier: Apache-2.0
*/
#include <array>
#include <system_error>
#include <errno.h>
#include <zephyr/posix/poll.h>
#include <zephyr/posix/sys/eventfd.h>
#include <zephyr/posix/unistd.h>
#include <thrift/transport/TFDTransport.h>
#include <zephyr/kernel.h>
#include <zephyr/logging/log.h>
#include "thrift/server/TFDServer.h"
LOG_MODULE_REGISTER(TFDServer, LOG_LEVEL_INF);
using namespace std;
namespace apache
{
namespace thrift
{
namespace transport
{
class xport : public TVirtualTransport<xport>
{
public:
xport(int fd) : xport(fd, eventfd(0, EFD_SEMAPHORE))
{
}
xport(int fd, int efd) : fd(fd), efd(efd)
{
__ASSERT(fd >= 0, "invalid fd %d", fd);
__ASSERT(efd >= 0, "invalid efd %d", efd);
LOG_DBG("created xport with fd %d and efd %d", fd, efd);
}
~xport()
{
close();
}
virtual uint32_t read_virt(uint8_t *buf, uint32_t len) override
{
int r;
array<pollfd, 2> pollfds = {
(pollfd){
.fd = fd,
.events = POLLIN,
.revents = 0,
},
(pollfd){
.fd = efd,
.events = POLLIN,
.revents = 0,
},
};
if (!isOpen()) {
return 0;
}
r = poll(&pollfds.front(), pollfds.size(), -1);
if (r == -1) {
LOG_ERR("failed to poll fds %d, %d: %d", fd, efd, errno);
throw system_error(errno, system_category(), "poll");
}
for (auto &pfd : pollfds) {
if (pfd.revents & POLLNVAL) {
LOG_DBG("fd %d is invalid", pfd.fd);
return 0;
}
}
if (pollfds[0].revents & POLLIN) {
r = ::read(fd, buf, len);
if (r == -1) {
LOG_ERR("failed to read %d bytes from fd %d: %d", len, fd, errno);
system_error(errno, system_category(), "read");
}
__ASSERT_NO_MSG(r > 0);
return uint32_t(r);
}
__ASSERT_NO_MSG(pollfds[1].revents & POLLIN);
return 0;
}
virtual void write_virt(const uint8_t *buf, uint32_t len) override
{
if (!isOpen()) {
throw TTransportException(TTransportException::END_OF_FILE);
}
for (int r = 0; len > 0; buf += r, len -= r) {
r = ::write(fd, buf, len);
if (r == -1) {
LOG_ERR("writing %u bytes to fd %d failed: %d", len, fd, errno);
throw system_error(errno, system_category(), "write");
}
__ASSERT_NO_MSG(r > 0);
}
}
void interrupt()
{
if (!isOpen()) {
return;
}
constexpr uint64_t x = 0xb7e;
int r = ::write(efd, &x, sizeof(x));
if (r == -1) {
LOG_ERR("writing %zu bytes to fd %d failed: %d", sizeof(x), efd, errno);
throw system_error(errno, system_category(), "write");
}
__ASSERT_NO_MSG(r > 0);
LOG_DBG("interrupted xport with fd %d and efd %d", fd, efd);
// there is no interrupt() method in the parent class, but the intent of
// interrupt() is to prevent future communication on this transport. The
// most reliable way we have of doing this is to close it :-)
close();
}
void close() override
{
if (isOpen()) {
::close(efd);
LOG_DBG("closed xport with fd %d and efd %d", fd, efd);
efd = -1;
// we only have a copy of fd and do not own it
fd = -1;
}
}
bool isOpen() const override
{
return fd >= 0 && efd >= 0;
}
protected:
int fd;
int efd;
};
TFDServer::TFDServer(int fd) : fd(fd)
{
}
TFDServer::~TFDServer()
{
interruptChildren();
interrupt();
}
bool TFDServer::isOpen() const
{
return fd >= 0;
}
shared_ptr<TTransport> TFDServer::acceptImpl()
{
if (!isOpen()) {
throw TTransportException(TTransportException::INTERRUPTED);
}
children.push_back(shared_ptr<TTransport>(new xport(fd)));
return children.back();
}
THRIFT_SOCKET TFDServer::getSocketFD()
{
return fd;
}
void TFDServer::close()
{
// we only have a copy of fd and do not own it
fd = -1;
}
void TFDServer::interrupt()
{
close();
}
void TFDServer::interruptChildren()
{
for (auto c : children) {
auto child = reinterpret_cast<xport *>(c.get());
child->interrupt();
}
children.clear();
}
} // namespace transport
} // namespace thrift
} // namespace apache

View file

@ -0,0 +1,52 @@
/*
* Copyright 2022 Meta
*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef _THRIFT_SERVER_TFDSERVER_H_
#define _THRIFT_SERVER_TFDSERVER_H_ 1
#include <memory>
#include <vector>
#include <thrift/transport/TServerTransport.h>
namespace apache
{
namespace thrift
{
namespace transport
{
class TFDServer : public TServerTransport
{
public:
/**
* Constructor.
*
* @param fd file descriptor of the socket
*/
TFDServer(int fd);
virtual ~TFDServer();
virtual bool isOpen() const override;
virtual THRIFT_SOCKET getSocketFD() override;
virtual void close() override;
virtual void interrupt() override;
virtual void interruptChildren() override;
protected:
TFDServer() : TFDServer(-1){};
virtual std::shared_ptr<TTransport> acceptImpl() override;
int fd;
std::vector<std::shared_ptr<TTransport>> children;
};
} // namespace transport
} // namespace thrift
} // namespace apache
#endif /* _THRIFT_SERVER_TFDSERVER_H_ */

View file

@ -0,0 +1,338 @@
/*
* Copyright (c) 2006- Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef _THRIFT_SERVER_TSERVER_H_
#define _THRIFT_SERVER_TSERVER_H_ 1
#include <thrift/TProcessor.h>
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/transport/TServerTransport.h>
#include <memory>
namespace apache
{
namespace thrift
{
namespace server
{
using apache::thrift::TProcessor;
using apache::thrift::protocol::TBinaryProtocolFactory;
using apache::thrift::protocol::TProtocol;
using apache::thrift::protocol::TProtocolFactory;
using apache::thrift::transport::TServerTransport;
using apache::thrift::transport::TTransport;
using apache::thrift::transport::TTransportFactory;
/**
* Virtual interface class that can handle events from the server core. To
* use this you should subclass it and implement the methods that you care
* about. Your subclass can also store local data that you may care about,
* such as additional "arguments" to these methods (stored in the object
* instance's state).
*/
class TServerEventHandler
{
public:
virtual ~TServerEventHandler() = default;
/**
* Called before the server begins.
*/
virtual void preServe()
{
}
/**
* Called when a new client has connected and is about to being processing.
*/
virtual void *createContext(std::shared_ptr<TProtocol> input,
std::shared_ptr<TProtocol> output)
{
(void)input;
(void)output;
return nullptr;
}
/**
* Called when a client has finished request-handling to delete server
* context.
*/
virtual void deleteContext(void *serverContext, std::shared_ptr<TProtocol> input,
std::shared_ptr<TProtocol> output)
{
(void)serverContext;
(void)input;
(void)output;
}
/**
* Called when a client is about to call the processor.
*/
virtual void processContext(void *serverContext, std::shared_ptr<TTransport> transport)
{
(void)serverContext;
(void)transport;
}
protected:
/**
* Prevent direct instantiation.
*/
TServerEventHandler() = default;
};
/**
* Thrift server.
*
*/
class TServer
{
public:
~TServer() = default;
virtual void serve() = 0;
virtual void stop()
{
}
// Allows running the server as a Runnable thread
void run()
{
serve();
}
std::shared_ptr<TProcessorFactory> getProcessorFactory()
{
return processorFactory_;
}
std::shared_ptr<TServerTransport> getServerTransport()
{
return serverTransport_;
}
std::shared_ptr<TTransportFactory> getInputTransportFactory()
{
return inputTransportFactory_;
}
std::shared_ptr<TTransportFactory> getOutputTransportFactory()
{
return outputTransportFactory_;
}
std::shared_ptr<TProtocolFactory> getInputProtocolFactory()
{
return inputProtocolFactory_;
}
std::shared_ptr<TProtocolFactory> getOutputProtocolFactory()
{
return outputProtocolFactory_;
}
std::shared_ptr<TServerEventHandler> getEventHandler()
{
return eventHandler_;
}
protected:
TServer(const std::shared_ptr<TProcessorFactory> &processorFactory)
: processorFactory_(processorFactory)
{
setInputTransportFactory(
std::shared_ptr<TTransportFactory>(new TTransportFactory()));
setOutputTransportFactory(
std::shared_ptr<TTransportFactory>(new TTransportFactory()));
setInputProtocolFactory(
std::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
setOutputProtocolFactory(
std::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
}
TServer(const std::shared_ptr<TProcessor> &processor)
: processorFactory_(new TSingletonProcessorFactory(processor))
{
setInputTransportFactory(
std::shared_ptr<TTransportFactory>(new TTransportFactory()));
setOutputTransportFactory(
std::shared_ptr<TTransportFactory>(new TTransportFactory()));
setInputProtocolFactory(
std::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
setOutputProtocolFactory(
std::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
}
TServer(const std::shared_ptr<TProcessorFactory> &processorFactory,
const std::shared_ptr<TServerTransport> &serverTransport)
: processorFactory_(processorFactory), serverTransport_(serverTransport)
{
setInputTransportFactory(
std::shared_ptr<TTransportFactory>(new TTransportFactory()));
setOutputTransportFactory(
std::shared_ptr<TTransportFactory>(new TTransportFactory()));
setInputProtocolFactory(
std::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
setOutputProtocolFactory(
std::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
}
TServer(const std::shared_ptr<TProcessor> &processor,
const std::shared_ptr<TServerTransport> &serverTransport)
: processorFactory_(new TSingletonProcessorFactory(processor)),
serverTransport_(serverTransport)
{
setInputTransportFactory(
std::shared_ptr<TTransportFactory>(new TTransportFactory()));
setOutputTransportFactory(
std::shared_ptr<TTransportFactory>(new TTransportFactory()));
setInputProtocolFactory(
std::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
setOutputProtocolFactory(
std::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
}
TServer(const std::shared_ptr<TProcessorFactory> &processorFactory,
const std::shared_ptr<TServerTransport> &serverTransport,
const std::shared_ptr<TTransportFactory> &transportFactory,
const std::shared_ptr<TProtocolFactory> &protocolFactory)
: processorFactory_(processorFactory), serverTransport_(serverTransport),
inputTransportFactory_(transportFactory),
outputTransportFactory_(transportFactory), inputProtocolFactory_(protocolFactory),
outputProtocolFactory_(protocolFactory)
{
}
TServer(const std::shared_ptr<TProcessor> &processor,
const std::shared_ptr<TServerTransport> &serverTransport,
const std::shared_ptr<TTransportFactory> &transportFactory,
const std::shared_ptr<TProtocolFactory> &protocolFactory)
: processorFactory_(new TSingletonProcessorFactory(processor)),
serverTransport_(serverTransport), inputTransportFactory_(transportFactory),
outputTransportFactory_(transportFactory), inputProtocolFactory_(protocolFactory),
outputProtocolFactory_(protocolFactory)
{
}
TServer(const std::shared_ptr<TProcessorFactory> &processorFactory,
const std::shared_ptr<TServerTransport> &serverTransport,
const std::shared_ptr<TTransportFactory> &inputTransportFactory,
const std::shared_ptr<TTransportFactory> &outputTransportFactory,
const std::shared_ptr<TProtocolFactory> &inputProtocolFactory,
const std::shared_ptr<TProtocolFactory> &outputProtocolFactory)
: processorFactory_(processorFactory), serverTransport_(serverTransport),
inputTransportFactory_(inputTransportFactory),
outputTransportFactory_(outputTransportFactory),
inputProtocolFactory_(inputProtocolFactory),
outputProtocolFactory_(outputProtocolFactory)
{
}
TServer(const std::shared_ptr<TProcessor> &processor,
const std::shared_ptr<TServerTransport> &serverTransport,
const std::shared_ptr<TTransportFactory> &inputTransportFactory,
const std::shared_ptr<TTransportFactory> &outputTransportFactory,
const std::shared_ptr<TProtocolFactory> &inputProtocolFactory,
const std::shared_ptr<TProtocolFactory> &outputProtocolFactory)
: processorFactory_(new TSingletonProcessorFactory(processor)),
serverTransport_(serverTransport), inputTransportFactory_(inputTransportFactory),
outputTransportFactory_(outputTransportFactory),
inputProtocolFactory_(inputProtocolFactory),
outputProtocolFactory_(outputProtocolFactory)
{
}
/**
* Get a TProcessor to handle calls on a particular connection.
*
* This method should only be called once per connection (never once per
* call). This allows the TProcessorFactory to return a different processor
* for each connection if it desires.
*/
std::shared_ptr<TProcessor> getProcessor(std::shared_ptr<TProtocol> inputProtocol,
std::shared_ptr<TProtocol> outputProtocol,
std::shared_ptr<TTransport> transport)
{
TConnectionInfo connInfo;
connInfo.input = inputProtocol;
connInfo.output = outputProtocol;
connInfo.transport = transport;
return processorFactory_->getProcessor(connInfo);
}
// Class variables
std::shared_ptr<TProcessorFactory> processorFactory_;
std::shared_ptr<TServerTransport> serverTransport_;
std::shared_ptr<TTransportFactory> inputTransportFactory_;
std::shared_ptr<TTransportFactory> outputTransportFactory_;
std::shared_ptr<TProtocolFactory> inputProtocolFactory_;
std::shared_ptr<TProtocolFactory> outputProtocolFactory_;
std::shared_ptr<TServerEventHandler> eventHandler_;
public:
void setInputTransportFactory(std::shared_ptr<TTransportFactory> inputTransportFactory)
{
inputTransportFactory_ = inputTransportFactory;
}
void setOutputTransportFactory(std::shared_ptr<TTransportFactory> outputTransportFactory)
{
outputTransportFactory_ = outputTransportFactory;
}
void setInputProtocolFactory(std::shared_ptr<TProtocolFactory> inputProtocolFactory)
{
inputProtocolFactory_ = inputProtocolFactory;
}
void setOutputProtocolFactory(std::shared_ptr<TProtocolFactory> outputProtocolFactory)
{
outputProtocolFactory_ = outputProtocolFactory;
}
void setServerEventHandler(std::shared_ptr<TServerEventHandler> eventHandler)
{
eventHandler_ = eventHandler;
}
};
/**
* Helper function to increase the max file descriptors limit
* for the current process and all of its children.
* By default, tries to increase it to as much as 2^24.
*/
#ifdef HAVE_SYS_RESOURCE_H
int increase_max_fds(int max_fds = (1 << 24));
#endif
} // namespace server
} // namespace thrift
} // namespace apache
#endif // #ifndef _THRIFT_SERVER_TSERVER_H_

View file

@ -0,0 +1,255 @@
/*
* Copyright (c) 2006- Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <algorithm>
#include <functional>
#include <stdexcept>
#include <stdint.h>
#include <thrift/server/TServerFramework.h>
namespace apache
{
namespace thrift
{
namespace server
{
// using apache::thrift::concurrency::Synchronized;
using apache::thrift::protocol::TProtocol;
using apache::thrift::protocol::TProtocolFactory;
using apache::thrift::transport::TServerTransport;
using apache::thrift::transport::TTransport;
using apache::thrift::transport::TTransportException;
using apache::thrift::transport::TTransportFactory;
using std::bind;
using std::shared_ptr;
using std::string;
TServerFramework::TServerFramework(const shared_ptr<TProcessorFactory> &processorFactory,
const shared_ptr<TServerTransport> &serverTransport,
const shared_ptr<TTransportFactory> &transportFactory,
const shared_ptr<TProtocolFactory> &protocolFactory)
: TServer(processorFactory, serverTransport, transportFactory, protocolFactory),
clients_(0), hwm_(0), limit_(INT64_MAX)
{
}
TServerFramework::TServerFramework(const shared_ptr<TProcessor> &processor,
const shared_ptr<TServerTransport> &serverTransport,
const shared_ptr<TTransportFactory> &transportFactory,
const shared_ptr<TProtocolFactory> &protocolFactory)
: TServer(processor, serverTransport, transportFactory, protocolFactory), clients_(0),
hwm_(0), limit_(INT64_MAX)
{
}
TServerFramework::TServerFramework(const shared_ptr<TProcessorFactory> &processorFactory,
const shared_ptr<TServerTransport> &serverTransport,
const shared_ptr<TTransportFactory> &inputTransportFactory,
const shared_ptr<TTransportFactory> &outputTransportFactory,
const shared_ptr<TProtocolFactory> &inputProtocolFactory,
const shared_ptr<TProtocolFactory> &outputProtocolFactory)
: TServer(processorFactory, serverTransport, inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory),
clients_(0), hwm_(0), limit_(INT64_MAX)
{
}
TServerFramework::TServerFramework(const shared_ptr<TProcessor> &processor,
const shared_ptr<TServerTransport> &serverTransport,
const shared_ptr<TTransportFactory> &inputTransportFactory,
const shared_ptr<TTransportFactory> &outputTransportFactory,
const shared_ptr<TProtocolFactory> &inputProtocolFactory,
const shared_ptr<TProtocolFactory> &outputProtocolFactory)
: TServer(processor, serverTransport, inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory),
clients_(0), hwm_(0), limit_(INT64_MAX)
{
}
TServerFramework::~TServerFramework() = default;
template <typename T> static void releaseOneDescriptor(const string &name, T &pTransport)
{
if (pTransport) {
try {
pTransport->close();
} catch (const TTransportException &ttx) {
string errStr =
string("TServerFramework " + name + " close failed: ") + ttx.what();
GlobalOutput(errStr.c_str());
}
}
}
void TServerFramework::serve()
{
shared_ptr<TTransport> client;
shared_ptr<TTransport> inputTransport;
shared_ptr<TTransport> outputTransport;
shared_ptr<TProtocol> inputProtocol;
shared_ptr<TProtocol> outputProtocol;
// Start the server listening
serverTransport_->listen();
// Run the preServe event to indicate server is now listening
// and that it is safe to connect.
if (eventHandler_) {
eventHandler_->preServe();
}
// Fetch client from server
for (;;) {
try {
// Dereference any resources from any previous client creation
// such that a blocking accept does not hold them indefinitely.
outputProtocol.reset();
inputProtocol.reset();
outputTransport.reset();
inputTransport.reset();
client.reset();
// If we have reached the limit on the number of concurrent
// clients allowed, wait for one or more clients to drain before
// accepting another.
{
// Synchronized sync(mon_);
while (clients_ >= limit_) {
// mon_.wait();
}
}
client = serverTransport_->accept();
inputTransport = inputTransportFactory_->getTransport(client);
outputTransport = outputTransportFactory_->getTransport(client);
if (!outputProtocolFactory_) {
inputProtocol = inputProtocolFactory_->getProtocol(inputTransport,
outputTransport);
outputProtocol = inputProtocol;
} else {
inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);
outputProtocol =
outputProtocolFactory_->getProtocol(outputTransport);
}
newlyConnectedClient(shared_ptr<TConnectedClient>(
new TConnectedClient(
getProcessor(inputProtocol, outputProtocol, client),
inputProtocol, outputProtocol, eventHandler_, client),
bind(&TServerFramework::disposeConnectedClient, this,
std::placeholders::_1)));
} catch (TTransportException &ttx) {
releaseOneDescriptor("inputTransport", inputTransport);
releaseOneDescriptor("outputTransport", outputTransport);
releaseOneDescriptor("client", client);
if (ttx.getType() == TTransportException::TIMED_OUT ||
ttx.getType() == TTransportException::CLIENT_DISCONNECT) {
// Accept timeout and client disconnect - continue processing.
continue;
} else if (ttx.getType() == TTransportException::END_OF_FILE ||
ttx.getType() == TTransportException::INTERRUPTED) {
// Server was interrupted. This only happens when stopping.
break;
} else {
// All other transport exceptions are logged.
// State of connection is unknown. Done.
string errStr = string("TServerTransport died: ") + ttx.what();
GlobalOutput(errStr.c_str());
break;
}
}
}
releaseOneDescriptor("serverTransport", serverTransport_);
}
int64_t TServerFramework::getConcurrentClientLimit() const
{
// Synchronized sync(mon_);
return limit_;
}
int64_t TServerFramework::getConcurrentClientCount() const
{
// Synchronized sync(mon_);
return clients_;
}
int64_t TServerFramework::getConcurrentClientCountHWM() const
{
// Synchronized sync(mon_);
return hwm_;
}
void TServerFramework::setConcurrentClientLimit(int64_t newLimit)
{
if (newLimit < 1) {
throw std::invalid_argument("newLimit must be greater than zero");
}
// Synchronized sync(mon_);
limit_ = newLimit;
if (limit_ - clients_ > 0) {
// mon_.notify();
}
}
void TServerFramework::stop()
{
// Order is important because serve() releases serverTransport_ when it is
// interrupted, which closes the socket that interruptChildren uses.
serverTransport_->interruptChildren();
serverTransport_->interrupt();
}
void TServerFramework::newlyConnectedClient(const shared_ptr<TConnectedClient> &pClient)
{
{
// Synchronized sync(mon_);
++clients_;
hwm_ = (std::max)(hwm_, clients_);
}
onClientConnected(pClient);
}
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdelete-non-virtual-dtor"
void TServerFramework::disposeConnectedClient(TConnectedClient *pClient)
{
onClientDisconnected(pClient);
delete pClient;
// Synchronized sync(mon_);
if (limit_ - --clients_ > 0) {
// mon_.notify();
}
}
#pragma GCC diagnostic pop
} // namespace server
} // namespace thrift
} // namespace apache

View file

@ -0,0 +1,197 @@
/*
* Copyright (c) 2006- Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef _THRIFT_SERVER_TSERVERFRAMEWORK_H_
#define _THRIFT_SERVER_TSERVERFRAMEWORK_H_ 1
#include <memory>
#include <stdint.h>
#include <thrift/TProcessor.h>
#include <thrift/server/TConnectedClient.h>
#include <thrift/server/TServer.h>
#include <thrift/transport/TServerTransport.h>
#include <thrift/transport/TTransport.h>
namespace apache
{
namespace thrift
{
namespace server
{
/**
* TServerFramework provides a single consolidated processing loop for
* servers. By having a single processing loop, behavior between servers
* is more predictable and maintenance cost is lowered. Implementations
* of TServerFramework must provide a method to deal with a client that
* connects and one that disconnects.
*
* While this functionality could be rolled directly into TServer, and
* probably should be, it would break the TServer interface contract so
* to maintain backwards compatibility for third party servers, no TServers
* were harmed in the making of this class.
*/
class TServerFramework : public TServer
{
public:
TServerFramework(
const std::shared_ptr<apache::thrift::TProcessorFactory> &processorFactory,
const std::shared_ptr<apache::thrift::transport::TServerTransport> &serverTransport,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&transportFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory> &protocolFactory);
TServerFramework(
const std::shared_ptr<apache::thrift::TProcessor> &processor,
const std::shared_ptr<apache::thrift::transport::TServerTransport> &serverTransport,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&transportFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory> &protocolFactory);
TServerFramework(
const std::shared_ptr<apache::thrift::TProcessorFactory> &processorFactory,
const std::shared_ptr<apache::thrift::transport::TServerTransport> &serverTransport,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&inputTransportFactory,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&outputTransportFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory>
&inputProtocolFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory>
&outputProtocolFactory);
TServerFramework(
const std::shared_ptr<apache::thrift::TProcessor> &processor,
const std::shared_ptr<apache::thrift::transport::TServerTransport> &serverTransport,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&inputTransportFactory,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&outputTransportFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory>
&inputProtocolFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory>
&outputProtocolFactory);
~TServerFramework();
/**
* Accept clients from the TServerTransport and add them for processing.
* Call stop() on another thread to interrupt processing
* and return control to the caller.
* Post-conditions (return guarantees):
* The serverTransport will be closed.
*/
virtual void serve() override;
/**
* Interrupt serve() so that it meets post-conditions and returns.
*/
virtual void stop() override;
/**
* Get the concurrent client limit.
* \returns the concurrent client limit
*/
virtual int64_t getConcurrentClientLimit() const;
/**
* Get the number of currently connected clients.
* \returns the number of currently connected clients
*/
virtual int64_t getConcurrentClientCount() const;
/**
* Get the highest number of concurrent clients.
* \returns the highest number of concurrent clients
*/
virtual int64_t getConcurrentClientCountHWM() const;
/**
* Set the concurrent client limit. This can be changed while
* the server is serving however it will not necessarily be
* enforced until the next client is accepted and added. If the
* limit is lowered below the number of connected clients, no
* action is taken to disconnect the clients.
* The default value used if this is not called is INT64_MAX.
* \param[in] newLimit the new limit of concurrent clients
* \throws std::invalid_argument if newLimit is less than 1
*/
virtual void setConcurrentClientLimit(int64_t newLimit);
protected:
/**
* A client has connected. The implementation is responsible for managing the
* lifetime of the client object. This is called during the serve() thread,
* therefore a failure to return quickly will result in new client connection
* delays.
*
* \param[in] pClient the newly connected client
*/
virtual void onClientConnected(const std::shared_ptr<TConnectedClient> &pClient) = 0;
/**
* A client has disconnected.
* When called:
* The server no longer tracks the client.
* The client TTransport has already been closed.
* The implementation must not delete the pointer.
*
* \param[in] pClient the disconnected client
*/
virtual void onClientDisconnected(TConnectedClient *pClient) = 0;
private:
/**
* Common handling for new connected clients. Implements concurrent
* client rate limiting after onClientConnected returns by blocking the
* serve() thread if the limit has been reached.
*/
void newlyConnectedClient(const std::shared_ptr<TConnectedClient> &pClient);
/**
* Smart pointer client deletion.
* Calls onClientDisconnected and then deletes pClient.
*/
void disposeConnectedClient(TConnectedClient *pClient);
/**
* The number of concurrent clients.
*/
int64_t clients_;
/**
* The high water mark of concurrent clients.
*/
int64_t hwm_;
/**
* The limit on the number of concurrent clients.
*/
int64_t limit_;
};
} // namespace server
} // namespace thrift
} // namespace apache
#endif // #ifndef _THRIFT_SERVER_TSERVERFRAMEWORK_H_

View file

@ -0,0 +1,97 @@
/*
* Copyright (c) 2006- Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef _THRIFT_SERVER_TSIMPLESERVER_H_
#define _THRIFT_SERVER_TSIMPLESERVER_H_ 1
#include <thrift/server/TServerFramework.h>
namespace apache
{
namespace thrift
{
namespace server
{
/**
* This is the most basic simple server. It is single-threaded and runs a
* continuous loop of accepting a single connection, processing requests on
* that connection until it closes, and then repeating.
*/
class TSimpleServer : public TServerFramework
{
public:
TSimpleServer(
const std::shared_ptr<apache::thrift::TProcessorFactory> &processorFactory,
const std::shared_ptr<apache::thrift::transport::TServerTransport> &serverTransport,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&transportFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory> &protocolFactory);
TSimpleServer(
const std::shared_ptr<apache::thrift::TProcessor> &processor,
const std::shared_ptr<apache::thrift::transport::TServerTransport> &serverTransport,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&transportFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory> &protocolFactory);
TSimpleServer(
const std::shared_ptr<apache::thrift::TProcessorFactory> &processorFactory,
const std::shared_ptr<apache::thrift::transport::TServerTransport> &serverTransport,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&inputTransportFactory,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&outputTransportFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory>
&inputProtocolFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory>
&outputProtocolFactory);
TSimpleServer(
const std::shared_ptr<apache::thrift::TProcessor> &processor,
const std::shared_ptr<apache::thrift::transport::TServerTransport> &serverTransport,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&inputTransportFactory,
const std::shared_ptr<apache::thrift::transport::TTransportFactory>
&outputTransportFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory>
&inputProtocolFactory,
const std::shared_ptr<apache::thrift::protocol::TProtocolFactory>
&outputProtocolFactory);
~TSimpleServer();
protected:
void onClientConnected(const std::shared_ptr<TConnectedClient> &pClient) override
/* override */;
void onClientDisconnected(TConnectedClient *pClient) override /* override */;
private:
void setConcurrentClientLimit(int64_t newLimit) override; // hide
};
} // namespace server
} // namespace thrift
} // namespace apache
#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_

View file

@ -0,0 +1,244 @@
/*
* Copyright 2006 Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
#include <cstring>
#include <zephyr/net/tls_credentials.h>
#include <zephyr/posix/sys/socket.h>
#include <zephyr/posix/unistd.h>
#include <thrift/thrift_export.h>
#include <thrift/transport/TSSLServerSocket.h>
#include <thrift/transport/TSSLSocket.h>
#include <thrift/transport/ThriftTLScertificateType.h>
#include <thrift/transport/TSocketUtils.h>
template <class T> inline void *cast_sockopt(T *v)
{
return reinterpret_cast<void *>(v);
}
void destroyer_of_fine_sockets(THRIFT_SOCKET *ssock);
namespace apache
{
namespace thrift
{
namespace transport
{
/**
* SSL server socket implementation.
*/
TSSLServerSocket::TSSLServerSocket(int port, std::shared_ptr<TSSLSocketFactory> factory)
: TServerSocket(port), factory_(factory)
{
factory_->server(true);
}
TSSLServerSocket::TSSLServerSocket(const std::string &address, int port,
std::shared_ptr<TSSLSocketFactory> factory)
: TServerSocket(address, port), factory_(factory)
{
factory_->server(true);
}
TSSLServerSocket::TSSLServerSocket(int port, int sendTimeout, int recvTimeout,
std::shared_ptr<TSSLSocketFactory> factory)
: TServerSocket(port, sendTimeout, recvTimeout), factory_(factory)
{
factory_->server(true);
}
std::shared_ptr<TSocket> TSSLServerSocket::createSocket(THRIFT_SOCKET client)
{
if (interruptableChildren_) {
return factory_->createSocket(client, pChildInterruptSockReader_);
} else {
return factory_->createSocket(client);
}
}
void TSSLServerSocket::listen()
{
THRIFT_SOCKET sv[2];
// Create the socket pair used to interrupt
if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) {
GlobalOutput.perror("TServerSocket::listen() socketpair() interrupt",
THRIFT_GET_SOCKET_ERROR);
interruptSockWriter_ = THRIFT_INVALID_SOCKET;
interruptSockReader_ = THRIFT_INVALID_SOCKET;
} else {
interruptSockWriter_ = sv[1];
interruptSockReader_ = sv[0];
}
// Create the socket pair used to interrupt all clients
if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) {
GlobalOutput.perror("TServerSocket::listen() socketpair() childInterrupt",
THRIFT_GET_SOCKET_ERROR);
childInterruptSockWriter_ = THRIFT_INVALID_SOCKET;
pChildInterruptSockReader_.reset();
} else {
childInterruptSockWriter_ = sv[1];
pChildInterruptSockReader_ = std::shared_ptr<THRIFT_SOCKET>(
new THRIFT_SOCKET(sv[0]), destroyer_of_fine_sockets);
}
// Validate port number
if (port_ < 0 || port_ > 0xFFFF) {
throw TTransportException(TTransportException::BAD_ARGS,
"Specified port is invalid");
}
// Resolve host:port strings into an iterable of struct addrinfo*
AddressResolutionHelper resolved_addresses;
try {
resolved_addresses.resolve(address_, std::to_string(port_), SOCK_STREAM,
AI_PASSIVE | AI_V4MAPPED);
} catch (const std::system_error &e) {
GlobalOutput.printf("getaddrinfo() -> %d; %s", e.code().value(), e.what());
close();
throw TTransportException(TTransportException::NOT_OPEN,
"Could not resolve host for server socket.");
}
// we may want to try to bind more than once, since THRIFT_NO_SOCKET_CACHING doesn't
// always seem to work. The client can configure the retry variables.
int retries = 0;
int errno_copy = 0;
// -- TCP socket -- //
auto addr_iter = AddressResolutionHelper::Iter{};
// Via DNS or somehow else, single hostname can resolve into many addresses.
// Results may contain perhaps a mix of IPv4 and IPv6. Here, we iterate
// over what system gave us, picking the first address that works.
do {
if (!addr_iter) {
// init + recycle over many retries
addr_iter = resolved_addresses.iterate();
}
auto trybind = *addr_iter++;
serverSocket_ = socket(trybind->ai_family, trybind->ai_socktype, IPPROTO_TLS_1_2);
if (serverSocket_ == -1) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
continue;
}
_setup_sockopts();
_setup_tcp_sockopts();
static const sec_tag_t sec_tag_list[3] = {
Thrift_TLS_CA_CERT_TAG, Thrift_TLS_SERVER_CERT_TAG, Thrift_TLS_PRIVATE_KEY};
int ret = setsockopt(serverSocket_, SOL_TLS, TLS_SEC_TAG_LIST, sec_tag_list,
sizeof(sec_tag_list));
if (ret != 0) {
throw TTransportException(TTransportException::NOT_OPEN,
"set TLS_SEC_TAG_LIST failed");
}
#ifdef IPV6_V6ONLY
if (trybind->ai_family == AF_INET6) {
int zero = 0;
if (-1 == setsockopt(serverSocket_, IPPROTO_IPV6, IPV6_V6ONLY,
cast_sockopt(&zero), sizeof(zero))) {
GlobalOutput.perror("TServerSocket::listen() IPV6_V6ONLY ",
THRIFT_GET_SOCKET_ERROR);
}
}
#endif // #ifdef IPV6_V6ONLY
if (0 == ::bind(serverSocket_, trybind->ai_addr,
static_cast<int>(trybind->ai_addrlen))) {
break;
}
errno_copy = THRIFT_GET_SOCKET_ERROR;
// use short circuit evaluation here to only sleep if we need to
} while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));
// retrieve bind info
if (port_ == 0 && retries <= retryLimit_) {
struct sockaddr_storage sa;
socklen_t len = sizeof(sa);
std::memset(&sa, 0, len);
if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr *>(&sa), &len) <
0) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
} else {
if (sa.ss_family == AF_INET6) {
const auto *sin =
reinterpret_cast<const struct sockaddr_in6 *>(&sa);
port_ = ntohs(sin->sin6_port);
} else {
const auto *sin = reinterpret_cast<const struct sockaddr_in *>(&sa);
port_ = ntohs(sin->sin_port);
}
}
}
// throw error if socket still wasn't created successfully
if (serverSocket_ == THRIFT_INVALID_SOCKET) {
GlobalOutput.perror("TServerSocket::listen() socket() ", errno_copy);
close();
throw TTransportException(TTransportException::NOT_OPEN,
"Could not create server socket.", errno_copy);
}
// throw an error if we failed to bind properly
if (retries > retryLimit_) {
char errbuf[1024];
THRIFT_SNPRINTF(errbuf, sizeof(errbuf),
"TServerSocket::listen() Could not bind to port %d", port_);
GlobalOutput(errbuf);
close();
throw TTransportException(TTransportException::NOT_OPEN, "Could not bind",
errno_copy);
}
if (listenCallback_) {
listenCallback_(serverSocket_);
}
// Call listen
if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
close();
throw TTransportException(TTransportException::NOT_OPEN, "Could not listen",
errno_copy);
}
// The socket is now listening!
listening_ = true;
}
void TSSLServerSocket::close()
{
rwMutex_.lock();
if (pChildInterruptSockReader_ != nullptr &&
*pChildInterruptSockReader_ != THRIFT_INVALID_SOCKET) {
::THRIFT_CLOSESOCKET(*pChildInterruptSockReader_);
*pChildInterruptSockReader_ = THRIFT_INVALID_SOCKET;
}
rwMutex_.unlock();
TServerSocket::close();
}
} // namespace transport
} // namespace thrift
} // namespace apache

View file

@ -0,0 +1,67 @@
/*
* Copyright 2006 Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_
#define _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_ 1
#include <thrift/transport/TServerSocket.h>
namespace apache
{
namespace thrift
{
namespace transport
{
class TSSLSocketFactory;
/**
* Server socket that accepts SSL connections.
*/
class TSSLServerSocket : public TServerSocket
{
public:
/**
* Constructor. Binds to all interfaces.
*
* @param port Listening port
* @param factory SSL socket factory implementation
*/
TSSLServerSocket(int port, std::shared_ptr<TSSLSocketFactory> factory);
/**
* Constructor. Binds to the specified address.
*
* @param address Address to bind to
* @param port Listening port
* @param factory SSL socket factory implementation
*/
TSSLServerSocket(const std::string &address, int port,
std::shared_ptr<TSSLSocketFactory> factory);
/**
* Constructor. Binds to all interfaces.
*
* @param port Listening port
* @param sendTimeout Socket send timeout
* @param recvTimeout Socket receive timeout
* @param factory SSL socket factory implementation
*/
TSSLServerSocket(int port, int sendTimeout, int recvTimeout,
std::shared_ptr<TSSLSocketFactory> factory);
void listen() override;
void close() override;
protected:
std::shared_ptr<TSocket> createSocket(THRIFT_SOCKET socket) override;
std::shared_ptr<TSSLSocketFactory> factory_;
};
} // namespace transport
} // namespace thrift
} // namespace apache
#endif

View file

@ -0,0 +1,656 @@
/*
* Copyright 2006 Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
#include <thrift/thrift-config.h>
#include <cstring>
#include <errno.h>
#include <memory>
#include <string>
#ifdef HAVE_ARPA_INET_H
#include <zephyr/posix/arpa/inet.h>
#endif
#include <sys/types.h>
#ifdef HAVE_POLL_H
#include <poll.h>
#endif
#include <zephyr/net/tls_credentials.h>
#include <fcntl.h>
#include <thrift/TToString.h>
#include <thrift/concurrency/Mutex.h>
#include <thrift/transport/PlatformSocket.h>
#include <thrift/transport/TSSLSocket.h>
#include <thrift/transport/ThriftTLScertificateType.h>
using namespace apache::thrift::concurrency;
using std::string;
struct CRYPTO_dynlock_value {
Mutex mutex;
};
namespace apache
{
namespace thrift
{
namespace transport
{
static bool matchName(const char *host, const char *pattern, int size);
static char uppercase(char c);
// TSSLSocket implementation
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config)
: TSocket(config), server_(false), ctx_(ctx)
{
init();
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx,
std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config)
: TSocket(config), server_(false), ctx_(ctx)
{
init();
interruptListener_ = interruptListener;
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket,
std::shared_ptr<TConfiguration> config)
: TSocket(socket, config), server_(false), ctx_(ctx)
{
init();
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket,
std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config)
: TSocket(socket, interruptListener, config), server_(false), ctx_(ctx)
{
init();
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port,
std::shared_ptr<TConfiguration> config)
: TSocket(host, port, config), server_(false), ctx_(ctx)
{
init();
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port,
std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config)
: TSocket(host, port, config), server_(false), ctx_(ctx)
{
init();
interruptListener_ = interruptListener;
}
TSSLSocket::~TSSLSocket()
{
close();
}
template <class T> inline void *cast_sockopt(T *v)
{
return reinterpret_cast<void *>(v);
}
void TSSLSocket::authorize()
{
}
void TSSLSocket::openSecConnection(struct addrinfo *res)
{
socket_ = socket(res->ai_family, res->ai_socktype, ctx_->protocol);
if (socket_ == THRIFT_INVALID_SOCKET) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TSocket::open() socket() " + getSocketInfo(), errno_copy);
throw TTransportException(TTransportException::NOT_OPEN, "socket()", errno_copy);
}
static const sec_tag_t sec_tag_list[3] = {
Thrift_TLS_CA_CERT_TAG, Thrift_TLS_SERVER_CERT_TAG, Thrift_TLS_PRIVATE_KEY};
int ret =
setsockopt(socket_, SOL_TLS, TLS_SEC_TAG_LIST, sec_tag_list, sizeof(sec_tag_list));
if (ret != 0) {
throw TTransportException(TTransportException::NOT_OPEN,
"set TLS_SEC_TAG_LIST failed");
}
ret = setsockopt(socket_, SOL_TLS, TLS_PEER_VERIFY, &(ctx_->verifyMode),
sizeof(ctx_->verifyMode));
if (ret != 0) {
throw TTransportException(TTransportException::NOT_OPEN,
"set TLS_PEER_VERIFY failed");
}
ret = setsockopt(socket_, SOL_TLS, TLS_HOSTNAME, host_.c_str(), host_.size());
if (ret != 0) {
throw TTransportException(TTransportException::NOT_OPEN, "set TLS_HOSTNAME failed");
}
// Send timeout
if (sendTimeout_ > 0) {
setSendTimeout(sendTimeout_);
}
// Recv timeout
if (recvTimeout_ > 0) {
setRecvTimeout(recvTimeout_);
}
if (keepAlive_) {
setKeepAlive(keepAlive_);
}
// Linger
setLinger(lingerOn_, lingerVal_);
// No delay
setNoDelay(noDelay_);
#ifdef SO_NOSIGPIPE
{
int one = 1;
setsockopt(socket_, SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(one));
}
#endif
// Uses a low min RTO if asked to.
#ifdef TCP_LOW_MIN_RTO
if (getUseLowMinRto()) {
int one = 1;
setsockopt(socket_, IPPROTO_TCP, TCP_LOW_MIN_RTO, &one, sizeof(one));
}
#endif
// Set the socket to be non blocking for connect if a timeout exists
int flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0);
if (connTimeout_ > 0) {
if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK)) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TSocket::open() THRIFT_FCNTL() " + getSocketInfo(),
errno_copy);
throw TTransportException(TTransportException::NOT_OPEN,
"THRIFT_FCNTL() failed", errno_copy);
}
} else {
if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags & ~THRIFT_O_NONBLOCK)) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TSocket::open() THRIFT_FCNTL " + getSocketInfo(),
errno_copy);
throw TTransportException(TTransportException::NOT_OPEN,
"THRIFT_FCNTL() failed", errno_copy);
}
}
// Connect the socket
ret = connect(socket_, res->ai_addr, static_cast<int>(res->ai_addrlen));
// success case
if (ret == 0) {
goto done;
}
if ((THRIFT_GET_SOCKET_ERROR != THRIFT_EINPROGRESS) &&
(THRIFT_GET_SOCKET_ERROR != THRIFT_EWOULDBLOCK)) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TSocket::open() connect() " + getSocketInfo(), errno_copy);
throw TTransportException(TTransportException::NOT_OPEN, "connect() failed",
errno_copy);
}
struct THRIFT_POLLFD fds[1];
std::memset(fds, 0, sizeof(fds));
fds[0].fd = socket_;
fds[0].events = THRIFT_POLLOUT;
ret = THRIFT_POLL(fds, 1, connTimeout_);
if (ret > 0) {
// Ensure the socket is connected and that there are no errors set
int val;
socklen_t lon;
lon = sizeof(int);
int ret2 = getsockopt(socket_, SOL_SOCKET, SO_ERROR, cast_sockopt(&val), &lon);
if (ret2 == -1) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TSocket::open() getsockopt() " + getSocketInfo(),
errno_copy);
throw TTransportException(TTransportException::NOT_OPEN, "getsockopt()",
errno_copy);
}
// no errors on socket, go to town
if (val == 0) {
goto done;
}
GlobalOutput.perror("TSocket::open() error on socket (after THRIFT_POLL) " +
getSocketInfo(),
val);
throw TTransportException(TTransportException::NOT_OPEN, "socket open() error",
val);
} else if (ret == 0) {
// socket timed out
string errStr = "TSocket::open() timed out " + getSocketInfo();
GlobalOutput(errStr.c_str());
throw TTransportException(TTransportException::NOT_OPEN, "open() timed out");
} else {
// error on THRIFT_POLL()
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TSocket::open() THRIFT_POLL() " + getSocketInfo(), errno_copy);
throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_POLL() failed",
errno_copy);
}
done:
// Set socket back to normal mode (blocking)
if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags)) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TSocket::open() THRIFT_FCNTL " + getSocketInfo(), errno_copy);
throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_FCNTL() failed",
errno_copy);
}
setCachedAddress(res->ai_addr, static_cast<socklen_t>(res->ai_addrlen));
}
void TSSLSocket::init()
{
handshakeCompleted_ = false;
readRetryCount_ = 0;
eventSafe_ = false;
}
void TSSLSocket::open()
{
if (isOpen() || server()) {
throw TTransportException(TTransportException::BAD_ARGS);
}
// Validate port number
if (port_ < 0 || port_ > 0xFFFF) {
throw TTransportException(TTransportException::BAD_ARGS,
"Specified port is invalid");
}
struct addrinfo hints, *res, *res0;
res = nullptr;
res0 = nullptr;
int error;
char port[sizeof("65535")];
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = PF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
sprintf(port, "%d", port_);
error = getaddrinfo(host_.c_str(), port, &hints, &res0);
if (error == DNS_EAI_NODATA) {
hints.ai_flags &= ~AI_ADDRCONFIG;
error = getaddrinfo(host_.c_str(), port, &hints, &res0);
}
if (error) {
string errStr = "TSocket::open() getaddrinfo() " + getSocketInfo() +
string(THRIFT_GAI_STRERROR(error));
GlobalOutput(errStr.c_str());
close();
throw TTransportException(TTransportException::NOT_OPEN,
"Could not resolve host for client socket.");
}
// Cycle through all the returned addresses until one
// connects or push the exception up.
for (res = res0; res; res = res->ai_next) {
try {
openSecConnection(res);
break;
} catch (TTransportException &) {
if (res->ai_next) {
close();
} else {
close();
freeaddrinfo(res0); // cleanup on failure
throw;
}
}
}
// Free address structure memory
freeaddrinfo(res0);
}
TSSLSocketFactory::TSSLSocketFactory(SSLProtocol protocol)
: ctx_(std::make_shared<SSLContext>()), server_(false)
{
switch (protocol) {
case SSLTLS:
break;
case TLSv1_0:
break;
case TLSv1_1:
ctx_->protocol = IPPROTO_TLS_1_1;
break;
case TLSv1_2:
ctx_->protocol = IPPROTO_TLS_1_2;
break;
default:
throw TTransportException(TTransportException::BAD_ARGS,
"Specified protocol is invalid");
}
}
TSSLSocketFactory::~TSSLSocketFactory()
{
}
std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket()
{
std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
setup(ssl);
return ssl;
}
std::shared_ptr<TSSLSocket>
TSSLSocketFactory::createSocket(std::shared_ptr<THRIFT_SOCKET> interruptListener)
{
std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, interruptListener));
setup(ssl);
return ssl;
}
std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket)
{
std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
setup(ssl);
return ssl;
}
std::shared_ptr<TSSLSocket>
TSSLSocketFactory::createSocket(THRIFT_SOCKET socket,
std::shared_ptr<THRIFT_SOCKET> interruptListener)
{
std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket, interruptListener));
setup(ssl);
return ssl;
}
std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string &host, int port)
{
std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
setup(ssl);
return ssl;
}
std::shared_ptr<TSSLSocket>
TSSLSocketFactory::createSocket(const string &host, int port,
std::shared_ptr<THRIFT_SOCKET> interruptListener)
{
std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port, interruptListener));
setup(ssl);
return ssl;
}
static void tlsCredtErrMsg(string &errors, const int status);
void TSSLSocketFactory::setup(std::shared_ptr<TSSLSocket> ssl)
{
ssl->server(server());
if (access_ == nullptr && !server()) {
access_ = std::shared_ptr<AccessManager>(new DefaultClientAccessManager);
}
if (access_ != nullptr) {
ssl->access(access_);
}
}
void TSSLSocketFactory::ciphers(const string &enable)
{
}
void TSSLSocketFactory::authenticate(bool required)
{
if (required) {
ctx_->verifyMode = TLS_PEER_VERIFY_REQUIRED;
} else {
ctx_->verifyMode = TLS_PEER_VERIFY_NONE;
}
}
void TSSLSocketFactory::loadCertificate(const char *path, const char *format)
{
if (path == nullptr || format == nullptr) {
throw TTransportException(
TTransportException::BAD_ARGS,
"loadCertificateChain: either <path> or <format> is nullptr");
}
if (strcmp(format, "PEM") == 0) {
} else {
throw TSSLException("Unsupported certificate format: " + string(format));
}
}
void TSSLSocketFactory::loadCertificateFromBuffer(const char *aCertificate, const char *format)
{
if (aCertificate == nullptr || format == nullptr) {
throw TTransportException(TTransportException::BAD_ARGS,
"loadCertificate: either <path> or <format> is nullptr");
}
if (strcmp(format, "PEM") == 0) {
const int status = tls_credential_add(Thrift_TLS_SERVER_CERT_TAG,
TLS_CREDENTIAL_SERVER_CERTIFICATE,
aCertificate, strlen(aCertificate) + 1);
if (status != 0) {
string errors;
tlsCredtErrMsg(errors, status);
throw TSSLException("tls_credential_add: " + errors);
}
} else {
throw TSSLException("Unsupported certificate format: " + string(format));
}
}
void TSSLSocketFactory::loadPrivateKey(const char *path, const char *format)
{
if (path == nullptr || format == nullptr) {
throw TTransportException(TTransportException::BAD_ARGS,
"loadPrivateKey: either <path> or <format> is nullptr");
}
if (strcmp(format, "PEM") == 0) {
if (0) {
string errors;
// tlsCredtErrMsg(errors, status);
throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors);
}
}
}
void TSSLSocketFactory::loadPrivateKeyFromBuffer(const char *aPrivateKey, const char *format)
{
if (aPrivateKey == nullptr || format == nullptr) {
throw TTransportException(TTransportException::BAD_ARGS,
"loadPrivateKey: either <path> or <format> is nullptr");
}
if (strcmp(format, "PEM") == 0) {
const int status =
tls_credential_add(Thrift_TLS_PRIVATE_KEY, TLS_CREDENTIAL_PRIVATE_KEY,
aPrivateKey, strlen(aPrivateKey) + 1);
if (status != 0) {
string errors;
tlsCredtErrMsg(errors, status);
throw TSSLException("SSL_CTX_use_PrivateKey: " + errors);
}
} else {
throw TSSLException("Unsupported certificate format: " + string(format));
}
}
void TSSLSocketFactory::loadTrustedCertificates(const char *path, const char *capath)
{
if (path == nullptr) {
throw TTransportException(TTransportException::BAD_ARGS,
"loadTrustedCertificates: <path> is nullptr");
}
if (0) {
string errors;
// tlsCredtErrMsg(errors, status);
throw TSSLException("SSL_CTX_load_verify_locations: " + errors);
}
}
void TSSLSocketFactory::loadTrustedCertificatesFromBuffer(const char *aCertificate,
const char *aChain)
{
if (aCertificate == nullptr) {
throw TTransportException(TTransportException::BAD_ARGS,
"loadTrustedCertificates: aCertificate is empty");
}
const int status = tls_credential_add(Thrift_TLS_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE,
aCertificate, strlen(aCertificate) + 1);
if (status != 0) {
string errors;
tlsCredtErrMsg(errors, status);
throw TSSLException("X509_STORE_add_cert: " + errors);
}
if (aChain) {
}
}
void TSSLSocketFactory::randomize()
{
}
void TSSLSocketFactory::overrideDefaultPasswordCallback()
{
}
void TSSLSocketFactory::server(bool flag)
{
server_ = flag;
ctx_->verifyMode = TLS_PEER_VERIFY_NONE;
}
bool TSSLSocketFactory::server() const
{
return server_;
}
int TSSLSocketFactory::passwordCallback(char *password, int size, int, void *data)
{
auto *factory = (TSSLSocketFactory *)data;
string userPassword;
factory->getPassword(userPassword, size);
int length = static_cast<int>(userPassword.size());
if (length > size) {
length = size;
}
strncpy(password, userPassword.c_str(), length);
userPassword.assign(userPassword.size(), '*');
return length;
}
// extract error messages from error queue
static void tlsCredtErrMsg(string &errors, const int status)
{
if (status == EACCES) {
errors = "Access to the TLS credential subsystem was denied";
} else if (status == ENOMEM) {
errors = "Not enough memory to add new TLS credential";
} else if (status == EEXIST) {
errors = "TLS credential of specific tag and type already exists";
} else {
errors = "Unknown error";
}
}
/**
* Default implementation of AccessManager
*/
Decision DefaultClientAccessManager::verify(const sockaddr_storage &sa) noexcept
{
(void)sa;
return SKIP;
}
Decision DefaultClientAccessManager::verify(const string &host, const char *name, int size) noexcept
{
if (host.empty() || name == nullptr || size <= 0) {
return SKIP;
}
return (matchName(host.c_str(), name, size) ? ALLOW : SKIP);
}
Decision DefaultClientAccessManager::verify(const sockaddr_storage &sa, const char *data,
int size) noexcept
{
bool match = false;
if (sa.ss_family == AF_INET && size == sizeof(in_addr)) {
match = (memcmp(&((sockaddr_in *)&sa)->sin_addr, data, size) == 0);
} else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) {
match = (memcmp(&((sockaddr_in6 *)&sa)->sin6_addr, data, size) == 0);
}
return (match ? ALLOW : SKIP);
}
/**
* Match a name with a pattern. The pattern may include wildcard. A single
* wildcard "*" can match up to one component in the domain name.
*
* @param host Host name, typically the name of the remote host
* @param pattern Name retrieved from certificate
* @param size Size of "pattern"
* @return True, if "host" matches "pattern". False otherwise.
*/
bool matchName(const char *host, const char *pattern, int size)
{
bool match = false;
int i = 0, j = 0;
while (i < size && host[j] != '\0') {
if (uppercase(pattern[i]) == uppercase(host[j])) {
i++;
j++;
continue;
}
if (pattern[i] == '*') {
while (host[j] != '.' && host[j] != '\0') {
j++;
}
i++;
continue;
}
break;
}
if (i == size && host[j] == '\0') {
match = true;
}
return match;
}
// This is to work around the Turkish locale issue, i.e.,
// toupper('i') != toupper('I') if locale is "tr_TR"
char uppercase(char c)
{
if ('a' <= c && c <= 'z') {
return c + ('A' - 'a');
}
return c;
}
} // namespace transport
} // namespace thrift
} // namespace apache

View file

@ -0,0 +1,465 @@
/*
* Copyright 2006 Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_
#define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1
// Put this first to avoid WIN32 build failure
#include <thrift/transport/TSocket.h>
#include <string>
#include <thrift/concurrency/Mutex.h>
#include <zephyr/posix/sys/socket.h>
namespace apache
{
namespace thrift
{
namespace transport
{
class AccessManager;
class SSLContext;
enum SSLProtocol {
SSLTLS = 0, // Supports SSLv2 and SSLv3 handshake but only negotiates at TLSv1_0 or later.
// SSLv2 = 1, // HORRIBLY INSECURE!
SSLv3 = 2, // Supports SSLv3 only - also horribly insecure!
TLSv1_0 = 3, // Supports TLSv1_0 or later.
TLSv1_1 = 4, // Supports TLSv1_1 or later.
TLSv1_2 = 5, // Supports TLSv1_2 or later.
LATEST = TLSv1_2
};
#define TSSL_EINTR 0
#define TSSL_DATA 1
/**
* Initialize OpenSSL library. This function, or some other
* equivalent function to initialize OpenSSL, must be called before
* TSSLSocket is used. If you set TSSLSocketFactory to use manual
* OpenSSL initialization, you should call this function or otherwise
* ensure OpenSSL is initialized yourself.
*/
void initializeOpenSSL();
/**
* Cleanup OpenSSL library. This function should be called to clean
* up OpenSSL after use of OpenSSL functionality is finished. If you
* set TSSLSocketFactory to use manual OpenSSL initialization, you
* should call this function yourself or ensure that whatever
* initialized OpenSSL cleans it up too.
*/
void cleanupOpenSSL();
/**
* OpenSSL implementation for SSL socket interface.
*/
class TSSLSocket : public TSocket
{
public:
~TSSLSocket() override;
/**
* TTransport interface.
*/
void open() override;
/**
* Set whether to use client or server side SSL handshake protocol.
*
* @param flag Use server side handshake protocol if true.
*/
void server(bool flag)
{
server_ = flag;
}
/**
* Determine whether the SSL socket is server or client mode.
*/
bool server() const
{
return server_;
}
/**
* Set AccessManager.
*
* @param manager Instance of AccessManager
*/
virtual void access(std::shared_ptr<AccessManager> manager)
{
access_ = manager;
}
/**
* Set eventSafe flag if libevent is used.
*/
void setLibeventSafe()
{
eventSafe_ = true;
}
/**
* Determines whether SSL Socket is libevent safe or not.
*/
bool isLibeventSafe() const
{
return eventSafe_;
}
void authenticate(bool required);
protected:
/**
* Constructor.
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor with an interrupt signal.
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx,
std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor, create an instance of TSSLSocket given an existing socket.
*
* @param socket An existing socket
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor, create an instance of TSSLSocket given an existing socket that can be
* interrupted.
*
* @param socket An existing socket
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket,
std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor.
*
* @param host Remote host name
* @param port Remote port number
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor with an interrupt signal.
*
* @param host Remote host name
* @param port Remote port number
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port,
std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Authorize peer access after SSL handshake completes.
*/
virtual void authorize();
/**
* Initiate SSL handshake if not already initiated.
*/
void initializeHandshake();
/**
* Initiate SSL handshake params.
*/
void initializeHandshakeParams();
/**
* Check if SSL handshake is completed or not.
*/
bool checkHandshake();
/**
* Waits for an socket or shutdown event.
*
* @throw TTransportException::INTERRUPTED if interrupted is signaled.
*
* @return TSSL_EINTR if EINTR happened on the underlying socket
* TSSL_DATA if data is available on the socket.
*/
unsigned int waitForEvent(bool wantRead);
void openSecConnection(struct addrinfo *res);
bool server_;
std::shared_ptr<SSLContext> ctx_;
std::shared_ptr<AccessManager> access_;
friend class TSSLSocketFactory;
private:
bool handshakeCompleted_;
int readRetryCount_;
bool eventSafe_;
void init();
};
/**
* SSL socket factory. SSL sockets should be created via SSL factory.
* The factory will automatically initialize and cleanup openssl as long as
* there is a TSSLSocketFactory instantiated, and as long as the static
* boolean manualOpenSSLInitialization_ is set to false, the default.
*
* If you would like to initialize and cleanup openssl yourself, set
* manualOpenSSLInitialization_ to true and TSSLSocketFactory will no
* longer be responsible for openssl initialization and teardown.
*
* It is the responsibility of the code using TSSLSocketFactory to
* ensure that the factory lifetime exceeds the lifetime of any sockets
* it might create. If this is not guaranteed, a socket may call into
* openssl after the socket factory has cleaned up openssl! This
* guarantee is unnecessary if manualOpenSSLInitialization_ is true,
* however, since it would be up to the consuming application instead.
*/
class TSSLSocketFactory
{
public:
/**
* Constructor/Destructor
*
* @param protocol The SSL/TLS protocol to use.
*/
TSSLSocketFactory(SSLProtocol protocol = SSLTLS);
virtual ~TSSLSocketFactory();
/**
* Create an instance of TSSLSocket with a fresh new socket.
*/
virtual std::shared_ptr<TSSLSocket> createSocket();
/**
* Create an instance of TSSLSocket with a fresh new socket, which is interruptable.
*/
virtual std::shared_ptr<TSSLSocket>
createSocket(std::shared_ptr<THRIFT_SOCKET> interruptListener);
/**
* Create an instance of TSSLSocket with the given socket.
*
* @param socket An existing socket.
*/
virtual std::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket);
/**
* Create an instance of TSSLSocket with the given socket which is interruptable.
*
* @param socket An existing socket.
*/
virtual std::shared_ptr<TSSLSocket>
createSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
/**
* Create an instance of TSSLSocket.
*
* @param host Remote host to be connected to
* @param port Remote port to be connected to
*/
virtual std::shared_ptr<TSSLSocket> createSocket(const std::string &host, int port);
/**
* Create an instance of TSSLSocket.
*
* @param host Remote host to be connected to
* @param port Remote port to be connected to
*/
virtual std::shared_ptr<TSSLSocket>
createSocket(const std::string &host, int port,
std::shared_ptr<THRIFT_SOCKET> interruptListener);
/**
* Set ciphers to be used in SSL handshake process.
*
* @param ciphers A list of ciphers
*/
virtual void ciphers(const std::string &enable);
/**
* Enable/Disable authentication.
*
* @param required Require peer to present valid certificate if true
*/
virtual void authenticate(bool required);
/**
* Load server certificate.
*
* @param path Path to the certificate file
* @param format Certificate file format
*/
virtual void loadCertificate(const char *path, const char *format = "PEM");
virtual void loadCertificateFromBuffer(const char *aCertificate,
const char *format = "PEM");
/**
* Load private key.
*
* @param path Path to the private key file
* @param format Private key file format
*/
virtual void loadPrivateKey(const char *path, const char *format = "PEM");
virtual void loadPrivateKeyFromBuffer(const char *aPrivateKey, const char *format = "PEM");
/**
* Load trusted certificates from specified file.
*
* @param path Path to trusted certificate file
*/
virtual void loadTrustedCertificates(const char *path, const char *capath = nullptr);
virtual void loadTrustedCertificatesFromBuffer(const char *aCertificate,
const char *aChain = nullptr);
/**
* Default randomize method.
*/
virtual void randomize();
/**
* Override default OpenSSL password callback with getPassword().
*/
void overrideDefaultPasswordCallback();
/**
* Set/Unset server mode.
*
* @param flag Server mode if true
*/
virtual void server(bool flag);
/**
* Determine whether the socket is in server or client mode.
*
* @return true, if server mode, or, false, if client mode
*/
virtual bool server() const;
/**
* Set AccessManager.
*
* @param manager The AccessManager instance
*/
virtual void access(std::shared_ptr<AccessManager> manager)
{
access_ = manager;
}
static void setManualOpenSSLInitialization(bool manualOpenSSLInitialization)
{
manualOpenSSLInitialization_ = manualOpenSSLInitialization;
}
protected:
std::shared_ptr<SSLContext> ctx_;
/**
* Override this method for custom password callback. It may be called
* multiple times at any time during a session as necessary.
*
* @param password Pass collected password to OpenSSL
* @param size Maximum length of password including NULL character
*/
virtual void getPassword(std::string & /* password */, int /* size */)
{
}
private:
bool server_;
std::shared_ptr<AccessManager> access_;
static concurrency::Mutex mutex_;
static uint64_t count_;
THRIFT_EXPORT static bool manualOpenSSLInitialization_;
void setup(std::shared_ptr<TSSLSocket> ssl);
static int passwordCallback(char *password, int size, int, void *data);
};
/**
* SSL exception.
*/
class TSSLException : public TTransportException
{
public:
TSSLException(const std::string &message)
: TTransportException(TTransportException::INTERNAL_ERROR, message)
{
}
const char *what() const noexcept override
{
if (message_.empty()) {
return "TSSLException";
} else {
return message_.c_str();
}
}
};
struct SSLContext {
int verifyMode = TLS_PEER_VERIFY_REQUIRED;
net_ip_protocol_secure protocol = IPPROTO_TLS_1_0;
};
/**
* Callback interface for access control. It's meant to verify the remote host.
* It's constructed when application starts and set to TSSLSocketFactory
* instance. It's passed onto all TSSLSocket instances created by this factory
* object.
*/
class AccessManager
{
public:
enum Decision {
DENY = -1, // deny access
SKIP = 0, // cannot make decision, move on to next (if any)
ALLOW = 1 // allow access
};
/**
* Destructor
*/
virtual ~AccessManager() = default;
/**
* Determine whether the peer should be granted access or not. It's called
* once after the SSL handshake completes successfully, before peer certificate
* is examined.
*
* If a valid decision (ALLOW or DENY) is returned, the peer certificate is
* not to be verified.
*
* @param sa Peer IP address
* @return True if the peer is trusted, false otherwise
*/
virtual Decision verify(const sockaddr_storage & /* sa */) noexcept
{
return DENY;
}
/**
* Determine whether the peer should be granted access or not. It's called
* every time a DNS subjectAltName/common name is extracted from peer's
* certificate.
*
* @param host Client mode: host name returned by TSocket::getHost()
* Server mode: host name returned by TSocket::getPeerHost()
* @param name SubjectAltName or common name extracted from peer certificate
* @param size Length of name
* @return True if the peer is trusted, false otherwise
*
* Note: The "name" parameter may be UTF8 encoded.
*/
virtual Decision verify(const std::string & /* host */, const char * /* name */,
int /* size */) noexcept
{
return DENY;
}
/**
* Determine whether the peer should be granted access or not. It's called
* every time an IP subjectAltName is extracted from peer's certificate.
*
* @param sa Peer IP address retrieved from the underlying socket
* @param data IP address extracted from certificate
* @param size Length of the IP address
* @return True if the peer is trusted, false otherwise
*/
virtual Decision verify(const sockaddr_storage & /* sa */, const char * /* data */,
int /* size */) noexcept
{
return DENY;
}
};
typedef AccessManager::Decision Decision;
class DefaultClientAccessManager : public AccessManager
{
public:
// AccessManager interface
Decision verify(const sockaddr_storage &sa) noexcept override;
Decision verify(const std::string &host, const char *name, int size) noexcept override;
Decision verify(const sockaddr_storage &sa, const char *data, int size) noexcept override;
};
} // namespace transport
} // namespace thrift
} // namespace apache
#endif

View file

@ -0,0 +1,186 @@
/*
* Copyright 2006 Facebook
*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_
#define _THRIFT_TRANSPORT_TSERVERSOCKET_H_ 1
#include <functional>
#include <thrift/concurrency/Mutex.h>
#include <thrift/transport/PlatformSocket.h>
#include <thrift/transport/TServerTransport.h>
#include <sys/types.h>
#ifdef HAVE_SYS_SOCKET_H
#include <zephyr/posix/sys/socket.h>
#endif
#ifdef HAVE_NETDB_H
#include <zephyr/posix/netdb.h>
#endif
namespace apache
{
namespace thrift
{
namespace transport
{
class TSocket;
/**
* Server socket implementation of TServerTransport. Wrapper around a unix
* socket listen and accept calls.
*
*/
class TServerSocket : public TServerTransport
{
public:
typedef std::function<void(THRIFT_SOCKET fd)> socket_func_t;
const static int DEFAULT_BACKLOG = 1024;
/**
* Constructor.
*
* @param port Port number to bind to
*/
TServerSocket(int port);
/**
* Constructor.
*
* @param port Port number to bind to
* @param sendTimeout Socket send timeout
* @param recvTimeout Socket receive timeout
*/
TServerSocket(int port, int sendTimeout, int recvTimeout);
/**
* Constructor.
*
* @param address Address to bind to
* @param port Port number to bind to
*/
TServerSocket(const std::string &address, int port);
/**
* Constructor used for unix sockets.
*
* @param path Pathname for unix socket.
*/
TServerSocket(const std::string &path);
~TServerSocket() override;
bool isOpen() const override;
void setSendTimeout(int sendTimeout);
void setRecvTimeout(int recvTimeout);
void setAcceptTimeout(int accTimeout);
void setAcceptBacklog(int accBacklog);
void setRetryLimit(int retryLimit);
void setRetryDelay(int retryDelay);
void setKeepAlive(bool keepAlive)
{
keepAlive_ = keepAlive;
}
void setTcpSendBuffer(int tcpSendBuffer);
void setTcpRecvBuffer(int tcpRecvBuffer);
// listenCallback gets called just before listen, and after all Thrift
// setsockopt calls have been made. If you have custom setsockopt
// things that need to happen on the listening socket, this is the place to do it.
void setListenCallback(const socket_func_t &listenCallback)
{
listenCallback_ = listenCallback;
}
// acceptCallback gets called after each accept call, on the newly created socket.
// It is called after all Thrift setsockopt calls have been made. If you have
// custom setsockopt things that need to happen on the accepted
// socket, this is the place to do it.
void setAcceptCallback(const socket_func_t &acceptCallback)
{
acceptCallback_ = acceptCallback;
}
// When enabled (the default), new children TSockets will be constructed so
// they can be interrupted by TServerTransport::interruptChildren().
// This is more expensive in terms of system calls (poll + recv) however
// ensures a connected client cannot interfere with TServer::stop().
//
// When disabled, TSocket children do not incur an additional poll() call.
// Server-side reads are more efficient, however a client can interfere with
// the server's ability to shutdown properly by staying connected.
//
// Must be called before listen(); mode cannot be switched after that.
// \throws std::logic_error if listen() has been called
void setInterruptableChildren(bool enable);
THRIFT_SOCKET getSocketFD() override
{
return serverSocket_;
}
int getPort() const;
std::string getPath() const;
bool isUnixDomainSocket() const;
void listen() override;
void interrupt() override;
void interruptChildren() override;
void close() override;
protected:
std::shared_ptr<TTransport> acceptImpl() override;
virtual std::shared_ptr<TSocket> createSocket(THRIFT_SOCKET client);
bool interruptableChildren_;
std::shared_ptr<THRIFT_SOCKET> pChildInterruptSockReader_; // if interruptableChildren_ this
// is shared with child TSockets
void _setup_sockopts();
void _setup_tcp_sockopts();
private:
void notify(THRIFT_SOCKET notifySock);
void _setup_unixdomain_sockopts();
protected:
int port_;
std::string address_;
std::string path_;
THRIFT_SOCKET serverSocket_;
int acceptBacklog_;
int sendTimeout_;
int recvTimeout_;
int accTimeout_;
int retryLimit_;
int retryDelay_;
int tcpSendBuffer_;
int tcpRecvBuffer_;
bool keepAlive_;
bool listening_;
concurrency::Mutex rwMutex_; // thread-safe interrupt
THRIFT_SOCKET interruptSockWriter_; // is notified on interrupt()
THRIFT_SOCKET
interruptSockReader_; // is used in select/poll with serverSocket_ for interruptability
THRIFT_SOCKET childInterruptSockWriter_; // is notified on interruptChildren()
socket_func_t listenCallback_;
socket_func_t acceptCallback_;
};
} // namespace transport
} // namespace thrift
} // namespace apache
#endif // #ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_

View file

@ -0,0 +1,21 @@
/*
* Copyright 2022 Young Mei
*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef ZEPHYR_MODULES_THRIFT_SRC_THRIFT_TRANSPORT_THRIFTTLSCERTIFICATETYPE_H_
#define ZEPHYR_MODULES_THRIFT_SRC_THRIFT_TRANSPORT_THRIFTTLSCERTIFICATETYPE_H_
namespace apache::thrift::transport
{
enum ThriftTLScertificateType {
Thrift_TLS_CA_CERT_TAG,
Thrift_TLS_SERVER_CERT_TAG,
Thrift_TLS_PRIVATE_KEY,
};
} // namespace apache::thrift::transport
#endif /* ZEPHYR_MODULES_THRIFT_SRC_THRIFT_TRANSPORT_THRIFTTLSCERTIFICATETYPE_H_ */