Open3D (C++ API)  0.17.0
Loading...
Searching...
No Matches
InvertNeighborsListOpKernel.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9
10#include "../TensorFlowHelper.h"
11#include "tensorflow/core/framework/op.h"
12#include "tensorflow/core/framework/op_kernel.h"
13#include "tensorflow/core/lib/core/errors.h"
14
16// namespace for code that is common for all kernels
17namespace invert_neighbors_list_opkernel {
18
19// Base class with common code for the OpKernel implementations
20class InvertNeighborsListOpKernel : public tensorflow::OpKernel {
21public:
22 explicit InvertNeighborsListOpKernel(
23 tensorflow::OpKernelConstruction* construction)
24 : OpKernel(construction) {}
25
26 void Compute(tensorflow::OpKernelContext* context) override {
27 using namespace tensorflow;
28 static_assert(sizeof(int64) == sizeof(int64_t),
29 "int64 type is not compatible");
30
31 const Tensor& num_points_tensor = context->input(0);
32 OP_REQUIRES(context,
33 TensorShapeUtils::IsScalar(num_points_tensor.shape()),
34 errors::InvalidArgument(
35 "num_points must be scalar, got shape ",
36 num_points_tensor.shape().DebugString()));
37 const int64 num_points = num_points_tensor.scalar<int64>()();
38
39 const Tensor& inp_neighbors_index = context->input(1);
40
41 const Tensor& inp_neighbors_row_splits = context->input(2);
42
43 const Tensor& inp_neighbors_attributes = context->input(3);
44
45 // check input shapes
46 {
47 using namespace open3d::ml::op_util;
48 Dim num_neighbors("num_neighbors");
49
50 CHECK_SHAPE(context, inp_neighbors_index, num_neighbors);
51 CHECK_SHAPE_IGNORE_LAST_DIMS(context, inp_neighbors_attributes,
52 num_neighbors || 0);
53 CHECK_SHAPE(context, inp_neighbors_row_splits, Dim());
54 }
55
56 // compute the number of attributes for each neighbor
57 int num_attributes;
58 if (inp_neighbors_attributes.shape().dim_size(0) == 0) {
59 num_attributes = 0;
60 } else {
61 num_attributes = 1;
62 for (int i = 1; i < inp_neighbors_attributes.shape().dims(); ++i)
63 num_attributes *= inp_neighbors_attributes.shape().dim_size(i);
64 }
65
66 Tensor* neighbors_index = 0;
67 TensorShape neighbors_index_shape(inp_neighbors_index.shape());
68 OP_REQUIRES_OK(context,
69 context->allocate_output(0, neighbors_index_shape,
70 &neighbors_index));
71
72 Tensor* neighbors_row_splits = 0;
73 TensorShape neighbors_row_splits_shape({num_points + 1});
74 OP_REQUIRES_OK(context,
75 context->allocate_output(1, neighbors_row_splits_shape,
76 &neighbors_row_splits));
77
78 Tensor* neighbors_attributes = 0;
79 TensorShape neighbors_attributes_shape(
80 inp_neighbors_attributes.shape());
81 OP_REQUIRES_OK(context,
82 context->allocate_output(2, neighbors_attributes_shape,
83 &neighbors_attributes));
84
85 Kernel(context, inp_neighbors_index, inp_neighbors_row_splits,
86 inp_neighbors_attributes, num_attributes, *neighbors_index,
87 *neighbors_row_splits, *neighbors_attributes);
88 }
89
90 // Function with the device specific code
91 virtual void Kernel(tensorflow::OpKernelContext* context,
92 const tensorflow::Tensor& inp_neighbors_index,
93 const tensorflow::Tensor& inp_neighbors_row_splits,
94 const tensorflow::Tensor& inp_neighbors_attributes,
95 const int num_attributes,
96 tensorflow::Tensor& neighbors_index,
97 tensorflow::Tensor& neighbors_row_splits,
98 tensorflow::Tensor& neighbors_attributes) = 0;
99
100private:
101};
102
103} // namespace invert_neighbors_list_opkernel
#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor,...)
Definition TorchHelper.h:225
#define CHECK_SHAPE(tensor,...)
Definition TorchHelper.h:186
ImGuiContext * context
Definition Window.cpp:76
Class for dimensions for which the value should be inferred.
Definition ShapeChecking.h:50
Definition ShapeChecking.h:16