1 module dparse.rollback_allocator;
2 
3 //version = debug_rollback_allocator;
4 
5 /**
6  * Pointer-bump allocator with rollback functionality.
7  */
8 struct RollbackAllocator
9 {
10 public:
11 
12     @disable this(this);
13 
14     ~this()
15     {
16         while (first !is null)
17             deallocateNode();
18     }
19 
20     /**
21      * Allocates `size` bytes of memory.
22      */
23     void[] allocate(const size_t size)
24     out (arr)
25     {
26         assert(arr.length == size);
27     }
28     body
29     {
30         if (first is null)
31             allocateNode(size);
32 
33         // Move size up to the next multiple of 8 for memory alignment purposes
34         immutable size_t s = size & ~7UL;
35         immutable size_t s2 = s == size ? size : s + 8;
36 
37         size_t fu = first.used;
38         size_t end = fu + s2;
39         //assert(end >= fu + size);
40         //assert(end % 8 == 0);
41         if (end > first.mem.length)
42         {
43             allocateNode(size);
44             fu = first.used;
45             end = fu + s2;
46         }
47         //assert((cast(size_t) first.mem.ptr) % 8 == 0);
48         //assert(((cast(size_t) first.mem.ptr) + end) % 8 == 0);
49         void[] m = first.mem[fu .. fu + size];
50         first.used = end;
51         return m;
52     }
53 
54     /**
55      * Rolls back the allocator to the given checkpoint.
56      */
57     void rollback(size_t point)
58     {
59         import std.stdio : stderr;
60 
61         if (point == 0)
62         {
63             while (first)
64                 deallocateNode();
65             return;
66         }
67         else
68             assert(contains(point), "Attepmted to roll back to a point not in the allocator.");
69         while (!first.contains(point))
70             deallocateNode();
71         assert(first !is null);
72         immutable begin = point - cast(size_t) first.mem.ptr;
73         version (debug_rollback_allocator)
74             (cast(ubyte[]) first.mem)[begin .. $] = 0;
75         first.used = begin;
76         assert(cast(size_t) first.mem.ptr + first.used == point);
77     }
78 
79     /**
80      * Get a checkpoint for the allocator.
81      */
82     size_t setCheckpoint() const nothrow @nogc
83     {
84         assert(first.used <= first.mem.length);
85         return first is null ? 0 : cast(size_t) first.mem.ptr + first.used;
86     }
87 
88     /**
89      * Allocates a T and returns a pointer to it
90      */
91     auto make(T, Args...)(auto ref Args args)
92     {
93         import std.algorithm.comparison : max;
94         import std.experimental.allocator : stateSize;
95         import std.conv : emplace;
96 
97         void[] mem = allocate(max(stateSize!T, 1));
98         if (mem.ptr is null)
99             return null;
100         static if (is(T == class))
101             return emplace!T(mem, args);
102         else
103             return emplace(cast(T*) mem.ptr, args);
104     }
105 
106 private:
107 
108     // Used for debugging
109     bool contains(size_t point) const
110     {
111         for (const(Node)* n = first; n !is null; n = n.next)
112             if (n.contains(point))
113                 return true;
114         return false;
115     }
116 
117     static struct Node
118     {
119         Node* next;
120         size_t used;
121         ubyte[] mem;
122 
123         bool contains(size_t p) const pure nothrow @nogc @safe
124         {
125             return p >= cast(size_t) mem.ptr && p <= cast(size_t) mem.ptr + mem.length;
126         }
127     }
128 
129     void allocateNode(size_t size)
130     {
131         import std.algorithm : max;
132         import std.experimental.allocator.mallocator : Mallocator;
133         import std.conv : emplace;
134 
135         enum ALLOC_SIZE = 1024 * 8;
136 
137         ubyte[] m = cast(ubyte[]) Mallocator.instance.allocate(max(size + Node.sizeof, ALLOC_SIZE));
138         version (debug_rollback_allocator)
139             m[] = 0;
140         Node* n = emplace!Node(cast(Node*) m.ptr, first, 0, m[Node.sizeof .. $]);
141         assert((cast(size_t) n.mem.ptr) % 8 == 0, "The memoriez!");
142         first = n;
143     }
144 
145     void deallocateNode()
146     {
147         assert(first !is null);
148         import std.experimental.allocator.mallocator : Mallocator;
149 
150         Node* next = first.next;
151         ubyte[] mem = (cast(ubyte*) first)[0 .. Node.sizeof + first.mem.length];
152         version (debug_rollback_allocator)
153             mem[] = 0;
154         Mallocator.instance.deallocate(mem);
155         first = next;
156     }
157 
158     Node* first;
159 }