Program Listing for File MessageSpatial2DDevice.cuh

Return to documentation for file (include/flamegpu/runtime/messaging/MessageSpatial2D/MessageSpatial2DDevice.cuh)

#ifndef INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGESPATIAL2D_MESSAGESPATIAL2DDEVICE_CUH_
#define INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGESPATIAL2D_MESSAGESPATIAL2DDEVICE_CUH_

#include "flamegpu/runtime/messaging/MessageSpatial2D.h"
#include "flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceDevice.cuh"

namespace flamegpu {


class MessageSpatial2D::In {
 public:
    class Filter {
     public:
        class Message {
            const Filter &_parent;
            int relative_cell = { -2 };
            int cell_index_max = 0;
            int cell_index = 0;
            __device__ void nextStrip() {
                relative_cell++;
            }

         public:
            __device__ Message(const Filter &parent, const int relative_cell_y, const int _cell_index_max, const int _cell_index)
                : _parent(parent)
                , cell_index_max(_cell_index_max)
                , cell_index(_cell_index) {
                relative_cell = relative_cell_y;
            }
            __device__ Message(const Filter &parent)
                : _parent(parent) { }
            __device__ bool operator==(const Message& rhs) const {
                return this->relative_cell == rhs.relative_cell
                    && this->cell_index_max == rhs.cell_index_max
                    && this->cell_index == rhs.cell_index;
            }
            __device__ bool operator!=(const Message& rhs) const {
                // The incoming Message& is end(), so we don't care about that
                // We only care that the host object has reached end
                // When the strip number equals 2, it has exceeded the [-1, 1] range
                return !(this->relative_cell >= 2);
            }
            __device__ Message& operator++();
            template<typename T, unsigned int N>
            __device__ T getVariable(const char(&variable_name)[N]) const;
            template<typename T, flamegpu::size_type N, unsigned int M> __device__
            T getVariable(const char(&variable_name)[M], unsigned int index) const;
        };
        class iterator {
            Message _message;

         public:
            __device__ iterator(const Filter &parent, const int relative_cell_y, const int _cell_index_max, const int _cell_index)
                : _message(parent, relative_cell_y, _cell_index_max, _cell_index) {
                // Increment to find first message
                ++_message;
            }
            __device__ iterator(const Filter &parent)
                : _message(parent) { }
            __device__ iterator& operator++() { ++_message;  return *this; }
            __device__ iterator operator++(int) {
                iterator temp = *this;
                ++*this;
                return temp;
            }
            __device__ bool operator==(const iterator& rhs) const { return  _message == rhs._message; }
            __device__ bool operator!=(const iterator& rhs) const { return  _message != rhs._message; }
            __device__ Message& operator*() { return _message; }
            __device__ Message* operator->() { return &_message; }
        };
        __device__ Filter(const MetaData *_metadata, float x, float y);
        inline __device__ iterator begin(void) const {
            // Bin before initial bin, as the constructor calls increment operator
            return iterator(*this, -2, 1, 0);
        }
        inline __device__ iterator end(void) const {
            // Empty init, because this object is never used
            // iterator equality doesn't actually check the end object
            return iterator(*this);
        }

     private:
        float loc[2];
        GridPos2D cell;
        const MetaData *metadata;
    };
    class WrapFilter {
     public:
        class Message {
            const WrapFilter& _parent;
            int relative_cell[2] = { -2, 1 };
            int cell_index_max = 0;
            int cell_index = 0;
            __device__ void nextCell() {
                if (relative_cell[1] >= 1) {
                    relative_cell[1] = -1;
                    ++relative_cell[0];
                } else {
                    ++relative_cell[1];
                }
            }

