# L14 - Calculating Data Movement

This Jupyter notebooks illustrates a process to calculate the data movement required to implement a specified dataflow using the [Python wrapper](https://documen.tician.de/islpy/index.html) around Sven Verdoolaegeâ€™s **Integer Set Library" ([ISL](https://libisl.sourceforge.io/)).

This notebook was used for creating the lecture slides and should work, but is rather crude and still a work-in-progress.

## Install/import some required libraries

In [None]:
# Installation: pip3 install islpy

# Documentation: https://documen.tician.de/islpy/

%pip install islpy

import sys
import islpy as isl

## Define some useful utility functions

In [None]:
def print_point(point):
  print(point)

def scan(set):
  set.foreach_point(print_point)

def val_to_int(val):
  assert val.is_int()
  return int(val.to_str())

def point_to_int(point):
  # FIXME: assert that the point has only 1 dim.
  pointval = point.get_coordinate_val(isl.dim_type.all, 0)
  return val_to_int(pointval)

def set_to_int(set):
  assert set.is_singleton()
  point = set.sample_point()
  return point_to_int(point)

def scan_map_at_domain_point(map):
  def curried(point):
    singleton = isl.Set.from_point(point)
    submap = map.intersect_domain(singleton)
    scan(submap.wrap())
    print()
  return curried

def group_print_map(map):
  domain = map.domain()
  domain.foreach_point(scan_map_at_domain_point(map))

def card_at_domain_point(map, ret_dict):
  def curried(point):
    singleton = isl.Set.from_point(point)
    submap = map.intersect_domain(singleton)
    card = submap.wrap().count_val()
    card_set = isl.Set(f'{{ [{card}] }}')
    card_map_entry = isl.Map.from_domain_and_range(singleton, card_set)
    ret_dict['map'] = ret_dict['map'].union(card_map_entry)
  return curried

def group_card(map):
  domain = map.domain()
  ret_dict = { 'map' : isl.Map.from_domain_and_range(domain.sample(), isl.Set('{ [0] : false }')) }
  domain.foreach_point(card_at_domain_point(map, ret_dict))
  return ret_dict['map']

def count_at_domain_point(map, ret_dict):
  def curried(point):
    singleton = isl.Set.from_point(point)
    submap = map.intersect_domain(singleton)
    card = submap.wrap().count_val()
    pointval = point.get_coordinate_val(isl.dim_type.all, 0)
    ret_dict[val_to_int(pointval)] = val_to_int(card)
  return curried

def group_count(map):
  domain = map.domain()
  ret_dict = { }
  domain.foreach_point(count_at_domain_point(map, ret_dict))
  return ret_dict

## Define some utility functions to print ISL objects

In [None]:
def printSet(title, set):
    print(f"{title} = {set}")
    print("")
    scan(set)

def compareSets(setA, setB):
    is_equal = setA.is_equal(setB)
    if is_equal:
        print("The sets are equal")
    else:
        print("The sets are NOT equal")

def printMap(title, map):
    print(f"{title} = {map}")
    print("")
    wrapped = map.wrap()
    lex_gt = wrapped.lex_gt_set(wrapped)

    seq_map = {}
  
    def assign_seq(point):
        singleton = isl.Set.from_point(point)
        seq = singleton.apply(lex_gt).count_val()
        seq_map[seq] = singleton
        
    wrapped.foreach_point(assign_seq)

    for seq in sorted(seq_map.keys()):
        print(seq_map[seq])


    

## Define some utility functions to plot some ISL objects

In [None]:
import matplotlib.pyplot as plt
import re

def plotMap(map):

    plotSet(map.domain())
    plotSet(map.range())
    
def plotSet(*sets):
    """
    Plots a set of 2-D points using matplotlib.
    
    Parameters:
        points (set): A set of 2-D points represented as tuples (x, y).
    """

    colors = ["blue", "red", "green", "yellow"] 
    points = []

    def next_point(point):
        pos0 = int(point.get_coordinate_val(isl._isl.dim_type.set, 0).to_python())
        try:
            pos1 = int(point.get_coordinate_val(isl._isl.dim_type.set, 1).to_python())
        except:
            pos1 = 0
            
        points.append((pos0,pos1))

    for n, set in enumerate(sets):
        set.foreach_point(next_point)
        # Extract x and y coordinates from the points
        x_coords, y_coords = zip(*points)        
        plt.scatter(x_coords, y_coords, color=colors[n])
        
        points = []

    # Get spacename

    spacename = sets[0].get_space().to_str()
    pattern = r'\[(.*?),(.*?)\]'
    match = re.search(pattern, spacename)
    if match:
        name1 = match.group(1).strip()
        name2 = match.group(2).strip()
    else:
        pattern = r'\[(.*?)\]'
        match = re.search(pattern, spacename)
        if match:
            name1 = match.group(1).strip()
            name2 = ""

        # Label each point with a number
#    for i, point in enumerate(points):
#        plt.text(point[0]+0.1, point[1]+0.1, str(i+1), ha='center', va='center')
        
    # Plot the points
    plt.xlabel(name1)
    plt.ylabel(name2)
    plt.title(spacename)
    plt.grid(True)
    plt.show()




## Output Stationary Iteration Space

The following sequence of cells will build up an analysis of an output stationary dataflow

## Basic ISL Background

### An ISL Set

Create an ISL set in the space SetXY.

As in many cells in this notebook the first lines are the actual activity following by some lines to display the results of that activity

In [None]:
X = 5
Y = 3

set1 = isl.Set(f'{{ SetXY[x, y] : 0 <= x <= {X-1} and 0 <= y <= {Y-1} }}')

printSet("SetXY", set1)
plotSet(set1)

### Space names

In ISL the points in a set have a spacename. The prior spacename was SetXY here we create a different spacename. And we note that sets with the same set of points but different spacenames are different.

In [None]:
set2 = isl.Set(f'{{ SetWX[x, y] : 0 <= x <= {X-1} and 0 <= y <= {Y-1} }}')
compareSets(set1, set2)

print()
printSet("Set", set2)
print()
plotSet(set2)

### Tuple Order

This set looks at the impact of the order of the points in the specification of a set. In this case we see that set order matters.

In [None]:
set3 = isl.Set(f'{{ SetXY[y, x] : 0 <= x <= {X-1} and 0 <= y <= {Y-1} }}')
compareSets(set1, set3)

print()
printSet("SetXY - transposed coordinates", set3)
plotSet(set3)

### Different constraints

In this cell we use a different specification of the constraints on the points in the set, and see that so long as the constraints result in the same set of points the sets are the same.

In [None]:
set4 = isl.Set(f'{{ SetXY[x, y] : 0 <= x < {X} and 0 <= y < {Y} }}')
compareSets(set1, set4)

print()
printSet("SetXY - different conditions", set4)
plotSet(set4)

### More complex conditions

This cell simply illustrates that we can specify more complex constraint expressions

In [None]:
set4 = isl.Set(f'{{ SetXY[x, y] : 0 <= x < {X} and 0 <= y <= x }}')
compareSets(set1, set4)

print()
printSet("SetXY - complex conditions", set4)
plotSet(set4)

## Defining an iteration space

### Iteration Space

This cell illustrates how we can use an ISL to define the iteration space for a computation.

In [None]:
X = 5
Y = 3

ispaceXY = isl.Set(f'{{ IterationSpace[x,y] : 0 <= x < {X} and 0 <= y < {Y} }}')

printSet("Iteration Space", ispaceXY)
plotSet(ispaceXY)

### 1-D Convolution Iteration Space

This cell shows how we can use the dummy variables in a set to use the standard variable names for defining the iteration space for a 1-D convolution. And that those names do not affect the equality of the sets.

In [None]:
Q=5
S=3

ispace = isl.Set(f'{{ IterationSpace[q,s] : 0 <= q < {Q} and 0 <= s < {S} }}')
compareSets(ispaceXY, ispace)

print()
printSet("Iteration Space", ispace)
plotSet(ispace)


## Maps

This next section looks at defining ISL maps, i.e., relations between two ISL sets.

## Maps (unbounded)

This cell defines a map between a set in the space `SetXY` and another set `SetXY_range`. In specific, we create a relation because a point (x,y) and another point one to the right and up 10. Note, however, that although we can print the defintion of the map, we cannot print the contents of the map because it is infinite.

In [None]:
unbounded_map = isl.Map(f'{{ SetXY[x,y] -> SetXY_range[x1,y1] : x1=x+1 and y1=y+10}}')

print(f"{unbounded_map = }")

### Maps (bounded)

To allow us to illustrate maps we can bound the domain of a map, as shown in the cell below. In many subsequent cells we will bound maps in this fashion to be able to display something, while for overall computation purposes there is no need to do that, and ISL will simply work on the symbolic expression describing the set.

Also, in the cell below we display both the domain of the map (blue) and the range of the map (red) in the same plot. Currently you have to imagine the lines connecting points in the domain to the range.

In [None]:
bounded_map = unbounded_map.intersect_domain(set1)

print()
printMap("Bounded map", bounded_map)
plotSet(bounded_map.domain(), bounded_map.range())

## Projections from iteration space to data spaces

The following set of cells show the specification of the projections from the points in the iteration space to the various data spaces of the 1-D convolution.

Note: The `plotMap()` method used in the following cells creates separate plots of the domain and range of the map, but unfortunately do not show lines connecting the points in the domain to points in the range. Also, the commented out statments in some of the cells below are plots that I used for the lecture.

### Weight Projection

Note how multiple points in the interation space map to the same same point in the weight space.

In [None]:
is2weight = isl.Map("{ IterationSpace[q,s] -> Weight[s] }")

is2weight_bounded = is2weight.intersect_domain(ispace)

printMap('Weight projections', is2weight_bounded)
plotMap(is2weight_bounded)
#plotSet(is2weight_bounded.domain())

### Output Projection

Again multiple points in the iteraction space map to a single point in the output space.

In [None]:
is2output = isl.Map("{ IterationSpace[q,s] -> Output[q] }")

is2output_bounded = is2output.intersect_domain(ispace)

printMap('Output projections', is2output_bounded)
plotMap(is2output_bounded)

### Input Projection

In [None]:
is2input = isl.Map("{ IterationSpace[q,s] -> Input[w] : w=q+s }")

is2input_bounded = is2input.intersect_domain(ispace)

printMap('Input projections', is2input_bounded)
plotMap(is2input_bounded)

### Compute Projection

In [None]:
is2compute = isl.Map("{ IterationSpace[q,s] -> MACC[q,s] }")

is2compute_bounded = is2compute.intersect_domain(ispace)

printMap('Compute projections', is2compute_bounded)
plotMap(is2compute_bounded)

## Iteration Space Traversal

In the next cells we show how to specify a mapping from an iteration space to a sequence of execution timestamps, which correspond to a specfic traversal through the iteration space.

### Defining a dataflow (schedule)

This cell shows the specification of a schedule that relates each point in the iteration space to a timestamp. Timestamps are specified as a tuple where the order is determined by assuming that the lower (rightmost) digit moves most quickly (like the ones place in a number).

Note: Because this is an unbounded map we cannot print its contents

In [None]:
os_schedule_unbounded = isl.Map(f'{{ IterationSpace[q,s] -> Timestamp[t1,t0] : t1=q and t0=s }}')

# This is an alternative way of specifying a timestamp, i.e., by flattening it
# But this will probably not work, because portions of the notebook asssume timestamps have two parameeters
#os_schedule_alt_unbounded = isl.Map(f'{{ IterationSpace[q,s] -> Timestamp[t] : t = s*{Q}+q }}')

# This is an alternative dataflow which can be assigned to df_is2ts_unbounded
#ws_schedule_unbounded = isl.Map(f'{{ IterationSpace[q,s] -> Timestamp[t1,t0] : t1=s and t0=q }}')

df_is2ts_unbounded = os_schedule_unbounded

print(f"{ df_is2ts_unbounded = }")

### Bounded schedule

Because we actually only care about a schedule for the actual points in the iteraction space we bound the schedule to the actual iteration space.

In [None]:
df_is2ts = df_is2ts_unbounded.intersect_domain(ispace)


printMap("Weight Stationary Schedule", df_is2ts)
plotMap(df_is2ts)
#plotSet(df_is2ts.domain())

### Define the set of timesteps

We define a convenience variable of the (bounded) set of imestamps for the problem

In [None]:
timestamps = df_is2ts.range()

printSet('Timestamps', timestamps)
plotSet(timestamps)

### Timestamp order traversal

Because we actually mostly care about mapping from a timestamp to a point in the iteration space, we reverse the map.

In [None]:
# Tensor accesses at each time coordinate.
df_ts2is = df_is2ts.reverse() # T-relation.

printMap('Timestamp to Iteration Space', df_ts2is)

## Code Generation

In [None]:
# Code-generation
build = isl.AstBuild.alloc(isl.DEFAULT_CONTEXT)
df_is2ts_tree = build.node_from_schedule_map(df_is2ts)

printer = isl.Printer.to_str(isl.DEFAULT_CONTEXT).set_output_format(isl.format.C)
printer.print_str("// Weight-stationary schedule\n")
printer.print_ast_node(df_is2ts_tree)
printer.print_str("\n")

print(printer.get_str())

## Composing Maps

Some background on composing maps

### This cell illustrates the process of composing two maps

Note: Again the plot does not show the arrows between the points

In [None]:
blue = isl.Set("{ BLUE[x,y] : 0 <= x < 6 and y = 0 }")
printSet("blue", blue)

blue2red = isl.Map("{ BLUE[x,y] -> RED[x1,y1] : x1=x+1 and y1=y+1 }")
blue2red_bounded = blue2red.intersect_domain(blue)
print()
printMap('blue2red', blue2red_bounded)

red2green = isl.Map("{ RED[x,y] -> GREEN[x1,y1] : x1=x-1 and y1=y+1 }")
red2green_bounded = red2green.intersect_domain(blue2red_bounded.range())
print()
printMap('red2green', red2green_bounded)

blue2green = blue2red.apply_range(red2green)

blue2green_bounded = blue2green.intersect_domain(blue)
print()
printMap('blue2green', blue2green_bounded)

plotSet(blue2red_bounded.domain(), blue2red_bounded.range(), blue2green_bounded.range())


## Calculating Accesses at a Timestamp

### Weight Data Accesses

Create a map of from timestamps to weights accesses

In [None]:
df_ts2weight_current = df_ts2is.apply_range(is2weight)

printMap("Timestamp -> Weight Access", df_ts2weight_current)

### Input Data Accesses

Create a map from timestamps to input accesses

In [None]:
df_ts2input_current = df_ts2is.apply_range(is2input)

printMap("Timestamp -> Input Access", df_ts2input_current)

### Output Data Accesses

Create a maps from timestamps to output accesses

In [None]:
df_ts2output_current = df_ts2is.apply_range(is2output)

printMap("Timestamp -> Output Access", df_ts2output_current)

## Examples of Probing the Schedule

### Probing the schedule (1)

In [None]:
# What is the last timestamp in my schedule?
print('last timestamp =', df_is2ts.range().lexmax())
print()

### Probing the schedule (2)

In [None]:
# Which weight point is touched at timestamp[1,2]?
print(isl.Set('{ Timestamp[1,2] }').apply(df_ts2weight_current))
print(df_ts2weight_current.intersect_domain(isl.Set('{ Timestamp[1,2] }')))

## Calculating Reuse

### Previous timestamp map

Create a map from a timestamp to the previous timestamp

Todo: Show plot with arrows...

In [None]:
T0_MAX = S # Ideally should get this from lexmax.
timestamp_previous = isl.Map(f"{{ Timestamp[t1',t0'] -> Timestamp[t1,t0] : t1'=t1 and t0'=t0+1 or t1'=t1+1 and t0'=0 and t0={T0_MAX}-1 }}")

timestamp_previous_bounded = timestamp_previous.intersect_range(timestamps)
timestamp_previous_bounded = timestamp_previous_bounded.intersect_domain(timestamps)
printMap("Prevous timestamp map", timestamp_previous_bounded)
plotMap(timestamp_previous_bounded)

### Weight Use in Previous Timestep

In [None]:
df_ts2weight_previous = timestamp_previous.apply_range(df_ts2weight_current)

printMap("Previous Weights", df_ts2weight_previous)
plotMap(df_ts2weight_previous)

### Weight Deltas

Timesteps using a **new** weight.

In [None]:
df_ts2weight_delta = df_ts2weight_current.subtract(df_ts2weight_previous)

printMap("Weight Deltas", df_ts2weight_delta)

### Input Use in Previous Cycle`

In [None]:
df_ts2input_previous = timestamp_previous.apply_range(df_ts2input_current)
printMap("Previos cycle inputs", df_ts2input_previous)

### Input Delta
Timesteps using a **new** input.

In [None]:
df_ts2input_delta = df_ts2input_current.subtract(df_ts2input_previous)
printMap("Input Deltas", df_ts2input_delta)

### Outputs used in previous cycle

In [None]:
df_ts2output_previous = timestamp_previous.apply_range(df_ts2output_current)
printMap('Output previous', df_ts2output_previous)

### Output Deltas

Timesteps using a **new** output

In [None]:
df_ts2output_delta = df_ts2output_current.subtract(df_ts2output_previous)
printMap("Output Deltas", df_ts2output_delta)

## L1 Level Buffering

### Tiling Timesteps

Create a partitioning of the tmestamp space where a set of timestamps, i.e., a **tile**, map to a L1 timestamp

In [None]:
timestamp2L1timestamp = isl.Map('{ Timestamp[t1,t0] -> L1Timestamp[t1] }')



timestamp2L1timestamp_bounded = timestamp2L1timestamp.intersect_domain(timestamps)
L1timestamps = timestamp2L1timestamp_bounded.range()
printMap('xxx', timestamp2L1timestamp_bounded)

### Create the L1timestamp to timestamp map

In [None]:
L1timestamp2timestamp = timestamp2L1timestamp.reverse()
L1timestamp2timestamp_bounded = timestamp2L1timestamp_bounded.reverse()

printMap('L1ts -> ts', L1timestamp2timestamp_bounded)
plotMap(timestamp2L1timestamp_bounded)
#plotSet(timestamp2L1timestamp_bounded.domain())

### L1 Level Previous

Create a map from each L1 timestamp to the previous L1 timestamp

In [None]:
L1timestamp_previous = isl.Map(f"{{ L1Timestamp[t'] -> L1Timestamp[t] : t'=t+1 }}")

L1timestamp_previous_bounded = L1timestamp_previous.intersect_domain(L1timestamps).intersect_range(L1timestamps)
printMap('L1timeestamp_previous', L1timestamp_previous_bounded)

## Calculate L1 Level Deltas

Using the mappings created before calculate the data accesses in the current and previous L1 timesteps and compute the deltas

### L1 Level Deltas - Weights

Calculate L1 weight accesses in the curent and previous L1 timestemps and the delta

In [None]:
df_L1ts2weight_current = L1timestamp2timestamp.apply_range(df_ts2weight_current)

print('df_L1ts2weight_current =', df_L1ts2weight_current)
group_print_map(df_L1ts2weight_current)

df_L1ts2weight_previous = L1timestamp_previous.apply_range(df_L1ts2weight_current)
print('df_prevous')
group_print_map(df_L1ts2weight_previous)

df_L1ts2weight_delta = df_L1ts2weight_current.subtract(df_L1ts2weight_previous)

print('L1 weight tile delta =', df_L1ts2weight_delta)
group_print_map(df_L1ts2weight_delta)

### L1 Level Deltas - Inputs

Calculate L1 input accesses in the curent and previous L1 timestemps and the delta

In [None]:
df_L1ts2input_current = L1timestamp2timestamp.apply_range(df_ts2input_current)

print('df_L1ts2input_current =', df_L1ts2input_current)
group_print_map(df_L1ts2input_current)

df_L1ts2input_previous = L1timestamp_previous.apply_range(df_L1ts2input_current)

print('df_L1input_tile_previous')
group_print_map(df_L1ts2input_previous)

df_L1ts2input_delta = df_L1ts2input_current.subtract(df_L1ts2input_previous)

print('L1 input tile delta =', df_L1ts2input_delta)
group_print_map(df_L1ts2input_delta)

### L1 Level Deltas - Outputs

Calculate L1 output accesses in the curent and previous L1 timestemps and the delta

In [None]:
df_L1ts2output_current = L1timestamp2timestamp.apply_range(df_ts2output_current)

print('df_L1ts2output_current =', df_L1ts2output_current)
group_print_map(df_L1ts2output_current)

df_L1ts2output_previous = L1timestamp_previous.apply_range(df_L1ts2output_current)
print('L1output tile previous')
group_print_map(df_L1ts2output_previous)

df_L1ts2output_delta = df_L1ts2output_current.subtract(df_L1ts2output_previous)

print('L1 output tile delta =', df_L1ts2output_delta)
group_print_map(df_L1ts2output_delta)

## Calculate L1 shrinks

## Calcuate next L1 timestamp

In [None]:
L1timestamp_next = L1timestamp_previous.reverse()

L1timestamp_next_bounded = L1timestamp_next.intersect_domain(L1timestamps).intersect_range(L1timestamps)
printMap('L1timestep_next', L1timestamp_next_bounded)

# HACK
L1timestamp_next = L1timestamp_next_bounded

## Shrink L1 Weights

Calculate L1 weight accesses in the curent and next L1 timestemps and the shrink

In [None]:
df_L1ts2weight_next = L1timestamp_next.apply_range(df_L1ts2weight_current)

print('L1 weight tile next = ', df_L1ts2weight_next)
print()
group_print_map(df_L1ts2weight_next)

df_L1ts2weight_shrink = df_L1ts2weight_current.subtract(df_L1ts2weight_next)

print('L1 weight tile shrink =', df_L1ts2weight_shrink)
print()
group_print_map(df_L1ts2weight_shrink)

### Shrink L1 Inputs

Calculate L1 input accesses in the curent and next L1 timestemps and the shrink

In [None]:
df_L1ts2input_next = L1timestamp_next.apply_range(df_L1ts2input_current)

print('L1 input tile next =', df_L1ts2input_next)
print()
group_print_map(df_L1ts2input_next)

df_L1ts2input_shrink = df_L1ts2input_current.subtract(df_L1ts2input_next)

print('L1 input tile shrink =', df_L1ts2input_shrink)
print()
group_print_map(df_L1ts2input_shrink)

## L1 Output Shrink

Calculate L1 output accesses in the curent and next L1 timestemps and the shrink

In [None]:
df_L1ts2output_next = L1timestamp_next.apply_range(df_L1ts2output_current)

print('L1 output tile next =', df_L1ts2output_next)
print()
group_print_map(df_L1ts2output_next)

df_L1ts2output_shrink = df_L1ts2output_current.subtract(df_L1ts2output_next)

print('L1 output tile shrink =', df_L1ts2output_shrink)
print()
group_print_map(df_L1ts2output_shrink)