source: CLRX/CLRadeonExtender/trunk/amdasm/AsmRegAlloc.cpp @ 4172

Last change on this file since 4172 was 4172, checked in by matszpk, 11 months ago

CLRadeonExtender: AsmRegAlloc?: Unfinished integration LinearDepHandler? with AsmRegAllocator?.

File size: 32.0 KB
Line 
1/*
2 *  CLRadeonExtender - Unofficial OpenCL Radeon Extensions Library
3 *  Copyright (C) 2014-2018 Mateusz Szpakowski
4 *
5 *  This library is free software; you can redistribute it and/or
6 *  modify it under the terms of the GNU Lesser General Public
7 *  License as published by the Free Software Foundation; either
8 *  version 2.1 of the License, or (at your option) any later version.
9 *
10 *  This library is distributed in the hope that it will be useful,
11 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 *  Lesser General Public License for more details.
14 *
15 *  You should have received a copy of the GNU Lesser General Public
16 *  License along with this library; if not, write to the Free Software
17 *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
18 */
19
20#include <CLRX/Config.h>
21#include <assert.h>
22#include <iostream>
23#include <stack>
24#include <deque>
25#include <vector>
26#include <utility>
27#include <unordered_set>
28#include <map>
29#include <set>
30#include <unordered_map>
31#include <algorithm>
32#include <CLRX/utils/Utilities.h>
33#include <CLRX/utils/Containers.h>
34#include <CLRX/amdasm/Assembler.h>
35#include "AsmInternals.h"
36#include "AsmRegAlloc.h"
37
38using namespace CLRX;
39
40#if ASMREGALLOC_DEBUGDUMP
41std::ostream& operator<<(std::ostream& os, const CLRX::BlockIndex& v)
42{
43    if (v.pass==0)
44        return os << v.index;
45    else
46        return os << v.index << "#" << v.pass;
47}
48#endif
49
50ISAUsageHandler::ISAUsageHandler(const std::vector<cxbyte>& _content) :
51            content(_content), lastOffset(0), readOffset(0), instrStructPos(0),
52            regUsagesPos(0), regUsages2Pos(0), regVarUsagesPos(0),
53            pushedArgs(0), argPos(0), argFlags(0), isNext(false), useRegMode(false)
54{ }
55
56ISAUsageHandler::~ISAUsageHandler()
57{ }
58
59void ISAUsageHandler::rewind()
60{
61    readOffset = instrStructPos = 0;
62    regUsagesPos = regUsages2Pos = regVarUsagesPos = 0;
63    useRegMode = false;
64    pushedArgs = 0;
65    skipBytesInInstrStruct();
66}
67
68void ISAUsageHandler::skipBytesInInstrStruct()
69{
70    // do not add instruction size if usereg (usereg immediately before instr regusages)
71    if ((instrStructPos != 0 || argPos != 0) && !useRegMode)
72        readOffset += defaultInstrSize;
73    argPos = 0;
74    for (;instrStructPos < instrStruct.size() &&
75        instrStruct[instrStructPos] > 0x80; instrStructPos++)
76        readOffset += (instrStruct[instrStructPos] & 0x7f);
77    isNext = (instrStructPos < instrStruct.size());
78}
79
80void ISAUsageHandler::putSpace(size_t offset)
81{
82    if (lastOffset != offset)
83    {
84        flush(); // flush before new instruction
85        // useReg immediately before instruction regusages
86        size_t defaultInstrSize = (!useRegMode ? this->defaultInstrSize : 0);
87        if (lastOffset > offset)
88            throw AsmException("Offset before previous instruction");
89        if (!instrStruct.empty() && offset - lastOffset < defaultInstrSize)
90            throw AsmException("Offset between previous instruction");
91        size_t toSkip = !instrStruct.empty() ? 
92                offset - lastOffset - defaultInstrSize : offset;
93        while (toSkip > 0)
94        {
95            size_t skipped = std::min(toSkip, size_t(0x7f));
96            instrStruct.push_back(skipped | 0x80);
97            toSkip -= skipped;
98        }
99        lastOffset = offset;
100        argFlags = 0;
101        pushedArgs = 0;
102    } 
103}
104
105void ISAUsageHandler::pushUsage(const AsmRegVarUsage& rvu)
106{
107    if (lastOffset == rvu.offset && useRegMode)
108        flush(); // only flush if useRegMode and no change in offset
109    else // otherwise
110        putSpace(rvu.offset);
111    useRegMode = false;
112    if (rvu.regVar != nullptr)
113    {
114        argFlags |= (1U<<pushedArgs);
115        regVarUsages.push_back({ rvu.regVar, rvu.rstart, rvu.rend, rvu.regField,
116            rvu.rwFlags, rvu.align });
117    }
118    else // reg usages
119        regUsages.push_back({ rvu.regField,cxbyte(rvu.rwFlags |
120                    getRwFlags(rvu.regField, rvu.rstart, rvu.rend)) });
121    pushedArgs++;
122}
123
124void ISAUsageHandler::pushUseRegUsage(const AsmRegVarUsage& rvu)
125{
126    if (lastOffset == rvu.offset && !useRegMode)
127        flush(); // only flush if useRegMode and no change in offset
128    else // otherwise
129        putSpace(rvu.offset);
130    useRegMode = true;
131    if (pushedArgs == 0 || pushedArgs == 256)
132    {
133        argFlags = 0;
134        pushedArgs = 0;
135        instrStruct.push_back(0x80); // sign of regvarusage from usereg
136        instrStruct.push_back(0);
137    }
138    if (rvu.regVar != nullptr)
139    {
140        argFlags |= (1U<<(pushedArgs & 7));
141        regVarUsages.push_back({ rvu.regVar, rvu.rstart, rvu.rend, rvu.regField,
142            rvu.rwFlags, rvu.align });
143    }
144    else // reg usages
145        regUsages2.push_back({ rvu.rstart, rvu.rend, rvu.rwFlags });
146    pushedArgs++;
147    if ((pushedArgs & 7) == 0) // just flush per 8 bit
148    {
149        instrStruct.push_back(argFlags);
150        instrStruct[instrStruct.size() - ((pushedArgs+7) >> 3) - 1] = pushedArgs;
151        argFlags = 0;
152    }
153}
154
155void ISAUsageHandler::flush()
156{
157    if (pushedArgs != 0)
158    {
159        if (!useRegMode)
160        {
161            // normal regvarusages
162            instrStruct.push_back(argFlags);
163            if ((argFlags & (1U<<(pushedArgs-1))) != 0)
164                regVarUsages.back().rwFlags |= 0x80;
165            else // reg usages
166                regUsages.back().rwFlags |= 0x80;
167        }
168        else
169        {
170            // use reg regvarusages
171            if ((pushedArgs & 7) != 0) //if only not pushed args remains
172                instrStruct.push_back(argFlags);
173            instrStruct[instrStruct.size() - ((pushedArgs+7) >> 3) - 1] = pushedArgs;
174        }
175    }
176}
177
178AsmRegVarUsage ISAUsageHandler::nextUsage()
179{
180    if (!isNext)
181        throw AsmException("No reg usage in this code");
182    AsmRegVarUsage rvu;
183    // get regvarusage
184    bool lastRegUsage = false;
185    rvu.offset = readOffset;
186    if (!useRegMode && instrStruct[instrStructPos] == 0x80)
187    {
188        // useRegMode (begin fetching useregs)
189        useRegMode = true;
190        argPos = 0;
191        instrStructPos++;
192        // pushedArgs - numer of useregs, 0 - 256 useregs
193        pushedArgs = instrStruct[instrStructPos++];
194        argFlags = instrStruct[instrStructPos];
195    }
196    rvu.useRegMode = useRegMode; // no ArgPos
197   
198    if ((instrStruct[instrStructPos] & (1U << (argPos&7))) != 0)
199    {
200        // regvar usage
201        const AsmRegVarUsageInt& inRVU = regVarUsages[regVarUsagesPos++];
202        rvu.regVar = inRVU.regVar;
203        rvu.rstart = inRVU.rstart;
204        rvu.rend = inRVU.rend;
205        rvu.regField = inRVU.regField;
206        rvu.rwFlags = inRVU.rwFlags & ASMRVU_ACCESS_MASK;
207        rvu.align = inRVU.align;
208        if (!useRegMode)
209            lastRegUsage = ((inRVU.rwFlags&0x80) != 0);
210    }
211    else if (!useRegMode)
212    {
213        // simple reg usage
214        const AsmRegUsageInt& inRU = regUsages[regUsagesPos++];
215        rvu.regVar = nullptr;
216        const std::pair<uint16_t, uint16_t> regPair =
217                    getRegPair(inRU.regField, inRU.rwFlags);
218        rvu.rstart = regPair.first;
219        rvu.rend = regPair.second;
220        rvu.rwFlags = (inRU.rwFlags & ASMRVU_ACCESS_MASK);
221        rvu.regField = inRU.regField;
222        rvu.align = 0;
223        lastRegUsage = ((inRU.rwFlags&0x80) != 0);
224    }
225    else
226    {
227        // use reg (simple reg usage, second structure)
228        const AsmRegUsage2Int& inRU = regUsages2[regUsages2Pos++];
229        rvu.regVar = nullptr;
230        rvu.rstart = inRU.rstart;
231        rvu.rend = inRU.rend;
232        rvu.rwFlags = inRU.rwFlags;
233        rvu.regField = ASMFIELD_NONE;
234        rvu.align = 0;
235    }
236    argPos++;
237    if (useRegMode)
238    {
239        // if inside useregs
240        if (argPos == (pushedArgs&0xff))
241        {
242            instrStructPos++; // end
243            skipBytesInInstrStruct();
244            useRegMode = false;
245        }
246        else if ((argPos & 7) == 0) // fetch new flag
247        {
248            instrStructPos++;
249            argFlags = instrStruct[instrStructPos];
250        }
251    }
252    // after instr
253    if (lastRegUsage)
254    {
255        instrStructPos++;
256        skipBytesInInstrStruct();
257    }
258    return rvu;
259}
260
261
262ISALinearDepHandler::ISALinearDepHandler() : regVarLinDepsPos(0)
263{ }
264
265void ISALinearDepHandler::pushLinearDep(const AsmRegVarLinearDep& linearDep)
266{
267    regVarLinDeps.push_back(linearDep);
268}
269
270void ISALinearDepHandler::rewind()
271{
272    regVarLinDepsPos = 0;
273}
274
275AsmRegVarLinearDep ISALinearDepHandler::nextLinearDep()
276{
277    if (regVarLinDepsPos >= regVarLinDeps.size())
278        throw AsmException("No regvar linear deps in this code");
279    return regVarLinDeps[regVarLinDepsPos++];
280}
281
282/*
283 * Asm register allocator stuff
284 */
285
286AsmRegAllocator::AsmRegAllocator(Assembler& _assembler) : assembler(_assembler)
287{ }
288
289AsmRegAllocator::AsmRegAllocator(Assembler& _assembler,
290        const std::vector<CodeBlock>& _codeBlocks, const SSAReplacesMap& _ssaReplacesMap)
291        : assembler(_assembler), codeBlocks(_codeBlocks), ssaReplacesMap(_ssaReplacesMap)
292{ }
293
294static inline bool codeBlockStartLess(const AsmRegAllocator::CodeBlock& c1,
295                  const AsmRegAllocator::CodeBlock& c2)
296{ return c1.start < c2.start; }
297
298static inline bool codeBlockEndLess(const AsmRegAllocator::CodeBlock& c1,
299                  const AsmRegAllocator::CodeBlock& c2)
300{ return c1.end < c2.end; }
301
302void AsmRegAllocator::createCodeStructure(const std::vector<AsmCodeFlowEntry>& codeFlow,
303             size_t codeSize, const cxbyte* code)
304{
305    ISAAssembler* isaAsm = assembler.isaAssembler;
306    if (codeSize == 0)
307        return;
308    std::vector<size_t> splits;
309    std::vector<size_t> codeStarts;
310    std::vector<size_t> codeEnds;
311    codeStarts.push_back(0);
312    codeEnds.push_back(codeSize);
313    for (const AsmCodeFlowEntry& entry: codeFlow)
314    {
315        size_t instrAfter = 0;
316        if (entry.type == AsmCodeFlowType::JUMP || entry.type == AsmCodeFlowType::CJUMP ||
317            entry.type == AsmCodeFlowType::CALL || entry.type == AsmCodeFlowType::RETURN)
318            instrAfter = entry.offset + isaAsm->getInstructionSize(
319                        codeSize - entry.offset, code + entry.offset);
320       
321        switch(entry.type)
322        {
323            case AsmCodeFlowType::START:
324                codeStarts.push_back(entry.offset);
325                break;
326            case AsmCodeFlowType::END:
327                codeEnds.push_back(entry.offset);
328                break;
329            case AsmCodeFlowType::JUMP:
330                splits.push_back(entry.target);
331                codeEnds.push_back(instrAfter);
332                break;
333            case AsmCodeFlowType::CJUMP:
334                splits.push_back(entry.target);
335                splits.push_back(instrAfter);
336                break;
337            case AsmCodeFlowType::CALL:
338                splits.push_back(entry.target);
339                splits.push_back(instrAfter);
340                break;
341            case AsmCodeFlowType::RETURN:
342                codeEnds.push_back(instrAfter);
343                break;
344            default:
345                break;
346        }
347    }
348    std::sort(splits.begin(), splits.end());
349    splits.resize(std::unique(splits.begin(), splits.end()) - splits.begin());
350    std::sort(codeEnds.begin(), codeEnds.end());
351    codeEnds.resize(std::unique(codeEnds.begin(), codeEnds.end()) - codeEnds.begin());
352    // remove codeStarts between codeStart and codeEnd
353    size_t i = 0;
354    size_t ii = 0;
355    size_t ei = 0; // codeEnd i
356    while (i < codeStarts.size())
357    {
358        size_t end = (ei < codeEnds.size() ? codeEnds[ei] : SIZE_MAX);
359        if (ei < codeEnds.size())
360            ei++;
361        codeStarts[ii++] = codeStarts[i];
362        // skip codeStart to end
363        for (i++ ;i < codeStarts.size() && codeStarts[i] < end; i++);
364    }
365    codeStarts.resize(ii);
366    // add next codeStarts
367    auto splitIt = splits.begin();
368    for (size_t codeEnd: codeEnds)
369    {
370        auto it = std::lower_bound(splitIt, splits.end(), codeEnd);
371        if (it != splits.end())
372        {
373            codeStarts.push_back(*it);
374            splitIt = it;
375        }
376        else // if end
377            break;
378    }
379   
380    std::sort(codeStarts.begin(), codeStarts.end());
381    codeStarts.resize(std::unique(codeStarts.begin(), codeStarts.end()) -
382                codeStarts.begin());
383    // divide to blocks
384    splitIt = splits.begin();
385    for (size_t codeStart: codeStarts)
386    {
387        size_t codeEnd = *std::upper_bound(codeEnds.begin(), codeEnds.end(), codeStart);
388        splitIt = std::lower_bound(splitIt, splits.end(), codeStart);
389       
390        if (splitIt != splits.end() && *splitIt==codeStart)
391            ++splitIt; // skip split in codeStart
392       
393        for (size_t start = codeStart; start < codeEnd; )
394        {
395            size_t end = codeEnd;
396            if (splitIt != splits.end())
397            {
398                end = std::min(end, *splitIt);
399                ++splitIt;
400            }
401            codeBlocks.push_back({ start, end, { }, false, false, false });
402            start = end;
403        }
404    }
405    // force empty block at end if some jumps goes to its
406    if (!codeEnds.empty() && !codeStarts.empty() && !splits.empty() &&
407        codeStarts.back()==codeEnds.back() && codeStarts.back() == splits.back())
408        codeBlocks.push_back({ codeStarts.back(), codeStarts.back(), { },
409                             false, false, false });
410   
411    // construct flow-graph
412    for (const AsmCodeFlowEntry& entry: codeFlow)
413        if (entry.type == AsmCodeFlowType::CALL || entry.type == AsmCodeFlowType::JUMP ||
414            entry.type == AsmCodeFlowType::CJUMP || entry.type == AsmCodeFlowType::RETURN)
415        {
416            std::vector<CodeBlock>::iterator it;
417            size_t instrAfter = entry.offset + isaAsm->getInstructionSize(
418                        codeSize - entry.offset, code + entry.offset);
419           
420            if (entry.type != AsmCodeFlowType::RETURN)
421                it = binaryFind(codeBlocks.begin(), codeBlocks.end(),
422                        CodeBlock{ entry.target }, codeBlockStartLess);
423            else // return
424            {
425                it = binaryFind(codeBlocks.begin(), codeBlocks.end(),
426                        CodeBlock{ 0, instrAfter }, codeBlockEndLess);
427                // if block have return
428                if (it != codeBlocks.end())
429                    it->haveEnd = it->haveReturn = true;
430                continue;
431            }
432           
433            if (it == codeBlocks.end())
434                continue; // error!
435            auto it2 = std::lower_bound(codeBlocks.begin(), codeBlocks.end(),
436                    CodeBlock{ instrAfter }, codeBlockStartLess);
437            auto curIt = it2;
438            --curIt;
439           
440            curIt->nexts.push_back({ size_t(it - codeBlocks.begin()),
441                        entry.type == AsmCodeFlowType::CALL });
442            curIt->haveCalls |= entry.type == AsmCodeFlowType::CALL;
443            if (entry.type == AsmCodeFlowType::CJUMP ||
444                 entry.type == AsmCodeFlowType::CALL)
445            {
446                curIt->haveEnd = false; // revert haveEnd if block have cond jump or call
447                if (it2 != codeBlocks.end() && entry.type == AsmCodeFlowType::CJUMP)
448                    // add next next block (only for cond jump)
449                    curIt->nexts.push_back({ size_t(it2 - codeBlocks.begin()), false });
450            }
451            else if (entry.type == AsmCodeFlowType::JUMP)
452                curIt->haveEnd = true; // set end
453        }
454    // force haveEnd for block with cf_end
455    for (const AsmCodeFlowEntry& entry: codeFlow)
456        if (entry.type == AsmCodeFlowType::END)
457        {
458            auto it = binaryFind(codeBlocks.begin(), codeBlocks.end(),
459                    CodeBlock{ 0, entry.offset }, codeBlockEndLess);
460            if (it != codeBlocks.end())
461                it->haveEnd = true;
462        }
463   
464    if (!codeBlocks.empty()) // always set haveEnd to last block
465        codeBlocks.back().haveEnd = true;
466   
467    // reduce nexts
468    for (CodeBlock& block: codeBlocks)
469    {
470        // first non-call nexts, for correct resolving SSA conflicts
471        std::sort(block.nexts.begin(), block.nexts.end(),
472                  [](const NextBlock& n1, const NextBlock& n2)
473                  { return int(n1.isCall)<int(n2.isCall) ||
474                      (n1.isCall == n2.isCall && n1.block < n2.block); });
475        auto it = std::unique(block.nexts.begin(), block.nexts.end(),
476                  [](const NextBlock& n1, const NextBlock& n2)
477                  { return n1.block == n2.block && n1.isCall == n2.isCall; });
478        block.nexts.resize(it - block.nexts.begin());
479    }
480}
481
482
483void AsmRegAllocator::applySSAReplaces()
484{
485    if (ssaReplacesMap.empty())
486        return; // do nothing
487   
488    /* prepare SSA id replaces */
489    struct MinSSAGraphNode
490    {
491        size_t minSSAId;
492        bool visited;
493        std::unordered_set<size_t> nexts;
494        MinSSAGraphNode() : minSSAId(SIZE_MAX), visited(false)
495        { }
496    };
497   
498    typedef std::map<size_t, MinSSAGraphNode, std::greater<size_t> > SSAGraphNodesMap;
499   
500    struct MinSSAGraphStackEntry
501    {
502        SSAGraphNodesMap::iterator nodeIt;
503        std::unordered_set<size_t>::const_iterator nextIt;
504        size_t minSSAId;
505       
506        MinSSAGraphStackEntry(
507                SSAGraphNodesMap::iterator _nodeIt,
508                std::unordered_set<size_t>::const_iterator _nextIt,
509                size_t _minSSAId = SIZE_MAX)
510                : nodeIt(_nodeIt), nextIt(_nextIt), minSSAId(_minSSAId)
511        { }
512    };
513   
514    for (auto& entry: ssaReplacesMap)
515    {
516        ARDOut << "SSAReplace: " << entry.first.regVar << "." << entry.first.index << "\n";
517        VectorSet<SSAReplace>& replaces = entry.second;
518        std::sort(replaces.begin(), replaces.end(), std::greater<SSAReplace>());
519        replaces.resize(std::unique(replaces.begin(), replaces.end()) - replaces.begin());
520        VectorSet<SSAReplace> newReplaces;
521       
522        SSAGraphNodesMap ssaGraphNodes;
523       
524        auto it = replaces.begin();
525        while (it != replaces.end())
526        {
527            auto itEnd = std::upper_bound(it, replaces.end(),
528                    std::make_pair(it->first, size_t(0)), std::greater<SSAReplace>());
529            {
530                auto itLast = itEnd;
531                --itLast;
532                MinSSAGraphNode& node = ssaGraphNodes[it->first];
533                node.minSSAId = std::min(node.minSSAId, itLast->second);
534                for (auto it2 = it; it2 != itEnd; ++it2)
535                {
536                    node.nexts.insert(it2->second);
537                    ssaGraphNodes.insert({ it2->second, MinSSAGraphNode() });
538                }
539            }
540            it = itEnd;
541        }
542        /*for (const auto& v: ssaGraphNodes)
543            ARDOut << "  SSANode: " << v.first << ":" << &v.second << " minSSAID: " <<
544                            v.second.minSSAId << std::endl;*/
545        // propagate min value
546        std::stack<MinSSAGraphStackEntry> minSSAStack;
547       
548        // initialize parents and new nexts
549        for (auto ssaGraphNodeIt = ssaGraphNodes.begin();
550                 ssaGraphNodeIt!=ssaGraphNodes.end(); )
551        {
552            ARDOut << "  Start in " << ssaGraphNodeIt->first << "." << "\n";
553            minSSAStack.push({ ssaGraphNodeIt, ssaGraphNodeIt->second.nexts.begin() });
554            // traverse with minimalize SSA id
555            while (!minSSAStack.empty())
556            {
557                MinSSAGraphStackEntry& entry = minSSAStack.top();
558                MinSSAGraphNode& node = entry.nodeIt->second;
559                bool toPop = false;
560                if (entry.nextIt == node.nexts.begin())
561                {
562                    toPop = node.visited;
563                    node.visited = true;
564                }
565                if (!toPop && entry.nextIt != node.nexts.end())
566                {
567                    auto nodeIt = ssaGraphNodes.find(*entry.nextIt);
568                    if (nodeIt != ssaGraphNodes.end())
569                        minSSAStack.push({ nodeIt, nodeIt->second.nexts.begin(),
570                                    size_t(0) });
571                    ++entry.nextIt;
572                }
573                else
574                {
575                    minSSAStack.pop();
576                    if (!minSSAStack.empty())
577                        node.nexts.insert(minSSAStack.top().nodeIt->first);
578                }
579            }
580           
581            // skip visited nodes
582            for(; ssaGraphNodeIt != ssaGraphNodes.end(); ++ssaGraphNodeIt)
583                if (!ssaGraphNodeIt->second.visited)
584                    break;
585        }
586       
587        /*for (const auto& v: ssaGraphNodes)
588        {
589            ARDOut << "  Nexts: " << v.first << ":" << &v.second << " nexts:";
590            for (size_t p: v.second.nexts)
591                ARDOut << " " << p;
592            ARDOut << "\n";
593        }*/
594       
595        for (auto& entry: ssaGraphNodes)
596            entry.second.visited = false;
597       
598        std::vector<MinSSAGraphNode*> toClear; // nodes to clear
599       
600        for (auto ssaGraphNodeIt = ssaGraphNodes.begin();
601                 ssaGraphNodeIt!=ssaGraphNodes.end(); )
602        {
603            ARDOut << "  Start in " << ssaGraphNodeIt->first << "." << "\n";
604            minSSAStack.push({ ssaGraphNodeIt, ssaGraphNodeIt->second.nexts.begin() });
605            // traverse with minimalize SSA id
606            while (!minSSAStack.empty())
607            {
608                MinSSAGraphStackEntry& entry = minSSAStack.top();
609                MinSSAGraphNode& node = entry.nodeIt->second;
610                bool toPop = false;
611                if (entry.nextIt == node.nexts.begin())
612                {
613                    toPop = node.visited;
614                    if (!node.visited)
615                        // this flag visited for this node will be clear after this pass
616                        toClear.push_back(&node);
617                    node.visited = true;
618                }
619               
620                // try to children only all parents are visited and if parent has children
621                if (!toPop && entry.nextIt != node.nexts.end())
622                {
623                    auto nodeIt = ssaGraphNodes.find(*entry.nextIt);
624                    if (nodeIt != ssaGraphNodes.end())
625                    {
626                        ARDOut << "  Node: " <<
627                                entry.nodeIt->first << ":" << &node << " minSSAId: " <<
628                                node.minSSAId << " to " <<
629                                nodeIt->first << ":" << &(nodeIt->second) <<
630                                " minSSAId: " << nodeIt->second.minSSAId << "\n";
631                        nodeIt->second.minSSAId =
632                                std::min(nodeIt->second.minSSAId, node.minSSAId);
633                        minSSAStack.push({ nodeIt, nodeIt->second.nexts.begin(),
634                                nodeIt->second.minSSAId });
635                    }
636                    ++entry.nextIt;
637                }
638                else
639                {
640                    node.minSSAId = std::min(node.minSSAId, entry.minSSAId);
641                    ARDOut << "    Node: " <<
642                                entry.nodeIt->first << ":" << &node << " minSSAId: " <<
643                                node.minSSAId << "\n";
644                    minSSAStack.pop();
645                    if (!minSSAStack.empty())
646                    {
647                        MinSSAGraphStackEntry& pentry = minSSAStack.top();
648                        pentry.minSSAId = std::min(pentry.minSSAId, node.minSSAId);
649                    }
650                }
651            }
652           
653            const size_t minSSAId = ssaGraphNodeIt->second.minSSAId;
654           
655            // skip visited nodes
656            for(; ssaGraphNodeIt != ssaGraphNodes.end(); ++ssaGraphNodeIt)
657                if (!ssaGraphNodeIt->second.visited)
658                    break;
659            // zeroing visited
660            for (MinSSAGraphNode* node: toClear)
661            {
662                node->minSSAId = minSSAId; // fill up by minSSAId
663                node->visited = false;
664            }
665            toClear.clear();
666        }
667       
668        for (const auto& entry: ssaGraphNodes)
669            newReplaces.push_back({ entry.first, entry.second.minSSAId });
670       
671        std::sort(newReplaces.begin(), newReplaces.end());
672        entry.second = newReplaces;
673    }
674   
675    /* apply SSA id replaces */
676    for (CodeBlock& cblock: codeBlocks)
677        for (auto& ssaEntry: cblock.ssaInfoMap)
678        {
679            auto it = ssaReplacesMap.find(ssaEntry.first);
680            if (it == ssaReplacesMap.end())
681                continue;
682            SSAInfo& sinfo = ssaEntry.second;
683            VectorSet<SSAReplace>& replaces = it->second;
684            if (sinfo.readBeforeWrite)
685            {
686                auto rit = binaryMapFind(replaces.begin(), replaces.end(),
687                                 ssaEntry.second.ssaIdBefore);
688                if (rit != replaces.end())
689                    sinfo.ssaIdBefore = rit->second; // replace
690            }
691            if (sinfo.ssaIdFirst != SIZE_MAX)
692            {
693                auto rit = binaryMapFind(replaces.begin(), replaces.end(),
694                                 ssaEntry.second.ssaIdFirst);
695                if (rit != replaces.end())
696                    sinfo.ssaIdFirst = rit->second; // replace
697            }
698            if (sinfo.ssaIdLast != SIZE_MAX)
699            {
700                auto rit = binaryMapFind(replaces.begin(), replaces.end(),
701                                 ssaEntry.second.ssaIdLast);
702                if (rit != replaces.end())
703                    sinfo.ssaIdLast = rit->second; // replace
704            }
705        }
706   
707    // clear ssa replaces
708    ssaReplacesMap.clear();
709}
710
711void AsmRegAllocator::createInterferenceGraph()
712{
713    /// construct liveBlockMaps
714    std::set<LiveBlock> liveBlockMaps[MAX_REGTYPES_NUM];
715    for (size_t regType = 0; regType < regTypesNum; regType++)
716    {
717        std::set<LiveBlock>& liveBlockMap = liveBlockMaps[regType];
718        Array<OutLiveness>& liveness = outLivenesses[regType];
719        for (size_t li = 0; li < liveness.size(); li++)
720        {
721            OutLiveness& lv = liveness[li];
722            for (const std::pair<size_t, size_t>& blk: lv)
723                if (blk.first != blk.second)
724                    liveBlockMap.insert({ blk.first, blk.second, li });
725            lv.clear();
726        }
727        liveness.clear();
728    }
729   
730    // create interference graphs
731    for (size_t regType = 0; regType < regTypesNum; regType++)
732    {
733        InterGraph& interGraph = interGraphs[regType];
734        interGraph.resize(graphVregsCounts[regType]);
735        std::set<LiveBlock>& liveBlockMap = liveBlockMaps[regType];
736       
737        auto lit = liveBlockMap.begin();
738        size_t rangeStart = 0;
739        if (lit != liveBlockMap.end())
740            rangeStart = lit->start;
741        while (lit != liveBlockMap.end())
742        {
743            const size_t blkStart = lit->start;
744            const size_t blkEnd = lit->end;
745            size_t rangeEnd = blkEnd;
746            auto liStart = liveBlockMap.lower_bound({ rangeStart, 0, 0 });
747            auto liEnd = liveBlockMap.lower_bound({ rangeEnd, 0, 0 });
748            // collect from this range, variable indices
749            std::set<size_t> varIndices;
750            for (auto lit2 = liStart; lit2 != liEnd; ++lit2)
751                varIndices.insert(lit2->vidx);
752            // push to intergraph as full subgGraph
753            for (auto vit = varIndices.begin(); vit != varIndices.end(); ++vit)
754                for (auto vit2 = varIndices.begin(); vit2 != varIndices.end(); ++vit2)
755                    if (vit != vit2)
756                        interGraph[*vit].insert(*vit2);
757            // go to next live blocks
758            rangeStart = rangeEnd;
759            for (; lit != liveBlockMap.end(); ++lit)
760                if (lit->start != blkStart && lit->end != blkEnd)
761                    break;
762            if (lit == liveBlockMap.end())
763                break; //
764            rangeStart = std::max(rangeStart, lit->start);
765        }
766    }
767}
768
769/* algorithm to allocate regranges:
770 * from smallest regranges to greatest regranges:
771 *   choosing free register: from smallest free regranges
772 *      to greatest regranges:
773 *         in this same regrange:
774 *               try to find free regs in regranges
775 *               try to link free ends of two distinct regranges
776 */
777
778void AsmRegAllocator::colorInterferenceGraph()
779{
780    const GPUArchitecture arch = getGPUArchitectureFromDeviceType(
781                    assembler.deviceType);
782   
783    for (size_t regType = 0; regType < regTypesNum; regType++)
784    {
785        const size_t maxColorsNum = getGPUMaxRegistersNum(arch, regType);
786        InterGraph& interGraph = interGraphs[regType];
787        const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
788        Array<cxuint>& gcMap = graphColorMaps[regType];
789       
790        const size_t nodesNum = interGraph.size();
791        gcMap.resize(nodesNum);
792        std::fill(gcMap.begin(), gcMap.end(), cxuint(UINT_MAX));
793        Array<size_t> sdoCounts(nodesNum);
794        std::fill(sdoCounts.begin(), sdoCounts.end(), 0);
795       
796        SDOLDOCompare compare(interGraph, sdoCounts);
797        std::set<size_t, SDOLDOCompare> nodeSet(compare);
798        for (size_t i = 0; i < nodesNum; i++)
799            nodeSet.insert(i);
800       
801        cxuint colorsNum = 0;
802        // firstly, allocate real registers
803        for (const auto& entry: vregIndexMap)
804            if (entry.first.regVar == nullptr)
805                gcMap[entry.second[0]] = colorsNum++;
806       
807        for (size_t colored = 0; colored < nodesNum; colored++)
808        {
809            size_t node = *nodeSet.begin();
810            if (gcMap[node] != UINT_MAX)
811                continue; // already colored
812            size_t color = 0;
813           
814            for (color = 0; color <= colorsNum; color++)
815            {
816                // find first usable color
817                bool thisSame = false;
818                for (size_t nb: interGraph[node])
819                    if (gcMap[nb] == color)
820                    {
821                        thisSame = true;
822                        break;
823                    }
824                if (!thisSame)
825                    break;
826            }
827            if (color==colorsNum) // add new color if needed
828            {
829                if (colorsNum >= maxColorsNum)
830                    throw AsmException("Too many register is needed");
831                colorsNum++;
832            }
833           
834            gcMap[node] = color;
835            // update SDO for node
836            bool colorExists = false;
837            for (size_t nb: interGraph[node])
838                if (gcMap[nb] == color)
839                {
840                    colorExists = true;
841                    break;
842                }
843            if (!colorExists)
844                sdoCounts[node]++;
845            // update SDO for neighbors
846            for (size_t nb: interGraph[node])
847            {
848                colorExists = false;
849                for (size_t nb2: interGraph[nb])
850                    if (gcMap[nb2] == color)
851                    {
852                        colorExists = true;
853                        break;
854                    }
855                if (!colorExists)
856                {
857                    if (gcMap[nb] == UINT_MAX)
858                        nodeSet.erase(nb);  // before update we erase from nodeSet
859                    sdoCounts[nb]++;
860                    if (gcMap[nb] == UINT_MAX)
861                        nodeSet.insert(nb); // after update, insert again
862                }
863            }
864           
865            gcMap[node] = color;
866        }
867    }
868}
869
870void AsmRegAllocator::allocateRegisters(cxuint sectionId)
871{
872    // before any operation, clear all
873    codeBlocks.clear();
874    for (size_t i = 0; i < MAX_REGTYPES_NUM; i++)
875    {
876        graphVregsCounts[i] = 0;
877        vregIndexMaps[i].clear();
878        interGraphs[i].clear();
879        linearDepMaps[i].clear();
880        graphColorMaps[i].clear();
881    }
882    ssaReplacesMap.clear();
883    cxuint maxRegs[MAX_REGTYPES_NUM];
884    assembler.isaAssembler->getMaxRegistersNum(regTypesNum, maxRegs);
885   
886    // set up
887    const AsmSection& section = assembler.sections[sectionId];
888    createCodeStructure(section.codeFlow, section.content.size(), section.content.data());
889    createSSAData(*section.usageHandler, *section.linearDepHandler);
890    applySSAReplaces();
891    createLivenesses(*section.usageHandler, *section.linearDepHandler);
892    createInterferenceGraph();
893    colorInterferenceGraph();
894}
Note: See TracBrowser for help on using the repository browser.