         public:
            __device__ Message(const WrapFilter& parent, const int relative_cell_x, const int relative_cell_y, const int _cell_index_max, const int _cell_index)
                : _parent(parent)
                , cell_index_max(_cell_index_max)
                , cell_index(_cell_index) {
                relative_cell[0] = relative_cell_x;
                relative_cell[1] = relative_cell_y;
            }
            __device__ Message(const WrapFilter& parent)
                : _parent(parent) { }
            __device__ bool operator==(const Message& rhs) const {
                return this->relative_cell[0] == rhs.relative_cell[0]
                    && this->relative_cell[1] == rhs.relative_cell[1]
                    && this->cell_index_max == rhs.cell_index_max
                    && this->cell_index == rhs.cell_index;
            }
            __device__ bool operator!=(const Message& rhs) const {
                // The incoming Message& is end(), so we don't care about that
                // We only care that the host object has reached end
                // When the x offset equals 2, it has exceeded the [1, 1] range
                return !(this->relative_cell[0] >= 2);
            }
            __device__ Message& operator++();
            template<typename T, unsigned int N>
            __device__ T getVariable(const char(&variable_name)[N]) const;
            template<typename T, flamegpu::size_type N, unsigned int M>
            __device__ T getVariable(const char(&variable_name)[M], unsigned int index) const;
            __device__ float getVirtualX() const {
                return getVirtualX(_parent.loc[0]);
            }
            __device__ float getVirtualY() const {
                return getVirtualY(_parent.loc[1]);
            }
            __device__ float getVirtualX(const float x1) const {
                const float x2 = getVariable<float>("x");
                const float x21 = x2 - x1;
                return abs(x21) > _parent.metadata->environmentWidth[0] / 2.0f ? x2 - (x21 / abs(x21) * _parent.metadata->environmentWidth[0]) : x2;
            }
            __device__ float getVirtualY(const float y1) const {
                const float y2 = getVariable<float>("y");
                const float y21 = y2 - y1;
                return abs(y21) > _parent.metadata->environmentWidth[1] / 2.0f ? y2 - (y21 / abs(y21) * _parent.metadata->environmentWidth[1]) : y2;
            }
        };
        class iterator {
            Message _message;

         public:
            __device__ iterator(const WrapFilter& parent, const int relative_cell_x, const int relative_cell_y, const int _cell_index_max, const int _cell_index)
                : _message(parent, relative_cell_x, relative_cell_y, _cell_index_max, _cell_index) {
                // Increment to find first message
                ++_message;
            }
            __device__ iterator(const WrapFilter& parent)
                : _message(parent) { }
            __device__ iterator& operator++() { ++_message;  return *this; }
            __device__ iterator operator++(int) {
                iterator temp = *this;
                ++* this;
                return temp;
            }
            __device__ bool operator==(const iterator& rhs) const { return  _message == rhs._message; }
            __device__ bool operator!=(const iterator& rhs) const { return  _message != rhs._message; }
            __device__ Message& operator*() { return _message; }
            __device__ Message* operator->() { return &_message; }
        };
        __device__ WrapFilter(const MetaData* _metadata, float x, float y);
        inline __device__ iterator begin(void) const {
            // Bin before initial bin, as the constructor calls increment operator
            return iterator(*this, -2, 1, 1, 0);
        }
        inline __device__ iterator end(void) const {
            // Empty init, because this object is never used
            // iterator equality doesn't actually check the end object
            return iterator(*this);
        }

     private:
        float loc[2];
        GridPos2D cell;
        const MetaData* metadata;
    };
    __device__ In(const void *_metadata)
        : metadata(reinterpret_cast<const MetaData*>(_metadata))
    { }
     inline __device__ Filter operator() (const float x, const float y) const {
         return Filter(metadata, x, y);
     }
     inline __device__ WrapFilter wrap(const float x, const float y) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
         if (x > metadata->max[0] ||
             y > metadata->max[1] ||
             x < metadata->min[0] ||
             y < metadata->min[1]) {
             DTHROW("Location (%f, %f) exceeds environment bounds (%g, %g):(%g, %g),"
                 " this is unsupported for the wrapped iterator, MessageSpatial2D::In::wrap().\n", x, y,
                 metadata->min[0], metadata->min[1],
                 metadata->max[0], metadata->max[1]);
             // Return iterator at min corner of env, this should be safe
             return WrapFilter(metadata, metadata->min[0], metadata->min[1]);
         }
         if (fmodf(metadata->max[0] - metadata->min[0], metadata->radius) > 0.00001f ||
             fmodf(metadata->max[1] - metadata->min[1], metadata->radius) > 0.00001f) {
             DTHROW("Spatial messaging radius (%g) is not a factor of environment dimensions (%g, %g),"
                 " this is unsupported for the wrapped iterator, MessageSpatial2D::In::wrap().\n", metadata->radius,
                 metadata->max[0] - metadata->min[0],
                 metadata->max[1] - metadata->min[1]);
         }
#endif
         return WrapFilter(metadata, x, y);
     }

