Program Listing for File MessageBruteForce.cu
↰ Return to documentation for file (src/flamegpu/runtime/messaging/MessageBruteForce.cu
)
#include <utility>
#include <string>
#include <memory>
#include "flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceHost.h"
#include "flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceDevice.cuh"
#include "flamegpu/model/AgentDescription.h" // Used by Move-Assign
#include "flamegpu/simulation/detail/CUDAMessage.h"
#include "flamegpu/detail/cuda.cuh"
namespace flamegpu {
void MessageBruteForce::CUDAModelHandler::init(detail::CUDAScatter &, unsigned int, cudaStream_t stream) {
allocateMetaDataDevicePtr(stream);
// Allocate messages
hd_metadata.length = 0; // This value should already be 0
gpuErrchk(cudaMemcpyAsync(d_metadata, &hd_metadata, sizeof(MetaData), cudaMemcpyHostToDevice, stream));
gpuErrchk(cudaStreamSynchronize(stream)); // This could probably be skipped/delayed safely
}
void MessageBruteForce::CUDAModelHandler::allocateMetaDataDevicePtr(cudaStream_t stream) {
if (d_metadata == nullptr) {
gpuErrchk(cudaMalloc(&d_metadata, sizeof(MetaData)));
}
}
void MessageBruteForce::CUDAModelHandler::freeMetaDataDevicePtr() {
if (d_metadata != nullptr) {
gpuErrchk(flamegpu::detail::cuda::cudaFree(d_metadata));
}
d_metadata = nullptr;
}
void MessageBruteForce::CUDAModelHandler::buildIndex(detail::CUDAScatter &, unsigned int, cudaStream_t stream) {
unsigned int newLength = this->sim_message.getMessageCount();
if (newLength != hd_metadata.length) {
hd_metadata.length = newLength;
gpuErrchk(cudaMemcpyAsync(d_metadata, &hd_metadata, sizeof(MetaData), cudaMemcpyHostToDevice, stream)); // Not Pinned
gpuErrchk(cudaStreamSynchronize(stream)); // This could probably be skipped/delayed safely if in the right stream
}
}
MessageBruteForce::Data::Data(std::shared_ptr<const ModelData> _model, const std::string &message_name)
: model(_model)
, name(message_name)
, persistent(false)
, optional_outputs(0) { }
MessageBruteForce::Data::Data(std::shared_ptr<const ModelData> _model, const MessageBruteForce::Data &other)
: model(_model)
, variables(other.variables)
, name(other.name)
, persistent(other.persistent)
, optional_outputs(other.optional_outputs) { }
MessageBruteForce::Data *MessageBruteForce::Data::clone(const std::shared_ptr<const ModelData> &newParent) {
return new MessageBruteForce::Data(newParent, *this);
}
bool MessageBruteForce::Data::operator==(const MessageBruteForce::Data& rhs) const {
if (this == &rhs) // They point to same object
return true;
if (name == rhs.name
// && model.lock() == rhs.model.lock() // Don't check weak pointers
&& persistent == rhs.persistent
&& variables.size() == rhs.variables.size()) {
{ // Compare variables
for (auto &v : variables) {
auto _v = rhs.variables.find(v.first);
if (_v == rhs.variables.end())
return false;
if (v.second.type_size != _v->second.type_size
|| v.second.type != _v->second.type
|| v.second.elements != _v->second.elements)
return false;
}
}
return true;
}
return false;
}
bool MessageBruteForce::Data::operator!=(const MessageBruteForce::Data& rhs) const {
return !operator==(rhs);
}
std::unique_ptr<MessageSpecialisationHandler> MessageBruteForce::Data::getSpecialisationHander(detail::CUDAMessage &owner) const {
return std::unique_ptr<MessageSpecialisationHandler>(new MessageBruteForce::CUDAModelHandler(owner));
}
flamegpu::MessageSortingType flamegpu::MessageBruteForce::Data::getSortingType() const {
return flamegpu::MessageSortingType::none;
}
// Used for the MessageBruteForce::Data::getType() type and derived methods
std::type_index MessageBruteForce::Data::getType() const { return std::type_index(typeid(MessageBruteForce)); }
MessageBruteForce::CDescription::CDescription(std::shared_ptr<Data> data)
: message(std::move(data)) { }
MessageBruteForce::CDescription::CDescription(std::shared_ptr<const Data> data)
: message(std::move(std::const_pointer_cast<Data>(data))) { }
bool MessageBruteForce::CDescription::operator==(const CDescription& rhs) const {
return *this->message == *rhs.message; // Compare content is functionally the same
}
bool MessageBruteForce::CDescription::operator!=(const CDescription& rhs) const {
return !(*this == rhs);
}
std::string MessageBruteForce::CDescription::getName() const {
return message->name;
}
bool MessageBruteForce::CDescription::getPersistent() const {
return message->persistent;
}
const std::type_index& MessageBruteForce::CDescription::getVariableType(const std::string& variable_name) const {
auto f = message->variables.find(variable_name);
if (f != message->variables.end()) {
return f->second.type;
}
THROW exception::InvalidMessageVar("Message ('%s') does not contain variable '%s', "
"in MessageDescription::getVariableType().",
message->name.c_str(), variable_name.c_str());
}
size_t MessageBruteForce::CDescription::getVariableSize(const std::string& variable_name) const {
auto f = message->variables.find(variable_name);
if (f != message->variables.end()) {
return f->second.type_size;
}
THROW exception::InvalidMessageVar("Message ('%s') does not contain variable '%s', "
"in MessageDescription::getVariableSize().",
message->name.c_str(), variable_name.c_str());
}
flamegpu::size_type MessageBruteForce::CDescription::getVariableLength(const std::string& variable_name) const {
auto f = message->variables.find(variable_name);
if (f != message->variables.end()) {
return f->second.elements;
}
THROW exception::InvalidAgentVar("Message ('%s') does not contain variable '%s', "
"in MessageBruteForce::getVariableLength().",
message->name.c_str(), variable_name.c_str());
}
flamegpu::size_type MessageBruteForce::CDescription::getVariablesCount() const {
// Downcast, will never have more than UINT_MAX variables
return static_cast<flamegpu::size_type>(message->variables.size());
}
bool MessageBruteForce::CDescription::hasVariable(const std::string& variable_name) const {
return message->variables.find(variable_name) != message->variables.end();
}
MessageBruteForce::Description::Description(std::shared_ptr<Data> data)
: CDescription(std::move(data)) { }
} // namespace flamegpu