C++ Logo

std-proposals

Advanced search

Re: [std-proposals] Monitor recursion by checking the frame pointer

From: Frederick Virchanza Gotham <cauldwell.thomas_at_[hidden]>
Date: Tue, 18 Apr 2023 10:21:58 +0100
In the other thread entitled "Function Pointer from Lambda with
Captures", we shared code for a few different ways of creating thunks
for lambdas. The two implementations were:
(a) Write the thunk as machine code onto the stack, and execute the stack
(b) Have a global pool of instantiations of a template function, and
manage a global array of pointers associated 1:1 with the
instantiations of the template function

The drawback of A is that the stack must be executable.

The drawback of B is that if you want to ensure 36 levels of
recursion, the executable file will contain 36 thunks.

Today I've come up with a third solution, which only has one thunk in
the final executable file, and which can handle as many threads and as
much recursion as your heap allows.

What I've done is keep a global map of frame pointers to lambda
objects. I got it working:

    https://godbolt.org/z/1WzGq7cj9

And here it is copy-pasted:

#include <cassert> // assert
#include <cstddef> // size_t
#include <cstdlib> // abort (if mutex fails to lock)
#include <mutex> // mutex, lock_guard
#include <utility> // index_sequence, make_index_sequence, forward, declval
#include <map> // map
#include <iterator> // prev, next

namespace std {
    extern void *frame_pointer(void);
}

// The next 5 lines are an x86_64 assembler implementation of std::frame_pointer
__asm__(
"_ZSt13frame_pointerv:\n" // mangled name of std::frame_pointer
" mov %rbp, %rax\n"
" ret\n"
);

// The stack grows negatively on the x86_64, so we use the following
// function to compare two frame pointers:
bool IsFurtherFromTop(void const *const p, void const *const q)
{
    // returns true if 'q' is closer to the top of the stack
    return p > q;
}

// The next three templates: 'ToFuncPtr', 'ToReturnType', 'IsNoExcept'
// are just helpers to make this all possible. You can scroll past them
// to Line #78

namespace detail {

// The following template turns a member function pointer into a
// normal function pointer, but we only want it to use decltype on it
template<typename ReturnType, class ClassType, bool b_noexcept,
typename... Params>
ReturnType (*ToFuncPtr(ReturnType (ClassType::*)(Params...) const
noexcept(b_noexcept)))(Params...) noexcept(b_noexcept)
{
    return nullptr; // suppress compiler warning
}
// and also the non-const version for non-const member functions (or
mutable lambdas):
template<typename ReturnType, class ClassType, bool b_noexcept,
typename... Params>
ReturnType (*ToFuncPtr(ReturnType (ClassType::*)(Params...)
noexcept(b_noexcept)))(Params...) noexcept(b_noexcept)
{
    return nullptr; // suppress compiler warning
}

// The following template isolates the return type of a member function pointer,
// but we only want it to use decltype on it. I tried using 'std::result_of_t'
// instead but I get a compiler error for the lambda being an incomplete type
template <typename ReturnType, class ClassType, typename... Params>
ReturnType ToReturnType(ReturnType (ClassType::*)(Params...) const)
{
    return std::declval<ReturnType>(); // suppress compiler warning
}
// and also the non-const version for non-const member functions (or
mutable lambdas):
template <typename ReturnType, class ClassType, typename... Params>
ReturnType ToReturnType(ReturnType (ClassType::*)(Params...))
{
    return std::declval<ReturnType>(); // suppress compiler warning
}

// The following template determines whether a non-static member
function is noexcept
template<typename ReturnType, class ClassType, bool b_noexcept,
typename... Params>
consteval bool IsNoExcept(ReturnType (ClassType::*)(Params...) const
noexcept(b_noexcept))
{
    return b_noexcept;
}
// and also the non-const version for non-const member functions (or
mutable lambdas):
template<typename ReturnType, class ClassType, bool b_noexcept,
typename... Params>
consteval bool IsNoExcept(ReturnType (ClassType::*)(Params...)
noexcept(b_noexcept))
{
    return b_noexcept;
}

} // close namespace 'detail'