     __forceinline__ __device__ float radius() const {
         return metadata->radius;
    }

 private:
    const MetaData *metadata;
};

class MessageSpatial2D::Out : public MessageBruteForce::Out {
 public:
    __device__ Out(const void *, unsigned int *scan_flag_messageOutput)
        : MessageBruteForce::Out(nullptr, scan_flag_messageOutput)
    { }
    __device__ void setLocation(const float x, const float y) const;
};

template<typename T, unsigned int N>
__device__ T MessageSpatial2D::In::Filter::Message::getVariable(const char(&variable_name)[N]) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (relative_cell >= 2) {
        DTHROW("MessageSpatial2D in invalid bin, unable to get variable '%s'.\n", variable_name);
        return static_cast<T>(0);
    }
#endif
    // get the value from curve using the message index.
    T value = detail::curve::DeviceCurve::getMessageVariable<T>(variable_name, cell_index);
    return value;
}
template<typename T, flamegpu::size_type N, unsigned int M> __device__
T MessageSpatial2D::In::Filter::Message::getVariable(const char(&variable_name)[M], const unsigned int array_index) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (relative_cell >= 2) {
        DTHROW("MessageSpatial2D in invalid bin, unable to get variable '%s'.\n", variable_name);
        return {};
    }
#endif
    // get the value from curve using the message index.
    T value = detail::curve::DeviceCurve::getMessageArrayVariable<T, N>(variable_name, cell_index, array_index);
    return value;
}
template<typename T, unsigned int N>
__device__ T MessageSpatial2D::In::WrapFilter::Message::getVariable(const char(&variable_name)[N]) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (relative_cell[0] >= 2) {
        DTHROW("MessageSpatial2D in invalid bin, unable to get variable '%s'.\n", variable_name);
        return static_cast<T>(0);
    }
#endif
    // get the value from curve using the stored hashes and message index.
    T value = detail::curve::DeviceCurve::getMessageVariable<T>(variable_name, cell_index);
    return value;
}
template<typename T, flamegpu::size_type N, unsigned int M> __device__
T MessageSpatial2D::In::WrapFilter::Message::getVariable(const char(&variable_name)[M], const unsigned int array_index) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (relative_cell[0] >= 2) {
        DTHROW("MessageSpatial2D in invalid bin, unable to get variable '%s'.\n", variable_name);
        return {};
    }
#endif
    // get the value from curve using the stored hashes and message index.
    T value = detail::curve::DeviceCurve::getMessageArrayVariable<T, N>(variable_name, cell_index, array_index);
    return value;
}


__device__ __forceinline__ MessageSpatial2D::GridPos2D getGridPosition2D(const MessageSpatial2D::MetaData *md, float x, float y) {
    // Clamp each grid coord to 0<=x<dim
    int gridPos[2] = {
        static_cast<int>(floorf((x-md->min[0]) / md->radius)),
        static_cast<int>(floorf((y-md->min[1]) / md->radius))
    };
    MessageSpatial2D::GridPos2D rtn = {
        gridPos[0] < 0 ? 0 : (gridPos[0] >= static_cast<int>(md->gridDim[0]) ? static_cast<int>(md->gridDim[0]) - 1 : gridPos[0]),
        gridPos[1] < 0 ? 0 : (gridPos[1] >= static_cast<int>(md->gridDim[1]) ? static_cast<int>(md->gridDim[1]) - 1 : gridPos[1])
    };
    return rtn;
}
__device__ __forceinline__ unsigned int getHash2D(const MessageSpatial2D::MetaData *md, const MessageSpatial2D::GridPos2D &xyz) {
    // Bound gridPos to gridDimensions
    unsigned int gridPos[3] = {
        (unsigned int)(xyz.x < 0 ? 0 : (xyz.x >= static_cast<int>(md->gridDim[0]) - 1 ? static_cast<int>(md->gridDim[0]) - 1 : xyz.x)),  // Only x should ever be out of bounds here
        (unsigned int) xyz.y,  // xyz.y < 0 ? 0 : (xyz.y >= md->gridDim[1] - 1 ? md->gridDim[1] - 1 : xyz.y)
    };
    // Compute hash (effectivley an index for to a bin within the partitioning grid in this case)
    return (unsigned int)(
        (gridPos[1] * md->gridDim[0]) +                    // y
        gridPos[0]);                                      // x
}

