Incorrect calculations using a Cythonized class on Dask Distributed

Solution for Incorrect calculations using a Cythonized class on Dask Distributed
is Given Below:

I’ve been banging my head against this issue – any help would be much appreciated. I’m not sure exactly where to go from here.

I’m using Dask to parallelize a least-cost-path calculation using scikit-image’s MCP class

The class is written in Cython, and since Dask expects intermediate results to be serializable I have implemented a Wrapper that “recreates” the MCP class during deserialization.

When I run the code without Dask, or using Dask’s single-threaded scheduler, it takes longer but the results come back fine.

However, when I switch to running using processes or threads (still using Dask Distributed), I get no errors, but I get a bunch of np.inf‘s in my results.

Furthermore, the results themselves are not consistent with what I get running on a single thread.

Adding relevant code snippets here:

# Create a client locally
if cluster_type == 'local':
        try:
            client = Client('127.0.0.1:8786')
        except:   
            cluster = LocalCluster(n_workers = 8, 
                               processes=True, 
                               threads_per_worker=8, 
                               scheduler_port=8786)

            client = Client(cluster)
## Create wrapper for MCP
# Creates a wrapper for Cython MCP Class
class Wrapper(object):
    def __init__(self, get_mcp):
        self.get_mcp = get_mcp 
        self.mcp = get_mcp()

    def __reduce__(self):
        #https://stackoverflow.com/questions/19855156/whats-the-exact-usage-of-reduce-in-pickler
        # When unpickled, the filter will be reloaded
        return(self.__class__, (self.get_mcp, ))


def load_mcp():
    print("...loading mcp...")
    inR = rasterio.open(friction_raster_path)
    inD = inR.read()[0,:,:] 
    inD = np.array(inD, dtype=np.float128) * 30 # important to specify pixel size in meters here in oder to get correct measurements
    inD = np.array(inD, dtype=np.float32)
    inD = np.nan_to_num(inD)
    mcp = graph.MCP_Geometric(inD)
    return mcp


# Init the wrapper for MCP
wrapper = Wrapper(load_mcp)

# Only reload inR here to do the crs check
inR = rasterio.open(friction_raster_path)
# Get costs from origin to dests
def get_costs_for_origin(wrapper, origin_id:str, origin_coords:tuple, dests:pd.DataFrame):
    # TODO - dests should be a list of tuples only
    res=[]
    origin_coords = [origin_coords]
    ends = dests.MCP_DESTS_COORDS.to_list()
    costs, traceback = wrapper.mcp.find_costs(starts=origin_coords, ends=ends)#ends=destinations.MCP_DESTS_COORDS.to_list())
    for idx, dest in enumerate(dests.to_dict(orient="records")):
        dest_coords = dest['MCP_DESTS_COORDS']
        tt = costs[dest_coords[0], dest_coords[1]]
        if tt > 9999999999:
            print(dest['id'])
            print(tt)
            raise ValueError("INF")
        res.append(
            {"d_id": dest['id'], 
             "d_tt": tt}
        )
            
    return {"o_id": origin_id, "o_tfan": res}
# Run on distributed scheduler using processes
def run_async(wrapper:Wrapper, origins_d:pd.DataFrame, dests_d:pd.DataFrame):
    # Broadcast the wrapper to all nodes
    wrapper = client.scatter(wrapper, broadcast=True)
    wait(wrapper)

    # Broadcast destinations to all nodes.
    dests_d = client.scatter(dests_d, broadcast=True)
    wait(dests_d)

    #https://docs.dask.org/en/latest/futures.html
    tasks = []
    for idx, origin in enumerate(origins_d):
        print(f"Origin {idx} of {len(origins_d)}")
        task = dask.delayed(get_costs_for_origin)(
            wrapper=wrapper,
            origin_id = origin['id'],
            origin_coords = origin['MCP_DESTS_COORDS'],
            dests=dests_d)#client.submit(get_costs_for_origin, wrapper, ogin, dests)
        tasks.append(task)
    #all_res = client.gather(futures)
    all_res_dsk = dask.compute(*tasks)
    all_res_dsk = list(all_res_dsk)
    return all_res_dsk

I’m assuming it’s something with the MCP class, but can’t figure out what could be causing the INFs to happen.

Thanks in advance everyone!