Program Listing for File CUDAScanCompaction.cu
↰ Return to documentation for file (src/flamegpu/simulation/detail/CUDAScanCompaction.cu
)
#include <cassert>
#include "flamegpu/simulation/detail/CUDAScanCompaction.h"
#include "flamegpu/simulation/detail/CUDAErrorChecking.cuh"
#include "flamegpu/simulation/CUDASimulation.h"
#include "flamegpu/detail/cuda.cuh"
namespace flamegpu {
namespace detail {
void CUDAScanCompaction::resize(const unsigned int newCount, const Type& type, const unsigned int streamId) {
assert(streamId < MAX_STREAMS);
assert(type < MAX_TYPES);
configs[type][streamId].resize_scan_flag(newCount);
}
void CUDAScanCompaction::zero_async(const Type& type, cudaStream_t stream, unsigned int streamId) {
assert(streamId < MAX_STREAMS);
assert(type < MAX_TYPES);
configs[type][streamId].zero_scan_flag_async(stream);
}
const CUDAScanCompactionConfig &CUDAScanCompaction::getConfig(const Type& type, const unsigned int streamId) {
return configs[type][streamId];
}
CUDAScanCompactionConfig &CUDAScanCompaction::Config(const Type& type, const unsigned int streamId) {
return configs[type][streamId];
}
CUDAScanCompactionConfig::~CUDAScanCompactionConfig() {
free_scan_flag();
}
void CUDAScanCompactionConfig::free_scan_flag() {
if (d_ptrs.scan_flag) {
gpuErrchk(flamegpu::detail::cuda::cudaFree(d_ptrs.scan_flag));
d_ptrs.scan_flag = nullptr;
}
if (d_ptrs.position) {
gpuErrchk(flamegpu::detail::cuda::cudaFree(d_ptrs.position));
d_ptrs.position = nullptr;
}
}
void CUDAScanCompactionConfig::zero_scan_flag_async(cudaStream_t stream) {
if (d_ptrs.position) {
gpuErrchk(cudaMemsetAsync(d_ptrs.position, 0, scan_flag_len * sizeof(unsigned int), stream));
}
if (d_ptrs.scan_flag) {
gpuErrchk(cudaMemsetAsync(d_ptrs.scan_flag, 0, scan_flag_len * sizeof(unsigned int), stream));
}
}
void CUDAScanCompactionConfig::resize_scan_flag(const unsigned int count) {
if (count + 1 > scan_flag_len) {
free_scan_flag();
gpuErrchk(cudaMalloc(&d_ptrs.scan_flag, (count + 1) * sizeof(unsigned int))); // +1 so we can get the total from the scan
gpuErrchk(cudaMalloc(&d_ptrs.position, (count + 1) * sizeof(unsigned int))); // +1 so we can get the total from the scan
scan_flag_len = count + 1;
}
}
} // namespace detail
} // namespace flamegpu