__device__ inline void MessageSpatial2D::Out::setLocation(const float x, const float y) const {
    unsigned int index = (blockDim.x * blockIdx.x) + threadIdx.x;  // + d_message_count;

    // set the variables using curve
    detail::curve::DeviceCurve::setMessageVariable<float>("x", x, index);
    detail::curve::DeviceCurve::setMessageVariable<float>("y", y, index);

    // Set scan flag incase the message is optional
    this->scan_flag[index] = 1;
}

__device__ inline MessageSpatial2D::In::Filter::Filter(const MetaData* _metadata, const float x, const float y)
    : metadata(_metadata) {
    loc[0] = x;
    loc[1] = y;
    cell = getGridPosition2D(_metadata, x, y);
}
__device__ inline MessageSpatial2D::In::Filter::Message& MessageSpatial2D::In::Filter::Message::operator++() {
    cell_index++;
    bool move_strip = cell_index >= cell_index_max;
    while (move_strip) {
        nextStrip();
        cell_index = 0;
        cell_index_max = 1;
        if (relative_cell < 2) {
            // Calculate the strips start and end hash
            int absolute_cell_y = _parent.cell.y + relative_cell;
            // Skip the strip if it is completely out of bounds
            if (absolute_cell_y >= 0 && absolute_cell_y < static_cast<int>(_parent.metadata->gridDim[1])) {
                unsigned int start_hash = getHash2D(_parent.metadata, { _parent.cell.x - 1, absolute_cell_y });
                unsigned int end_hash = getHash2D(_parent.metadata, { _parent.cell.x + 1, absolute_cell_y });
                // Lookup start and end indicies from PBM
                cell_index = _parent.metadata->PBM[start_hash];
                cell_index_max = _parent.metadata->PBM[end_hash + 1];
            } else {
                // Goto next strip
                // Don't update move_strip
                continue;
            }
        }
        move_strip = cell_index >= cell_index_max;
    }
    return *this;
}
__device__ inline MessageSpatial2D::In::WrapFilter::WrapFilter(const MetaData* _metadata, const float x, const float y)
    : metadata(_metadata) {
    loc[0] = x;
    loc[1] = y;
    cell = getGridPosition2D(_metadata, x, y);
}
__device__ inline MessageSpatial2D::In::WrapFilter::Message& MessageSpatial2D::In::WrapFilter::Message::operator++() {
    cell_index++;
    bool move_cell = cell_index >= cell_index_max;
    while (move_cell) {
        nextCell();
        cell_index = 0;
        cell_index_max = 1;
        if (relative_cell[0] < 2) {
            // Wrap the cell (simply add grid width and use remainder op, relative should not be less than - grid width)
            int absolute_cell_x = (_parent.cell.x + relative_cell[0] + static_cast<int>(_parent.metadata->gridDim[0])) % _parent.metadata->gridDim[0];
            int absolute_cell_y = (_parent.cell.y + relative_cell[1] + static_cast<int>(_parent.metadata->gridDim[1])) % _parent.metadata->gridDim[1];
            unsigned int start_hash = getHash2D(_parent.metadata, { absolute_cell_x, absolute_cell_y });
            // Lookup start and end indicies from PBM
            cell_index = _parent.metadata->PBM[start_hash];
            cell_index_max = _parent.metadata->PBM[start_hash + 1];
        }
        move_cell = cell_index >= cell_index_max;
    }
    return *this;
}


}  // namespace flamegpu


#endif  // INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGESPATIAL2D_MESSAGESPATIAL2DDEVICE_CUH_