template<typename LambdaType>
class thunk {
protected:
    using size_t = std::size_t;
    using R = decltype(detail::ToReturnType(&LambdaType::operator()));
    using FuncPtr = decltype(detail::ToFuncPtr
(&LambdaType::operator())); // preserves the 'noexcept'

    // In the map on the next line, we keep track of which frame pointers
    // are associated with which lambda objects
    // map.begin()->first == frame pointer
    // map.begin()->second == address of lambda object
    inline static thread_local std::map<void*,LambdaType*>
frame_pointers_for_lambdas;

public:
    explicit thunk(LambdaType &arg) noexcept {
        try {
            assert( nullptr == frame_pointers_for_lambdas[
std::frame_pointer() ] );
            frame_pointers_for_lambdas[ std::frame_pointer() ] = &arg;
        }
        catch(...) {
            assert(nullptr == "failed to add frame pointer to map");
            std::abort(); // ifdef NDEBUG
        }
    }
    ~thunk() noexcept {
        assert( nullptr != frame_pointers_for_lambdas[ std::frame_pointer() ] );
        frame_pointers_for_lambdas.erase(std::frame_pointer());
    }
    FuncPtr get() const noexcept {
        return &invoke;
    }
    operator FuncPtr() const noexcept { return this->get(); }

protected:

    template<typename... A> static R invoke(A... a)
noexcept(detail::IsNoExcept(&LambdaType::operator()))
    {
        assert( frame_pointers_for_lambdas.size() >= 1u );

        void *const fp = std::frame_pointer();

        assert( false == IsFurtherFromTop(fp,
frame_pointers_for_lambdas.begin()->first) );

        size_t i;
        for ( i = 1u; i < frame_pointers_for_lambdas.size(); ++i )
        {
            if ( IsFurtherFromTop(fp,
std::next(frame_pointers_for_lambdas.begin(),i)->first) )
            {
                break;
            }
        }

        LambdaType &mylambda =
*std::next(frame_pointers_for_lambdas.begin(),i - 1u)->second;
        return mylambda(std::forward<A>(a)...);
    }

    thunk(void) = delete;
    thunk(thunk const & ) = delete;
    thunk(thunk &&) = delete;
    thunk &operator=(thunk const & ) = delete;
    thunk &operator=(thunk &&) = delete;
    thunk const *operator&(void) const = delete; // just to avoid confusion
    thunk *operator&(void) = delete; // just to avoid confusion
};

// ========================================================
// And now the test code.............
// ========================================================

#include <cstring> // strcpy, strcat
#include <cstdio> // sprintf
#include <thread> // jthread
#include <stop_token> // stop_token
#include <iostream> // cout, endl

using std::cout, std::endl;

bool SomeLibraryFunc( bool (*arg)(int), int val )
{
    return arg(val);
}

void RecursiveFunc(int arg)
{
    if ( arg < 0 ) return;

    thread_local char str_buf[128u]; // so we don't need a mutex for cout

    auto f = [arg](int const i) -> bool
      {
        std::sprintf(str_buf, "Hello %i thunk\n",i);
        cout << str_buf;
        return i < 2;
      };

    bool const retval = SomeLibraryFunc( thunk(f), arg );

    std::strcpy(str_buf, "Returned: ");
    std::strcat(str_buf, retval ? "true\n" : "false\n");
    cout << str_buf;

    RecursiveFunc(--arg);
}

void ThreadEntryPoint(std::stop_token,int const arg)
{
    RecursiveFunc(arg);
}

void Spawn_All_Thread_And_Join(int arg)
{
    std::jthread mythreads[4u];

    for ( auto &t : mythreads )
    {
        t = std::jthread(ThreadEntryPoint,arg += 7);
    }
}

int main(int argc, char **argv)
{
    while ( argc++ < 3 ) Spawn_All_Thread_And_Join(argc);
}

Received on 2023-04-18 09:22:10