aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authorPaolo Abeni <pabeni@redhat.com>2025-09-18 12:32:31 +0200
committerPaolo Abeni <pabeni@redhat.com>2025-09-18 12:38:34 +0200
commit64d2616972b77506731fa0122d3c48cb04dbe21b (patch)
tree53c6473a9fb1fa4135233f4c5c18eb5c99c77048 /net
parentMerge branch 'eth-fbnic-add-devlink-health-support-for-fw-crashes-and-otp-mem... (diff)
parentnet/mlx5e: Implement PSP key_rotate operation (diff)
downloadlinux-64d2616972b77506731fa0122d3c48cb04dbe21b.tar.gz
linux-64d2616972b77506731fa0122d3c48cb04dbe21b.zip
Merge branch 'add-basic-psp-encryption-for-tcp-connections'
Daniel Zahka says: ================== add basic PSP encryption for TCP connections This is v13 of the PSP RFC [1] posted by Jakub Kicinski one year ago. General developments since v1 include a fork of packetdrill [2] with support for PSP added, as well as some test cases, and an implementation of PSP key exchange and connection upgrade [3] integrated into the fbthrift RPC library. Both [2] and [3] have been tested on server platforms with PSP-capable CX7 NICs. Below is the cover letter from the original RFC: Add support for PSP encryption of TCP connections. PSP is a protocol out of Google: https://github.com/google/psp/blob/main/doc/PSP_Arch_Spec.pdf which shares some similarities with IPsec. I added some more info in the first patch so I'll keep it short here. The protocol can work in multiple modes including tunneling. But I'm mostly interested in using it as TLS replacement because of its superior offload characteristics. So this patch does three things: - it adds "core" PSP code PSP is offload-centric, and requires some additional care and feeding, so first chunk of the code exposes device info. This part can be reused by PSP implementations in xfrm, tunneling etc. - TCP integration TLS style Reuse some of the existing concepts from TLS offload, such as attaching crypto state to a socket, marking skbs as "decrypted", egress validation. PSP does not prescribe key exchange protocols. To use PSP as a more efficient TLS offload we intend to perform a TLS handshake ("inline" in the same TCP connection) and negotiate switching to PSP based on capabilities of both endpoints. This is also why I'm not including a software implementation. Nobody would use it in production, software TLS is faster, it has larger crypto records. - mlx5 implementation That's mostly other people's work, not 100% sure those folks consider it ready hence the RFC in the title. But it works :) Not posted, queued a branch [4] are follow up pieces: - standard stats - netdevsim implementation and tests [1] https://lore.kernel.org/netdev/20240510030435.120935-1-kuba@kernel.org/ [2] https://github.com/danieldzahka/packetdrill [3] https://github.com/danieldzahka/fbthrift/tree/dzahka/psp [4] https://github.com/kuba-moo/linux/tree/psp Comments we intend to defer to future series: - we prefer to keep the version field in the tx-assoc netlink request, because it makes parsing keys require less state early on, but we are willing to change in the next version of this series. - using a static branch to wrap psp_enqueue_set_decrypted() and other functions called from tcp. - using INDIRECT_CALL for tls/psp in sk_validate_xmit_skb(). We prefer to address this in a dedicated patch series, so that this series does not need to modify the way tls_validate_xmit_skb() is declared and stubbed out. v12: https://lore.kernel.org/netdev/20250916000559.1320151-1-kuba@kernel.org/ v11: https://lore.kernel.org/20250911014735.118695-1-daniel.zahka@gmail.com v10: https://lore.kernel.org/netdev/20250828162953.2707727-1-daniel.zahka@gmail.com/ v9: https://lore.kernel.org/netdev/20250827155340.2738246-1-daniel.zahka@gmail.com/ v8: https://lore.kernel.org/netdev/20250825200112.1750547-1-daniel.zahka@gmail.com/ v7: https://lore.kernel.org/netdev/20250820113120.992829-1-daniel.zahka@gmail.com/ v6: https://lore.kernel.org/netdev/20250812003009.2455540-1-daniel.zahka@gmail.com/ v5: https://lore.kernel.org/netdev/20250723203454.519540-1-daniel.zahka@gmail.com/ v4: https://lore.kernel.org/netdev/20250716144551.3646755-1-daniel.zahka@gmail.com/ v3: https://lore.kernel.org/netdev/20250702171326.3265825-1-daniel.zahka@gmail.com/ v2: https://lore.kernel.org/netdev/20250625135210.2975231-1-daniel.zahka@gmail.com/ v1: https://lore.kernel.org/netdev/20240510030435.120935-1-kuba@kernel.org/ ================== Links: https://patch.msgid.link/20250917000954.859376-1-daniel.zahka@gmail.com Signed-off-by: Paolo Abeni <pabeni@redhat.com> --- * add-basic-psp-encryption-for-tcp-connections: net/mlx5e: Implement PSP key_rotate operation net/mlx5e: Add Rx data path offload psp: provide decapsulation and receive helper for drivers net/mlx5e: Configure PSP Rx flow steering rules net/mlx5e: Add PSP steering in local NIC RX net/mlx5e: Implement PSP Tx data path psp: provide encapsulation helper for drivers net/mlx5e: Implement PSP operations .assoc_add and .assoc_del net/mlx5e: Support PSP offload functionality psp: track generations of device key net: psp: update the TCP MSS to reflect PSP packet overhead net: psp: add socket security association code net: tcp: allow tcp_timewait_sock to validate skbs before handing to device net: move sk_validate_xmit_skb() to net/core/dev.c psp: add op for rotation of device key tcp: add datapath logic for PSP with inline key exchange net: modify core data structures for PSP datapath support psp: base PSP device support psp: add documentation
Diffstat (limited to 'net')
-rw-r--r--net/Kconfig1
-rw-r--r--net/Makefile1
-rw-r--r--net/core/dev.c32
-rw-r--r--net/core/gro.c2
-rw-r--r--net/core/skbuff.c4
-rw-r--r--net/ipv4/af_inet.c2
-rw-r--r--net/ipv4/inet_timewait_sock.c5
-rw-r--r--net/ipv4/ip_output.c5
-rw-r--r--net/ipv4/tcp.c2
-rw-r--r--net/ipv4/tcp_ipv4.c18
-rw-r--r--net/ipv4/tcp_minisocks.c20
-rw-r--r--net/ipv4/tcp_output.c17
-rw-r--r--net/ipv6/ipv6_sockglue.c6
-rw-r--r--net/ipv6/tcp_ipv6.c17
-rw-r--r--net/psp/Kconfig15
-rw-r--r--net/psp/Makefile5
-rw-r--r--net/psp/psp-nl-gen.c119
-rw-r--r--net/psp/psp-nl-gen.h39
-rw-r--r--net/psp/psp.h54
-rw-r--r--net/psp/psp_main.c321
-rw-r--r--net/psp/psp_nl.c505
-rw-r--r--net/psp/psp_sock.c295
22 files changed, 1471 insertions, 14 deletions
diff --git a/net/Kconfig b/net/Kconfig
index d5865cf19799..4b563aea4c23 100644
--- a/net/Kconfig
+++ b/net/Kconfig
@@ -82,6 +82,7 @@ config NET_CRC32C
menu "Networking options"
source "net/packet/Kconfig"
+source "net/psp/Kconfig"
source "net/unix/Kconfig"
source "net/tls/Kconfig"
source "net/xfrm/Kconfig"
diff --git a/net/Makefile b/net/Makefile
index aac960c41db6..90e3d72bf58b 100644
--- a/net/Makefile
+++ b/net/Makefile
@@ -18,6 +18,7 @@ obj-$(CONFIG_INET) += ipv4/
obj-$(CONFIG_TLS) += tls/
obj-$(CONFIG_XFRM) += xfrm/
obj-$(CONFIG_UNIX) += unix/
+obj-$(CONFIG_INET_PSP) += psp/
obj-y += ipv6/
obj-$(CONFIG_PACKET) += packet/
obj-$(CONFIG_NET_KEY) += key/
diff --git a/net/core/dev.c b/net/core/dev.c
index 2522d9d8f0e4..5e22d062bac5 100644
--- a/net/core/dev.c
+++ b/net/core/dev.c
@@ -3907,6 +3907,38 @@ sw_checksum:
}
EXPORT_SYMBOL(skb_csum_hwoffload_help);
+/* Checks if this SKB belongs to an HW offloaded socket
+ * and whether any SW fallbacks are required based on dev.
+ * Check decrypted mark in case skb_orphan() cleared socket.
+ */
+static struct sk_buff *sk_validate_xmit_skb(struct sk_buff *skb,
+ struct net_device *dev)
+{
+#ifdef CONFIG_SOCK_VALIDATE_XMIT
+ struct sk_buff *(*sk_validate)(struct sock *sk, struct net_device *dev,
+ struct sk_buff *skb);
+ struct sock *sk = skb->sk;
+
+ sk_validate = NULL;
+ if (sk) {
+ if (sk_fullsock(sk))
+ sk_validate = sk->sk_validate_xmit_skb;
+ else if (sk_is_inet(sk) && sk->sk_state == TCP_TIME_WAIT)
+ sk_validate = inet_twsk(sk)->tw_validate_xmit_skb;
+ }
+
+ if (sk_validate) {
+ skb = sk_validate(sk, dev, skb);
+ } else if (unlikely(skb_is_decrypted(skb))) {
+ pr_warn_ratelimited("unencrypted skb with no associated socket - dropping\n");
+ kfree_skb(skb);
+ skb = NULL;
+ }
+#endif
+
+ return skb;
+}
+
static struct sk_buff *validate_xmit_unreadable_skb(struct sk_buff *skb,
struct net_device *dev)
{
diff --git a/net/core/gro.c b/net/core/gro.c
index b350e5b69549..5ba4504cfd28 100644
--- a/net/core/gro.c
+++ b/net/core/gro.c
@@ -1,4 +1,5 @@
// SPDX-License-Identifier: GPL-2.0-or-later
+#include <net/psp.h>
#include <net/gro.h>
#include <net/dst_metadata.h>
#include <net/busy_poll.h>
@@ -376,6 +377,7 @@ static void gro_list_prepare(const struct list_head *head,
diffs |= skb_get_nfct(p) ^ skb_get_nfct(skb);
diffs |= gro_list_prepare_tc_ext(skb, p, diffs);
+ diffs |= __psp_skb_coalesce_diff(skb, p, diffs);
}
NAPI_GRO_CB(p)->same_flow = !diffs;
diff --git a/net/core/skbuff.c b/net/core/skbuff.c
index 23b776cd9879..d331e607edfb 100644
--- a/net/core/skbuff.c
+++ b/net/core/skbuff.c
@@ -79,6 +79,7 @@
#include <net/mptcp.h>
#include <net/mctp.h>
#include <net/page_pool/helpers.h>
+#include <net/psp/types.h>
#include <net/dropreason.h>
#include <linux/uaccess.h>
@@ -5062,6 +5063,9 @@ static const u8 skb_ext_type_len[] = {
#if IS_ENABLED(CONFIG_MCTP_FLOWS)
[SKB_EXT_MCTP] = SKB_EXT_CHUNKSIZEOF(struct mctp_flow),
#endif
+#if IS_ENABLED(CONFIG_INET_PSP)
+ [SKB_EXT_PSP] = SKB_EXT_CHUNKSIZEOF(struct psp_skb_ext),
+#endif
};
static __always_inline unsigned int skb_ext_total_length(void)
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 76e38092cd8a..e298dacb4a06 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -102,6 +102,7 @@
#include <net/gro.h>
#include <net/gso.h>
#include <net/tcp.h>
+#include <net/psp.h>
#include <net/udp.h>
#include <net/udplite.h>
#include <net/ping.h>
@@ -158,6 +159,7 @@ void inet_sock_destruct(struct sock *sk)
kfree(rcu_dereference_protected(inet->inet_opt, 1));
dst_release(rcu_dereference_protected(sk->sk_dst_cache, 1));
dst_release(rcu_dereference_protected(sk->sk_rx_dst, 1));
+ psp_sk_assoc_free(sk);
}
EXPORT_SYMBOL(inet_sock_destruct);
diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c
index 5b5426b8ee92..2ca2912f61f4 100644
--- a/net/ipv4/inet_timewait_sock.c
+++ b/net/ipv4/inet_timewait_sock.c
@@ -16,6 +16,7 @@
#include <net/inet_timewait_sock.h>
#include <net/ip.h>
#include <net/tcp.h>
+#include <net/psp.h>
/**
* inet_twsk_bind_unhash - unhash a timewait socket from bind hash
@@ -211,6 +212,9 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk,
atomic64_set(&tw->tw_cookie, atomic64_read(&sk->sk_cookie));
twsk_net_set(tw, sock_net(sk));
timer_setup(&tw->tw_timer, tw_timer_handler, 0);
+#ifdef CONFIG_SOCK_VALIDATE_XMIT
+ tw->tw_validate_xmit_skb = NULL;
+#endif
/*
* Because we use RCU lookups, we should not set tw_refcnt
* to a non null value before everything is setup for this
@@ -219,6 +223,7 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk,
refcount_set(&tw->tw_refcnt, 0);
__module_get(tw->tw_prot->owner);
+ psp_twsk_init(tw, sk);
}
return tw;
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index 2b96651d719b..5ca97ede979c 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -84,6 +84,7 @@
#include <linux/netfilter_bridge.h>
#include <linux/netlink.h>
#include <linux/tcp.h>
+#include <net/psp.h>
static int
ip_fragment(struct net *net, struct sock *sk, struct sk_buff *skb,
@@ -1665,8 +1666,10 @@ void ip_send_unicast_reply(struct sock *sk, const struct sock *orig_sk,
arg->csumoffset) = csum_fold(csum_add(nskb->csum,
arg->csum));
nskb->ip_summed = CHECKSUM_NONE;
- if (orig_sk)
+ if (orig_sk) {
skb_set_owner_edemux(nskb, (struct sock *)orig_sk);
+ psp_reply_set_decrypted(nskb);
+ }
if (transmit_time)
nskb->tstamp_type = SKB_CLOCK_MONOTONIC;
if (txhash)
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 5b5c655ded1d..d6d0d970e014 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -277,6 +277,7 @@
#include <net/proto_memory.h>
#include <net/xfrm.h>
#include <net/ip.h>
+#include <net/psp.h>
#include <net/sock.h>
#include <net/rstreason.h>
@@ -705,6 +706,7 @@ void tcp_skb_entail(struct sock *sk, struct sk_buff *skb)
tcb->seq = tcb->end_seq = tp->write_seq;
tcb->tcp_flags = TCPHDR_ACK;
__skb_header_release(skb);
+ psp_enqueue_set_decrypted(sk, skb);
tcp_add_write_queue_tail(sk, skb);
sk_wmem_queued_add(sk, skb->truesize);
sk_mem_charge(sk, skb->truesize);
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 6a63be1f6461..b1fcf3e4e1ce 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -75,6 +75,7 @@
#include <net/secure_seq.h>
#include <net/busy_poll.h>
#include <net/rstreason.h>
+#include <net/psp.h>
#include <linux/inet.h>
#include <linux/ipv6.h>
@@ -293,9 +294,9 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
inet->inet_dport = usin->sin_port;
sk_daddr_set(sk, daddr);
- inet_csk(sk)->icsk_ext_hdr_len = 0;
+ inet_csk(sk)->icsk_ext_hdr_len = psp_sk_overhead(sk);
if (inet_opt)
- inet_csk(sk)->icsk_ext_hdr_len = inet_opt->opt.optlen;
+ inet_csk(sk)->icsk_ext_hdr_len += inet_opt->opt.optlen;
tp->rx_opt.mss_clamp = TCP_MSS_DEFAULT;
@@ -1907,6 +1908,10 @@ int tcp_v4_do_rcv(struct sock *sk, struct sk_buff *skb)
enum skb_drop_reason reason;
struct sock *rsk;
+ reason = psp_sk_rx_policy_check(sk, skb);
+ if (reason)
+ goto err_discard;
+
if (sk->sk_state == TCP_ESTABLISHED) { /* Fast path */
struct dst_entry *dst;
@@ -1968,6 +1973,7 @@ csum_err:
reason = SKB_DROP_REASON_TCP_CSUM;
trace_tcp_bad_csum(skb);
TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS);
+err_discard:
TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS);
goto discard;
}
@@ -2069,7 +2075,9 @@ bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb,
(TCPHDR_ECE | TCPHDR_CWR | TCPHDR_AE)) ||
!tcp_skb_can_collapse_rx(tail, skb) ||
thtail->doff != th->doff ||
- memcmp(thtail + 1, th + 1, hdrlen - sizeof(*th)))
+ memcmp(thtail + 1, th + 1, hdrlen - sizeof(*th)) ||
+ /* prior to PSP Rx policy check, retain exact PSP metadata */
+ psp_skb_coalesce_diff(tail, skb))
goto no_coalesce;
__skb_pull(skb, hdrlen);
@@ -2437,6 +2445,10 @@ do_time_wait:
__this_cpu_write(tcp_tw_isn, isn);
goto process;
}
+
+ drop_reason = psp_twsk_rx_policy_check(inet_twsk(sk), skb);
+ if (drop_reason)
+ break;
}
/* to ACK */
fallthrough;
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index 327095ef95ef..2ec8c6f1cdcc 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -24,6 +24,7 @@
#include <net/xfrm.h>
#include <net/busy_poll.h>
#include <net/rstreason.h>
+#include <net/psp.h>
static bool tcp_in_window(u32 seq, u32 end_seq, u32 s_win, u32 e_win)
{
@@ -104,9 +105,16 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb,
struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw);
u32 rcv_nxt = READ_ONCE(tcptw->tw_rcv_nxt);
struct tcp_options_received tmp_opt;
+ enum skb_drop_reason psp_drop;
bool paws_reject = false;
int ts_recent_stamp;
+ /* Instead of dropping immediately, wait to see what value is
+ * returned. We will accept a non psp-encapsulated syn in the
+ * case where TCP_TW_SYN is returned.
+ */
+ psp_drop = psp_twsk_rx_policy_check(tw, skb);
+
tmp_opt.saw_tstamp = 0;
ts_recent_stamp = READ_ONCE(tcptw->tw_ts_recent_stamp);
if (th->doff > (sizeof(*th) >> 2) && ts_recent_stamp) {
@@ -124,6 +132,9 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb,
if (READ_ONCE(tw->tw_substate) == TCP_FIN_WAIT2) {
/* Just repeat all the checks of tcp_rcv_state_process() */
+ if (psp_drop)
+ goto out_put;
+
/* Out of window, send ACK */
if (paws_reject ||
!tcp_in_window(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq,
@@ -194,6 +205,9 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb,
(TCP_SKB_CB(skb)->seq == TCP_SKB_CB(skb)->end_seq || th->rst))) {
/* In window segment, it may be only reset or bare ack. */
+ if (psp_drop)
+ goto out_put;
+
if (th->rst) {
/* This is TIME_WAIT assassination, in two flavors.
* Oh well... nobody has a sufficient solution to this
@@ -247,6 +261,9 @@ kill:
return TCP_TW_SYN;
}
+ if (psp_drop)
+ goto out_put;
+
if (paws_reject) {
*drop_reason = SKB_DROP_REASON_TCP_RFC7323_TW_PAWS;
__NET_INC_STATS(twsk_net(tw), LINUX_MIB_PAWS_TW_REJECTED);
@@ -265,6 +282,8 @@ kill:
return tcp_timewait_check_oow_rate_limit(
tw, skb, LINUX_MIB_TCPACKSKIPPEDTIMEWAIT);
}
+
+out_put:
inet_twsk_put(tw);
return TCP_TW_SUCCESS;
}
@@ -392,6 +411,7 @@ void tcp_twsk_destructor(struct sock *sk)
}
#endif
tcp_ao_destroy_sock(sk, true);
+ psp_twsk_assoc_free(inet_twsk(sk));
}
void tcp_twsk_purge(struct list_head *net_exit_list)
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 388c45859469..223d7feeb19d 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -41,6 +41,7 @@
#include <net/tcp_ecn.h>
#include <net/mptcp.h>
#include <net/proto_memory.h>
+#include <net/psp.h>
#include <linux/compiler.h>
#include <linux/gfp.h>
@@ -358,13 +359,15 @@ static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb,
/* Constructs common control bits of non-data skb. If SYN/FIN is present,
* auto increment end seqno.
*/
-static void tcp_init_nondata_skb(struct sk_buff *skb, u32 seq, u16 flags)
+static void tcp_init_nondata_skb(struct sk_buff *skb, struct sock *sk,
+ u32 seq, u16 flags)
{
skb->ip_summed = CHECKSUM_PARTIAL;
TCP_SKB_CB(skb)->tcp_flags = flags;
tcp_skb_pcount_set(skb, 1);
+ psp_enqueue_set_decrypted(sk, skb);
TCP_SKB_CB(skb)->seq = seq;
if (flags & (TCPHDR_SYN | TCPHDR_FIN))
@@ -1656,6 +1659,7 @@ static void tcp_queue_skb(struct sock *sk, struct sk_buff *skb)
/* Advance write_seq and place onto the write_queue. */
WRITE_ONCE(tp->write_seq, TCP_SKB_CB(skb)->end_seq);
__skb_header_release(skb);
+ psp_enqueue_set_decrypted(sk, skb);
tcp_add_write_queue_tail(sk, skb);
sk_wmem_queued_add(sk, skb->truesize);
sk_mem_charge(sk, skb->truesize);
@@ -3778,7 +3782,7 @@ void tcp_send_fin(struct sock *sk)
skb_reserve(skb, MAX_TCP_HEADER);
sk_forced_mem_schedule(sk, skb->truesize);
/* FIN eats a sequence byte, write_seq advanced by tcp_queue_skb(). */
- tcp_init_nondata_skb(skb, tp->write_seq,
+ tcp_init_nondata_skb(skb, sk, tp->write_seq,
TCPHDR_ACK | TCPHDR_FIN);
tcp_queue_skb(sk, skb);
}
@@ -3806,7 +3810,7 @@ void tcp_send_active_reset(struct sock *sk, gfp_t priority,
/* Reserve space for headers and prepare control bits. */
skb_reserve(skb, MAX_TCP_HEADER);
- tcp_init_nondata_skb(skb, tcp_acceptable_seq(sk),
+ tcp_init_nondata_skb(skb, sk, tcp_acceptable_seq(sk),
TCPHDR_ACK | TCPHDR_RST);
tcp_mstamp_refresh(tcp_sk(sk));
/* Send it off. */
@@ -4303,7 +4307,7 @@ int tcp_connect(struct sock *sk)
/* SYN eats a sequence byte, write_seq updated by
* tcp_connect_queue_skb().
*/
- tcp_init_nondata_skb(buff, tp->write_seq, TCPHDR_SYN);
+ tcp_init_nondata_skb(buff, sk, tp->write_seq, TCPHDR_SYN);
tcp_mstamp_refresh(tp);
tp->retrans_stamp = tcp_time_stamp_ts(tp);
tcp_connect_queue_skb(sk, buff);
@@ -4428,7 +4432,8 @@ void __tcp_send_ack(struct sock *sk, u32 rcv_nxt, u16 flags)
/* Reserve space for headers and prepare control bits. */
skb_reserve(buff, MAX_TCP_HEADER);
- tcp_init_nondata_skb(buff, tcp_acceptable_seq(sk), TCPHDR_ACK | flags);
+ tcp_init_nondata_skb(buff, sk,
+ tcp_acceptable_seq(sk), TCPHDR_ACK | flags);
/* We do not want pure acks influencing TCP Small Queues or fq/pacing
* too much.
@@ -4474,7 +4479,7 @@ static int tcp_xmit_probe_skb(struct sock *sk, int urgent, int mib)
* end to send an ack. Don't queue or clone SKB, just
* send it.
*/
- tcp_init_nondata_skb(skb, tp->snd_una - !urgent, TCPHDR_ACK);
+ tcp_init_nondata_skb(skb, sk, tp->snd_una - !urgent, TCPHDR_ACK);
NET_INC_STATS(sock_net(sk), mib);
return tcp_transmit_skb(sk, skb, 0, (__force gfp_t)0);
}
diff --git a/net/ipv6/ipv6_sockglue.c b/net/ipv6/ipv6_sockglue.c
index e66ec623972e..a61e742794f9 100644
--- a/net/ipv6/ipv6_sockglue.c
+++ b/net/ipv6/ipv6_sockglue.c
@@ -49,6 +49,7 @@
#include <net/xfrm.h>
#include <net/compat.h>
#include <net/seg6.h>
+#include <net/psp.h>
#include <linux/uaccess.h>
@@ -107,7 +108,10 @@ struct ipv6_txoptions *ipv6_update_options(struct sock *sk,
!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
inet_sk(sk)->inet_daddr != LOOPBACK4_IPV6) {
struct inet_connection_sock *icsk = inet_csk(sk);
- icsk->icsk_ext_hdr_len = opt->opt_flen + opt->opt_nflen;
+
+ icsk->icsk_ext_hdr_len =
+ psp_sk_overhead(sk) +
+ opt->opt_flen + opt->opt_nflen;
icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie);
}
}
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index c7271f6359db..d1e5b2a186fb 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -62,6 +62,7 @@
#include <net/hotdata.h>
#include <net/busy_poll.h>
#include <net/rstreason.h>
+#include <net/psp.h>
#include <linux/proc_fs.h>
#include <linux/seq_file.h>
@@ -301,10 +302,10 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
sk->sk_gso_type = SKB_GSO_TCPV6;
ip6_dst_store(sk, dst, false, false);
- icsk->icsk_ext_hdr_len = 0;
+ icsk->icsk_ext_hdr_len = psp_sk_overhead(sk);
if (opt)
- icsk->icsk_ext_hdr_len = opt->opt_flen +
- opt->opt_nflen;
+ icsk->icsk_ext_hdr_len += opt->opt_flen +
+ opt->opt_nflen;
tp->rx_opt.mss_clamp = IPV6_MIN_MTU - sizeof(struct tcphdr) - sizeof(struct ipv6hdr);
@@ -973,6 +974,7 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32
if (sk) {
/* unconstify the socket only to attach it to buff with care. */
skb_set_owner_edemux(buff, (struct sock *)sk);
+ psp_reply_set_decrypted(buff);
if (sk->sk_state == TCP_TIME_WAIT)
mark = inet_twsk(sk)->tw_mark;
@@ -1605,6 +1607,10 @@ int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb)
if (skb->protocol == htons(ETH_P_IP))
return tcp_v4_do_rcv(sk, skb);
+ reason = psp_sk_rx_policy_check(sk, skb);
+ if (reason)
+ goto err_discard;
+
/*
* socket locking is here for SMP purposes as backlog rcv
* is currently called with bh processing disabled.
@@ -1684,6 +1690,7 @@ csum_err:
reason = SKB_DROP_REASON_TCP_CSUM;
trace_tcp_bad_csum(skb);
TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS);
+err_discard:
TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS);
goto discard;
@@ -1988,6 +1995,10 @@ do_time_wait:
__this_cpu_write(tcp_tw_isn, isn);
goto process;
}
+
+ drop_reason = psp_twsk_rx_policy_check(inet_twsk(sk), skb);
+ if (drop_reason)
+ break;
}
/* to ACK */
fallthrough;
diff --git a/net/psp/Kconfig b/net/psp/Kconfig
new file mode 100644
index 000000000000..a7d24691a7e1
--- /dev/null
+++ b/net/psp/Kconfig
@@ -0,0 +1,15 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# PSP configuration
+#
+config INET_PSP
+ bool "PSP Security Protocol support"
+ depends on INET
+ select SKB_DECRYPTED
+ select SOCK_VALIDATE_XMIT
+ help
+ Enable kernel support for the PSP protocol.
+ For more information see:
+ https://raw.githubusercontent.com/google/psp/main/doc/PSP_Arch_Spec.pdf
+
+ If unsure, say N.
diff --git a/net/psp/Makefile b/net/psp/Makefile
new file mode 100644
index 000000000000..eb5ff3c5bfb2
--- /dev/null
+++ b/net/psp/Makefile
@@ -0,0 +1,5 @@
+# SPDX-License-Identifier: GPL-2.0-only
+
+obj-$(CONFIG_INET_PSP) += psp.o
+
+psp-y := psp_main.o psp_nl.o psp_sock.o psp-nl-gen.o
diff --git a/net/psp/psp-nl-gen.c b/net/psp/psp-nl-gen.c
new file mode 100644
index 000000000000..9fdd6f831803
--- /dev/null
+++ b/net/psp/psp-nl-gen.c
@@ -0,0 +1,119 @@
+// SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
+/* Do not edit directly, auto-generated from: */
+/* Documentation/netlink/specs/psp.yaml */
+/* YNL-GEN kernel source */
+
+#include <net/netlink.h>
+#include <net/genetlink.h>
+
+#include "psp-nl-gen.h"
+
+#include <uapi/linux/psp.h>
+
+/* Common nested types */
+const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1] = {
+ [PSP_A_KEYS_KEY] = { .type = NLA_BINARY, },
+ [PSP_A_KEYS_SPI] = { .type = NLA_U32, },
+};
+
+/* PSP_CMD_DEV_GET - do */
+static const struct nla_policy psp_dev_get_nl_policy[PSP_A_DEV_ID + 1] = {
+ [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+};
+
+/* PSP_CMD_DEV_SET - do */
+static const struct nla_policy psp_dev_set_nl_policy[PSP_A_DEV_PSP_VERSIONS_ENA + 1] = {
+ [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+ [PSP_A_DEV_PSP_VERSIONS_ENA] = NLA_POLICY_MASK(NLA_U32, 0xf),
+};
+
+/* PSP_CMD_KEY_ROTATE - do */
+static const struct nla_policy psp_key_rotate_nl_policy[PSP_A_DEV_ID + 1] = {
+ [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+};
+
+/* PSP_CMD_RX_ASSOC - do */
+static const struct nla_policy psp_rx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = {
+ [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+ [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3),
+ [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, },
+};
+
+/* PSP_CMD_TX_ASSOC - do */
+static const struct nla_policy psp_tx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = {
+ [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+ [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3),
+ [PSP_A_ASSOC_TX_KEY] = NLA_POLICY_NESTED(psp_keys_nl_policy),
+ [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, },
+};
+
+/* Ops table for psp */
+static const struct genl_split_ops psp_nl_ops[] = {
+ {
+ .cmd = PSP_CMD_DEV_GET,
+ .pre_doit = psp_device_get_locked,
+ .doit = psp_nl_dev_get_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_dev_get_nl_policy,
+ .maxattr = PSP_A_DEV_ID,
+ .flags = GENL_CMD_CAP_DO,
+ },
+ {
+ .cmd = PSP_CMD_DEV_GET,
+ .dumpit = psp_nl_dev_get_dumpit,
+ .flags = GENL_CMD_CAP_DUMP,
+ },
+ {
+ .cmd = PSP_CMD_DEV_SET,
+ .pre_doit = psp_device_get_locked,
+ .doit = psp_nl_dev_set_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_dev_set_nl_policy,
+ .maxattr = PSP_A_DEV_PSP_VERSIONS_ENA,
+ .flags = GENL_CMD_CAP_DO,
+ },
+ {
+ .cmd = PSP_CMD_KEY_ROTATE,
+ .pre_doit = psp_device_get_locked,
+ .doit = psp_nl_key_rotate_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_key_rotate_nl_policy,
+ .maxattr = PSP_A_DEV_ID,
+ .flags = GENL_CMD_CAP_DO,
+ },
+ {
+ .cmd = PSP_CMD_RX_ASSOC,
+ .pre_doit = psp_assoc_device_get_locked,
+ .doit = psp_nl_rx_assoc_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_rx_assoc_nl_policy,
+ .maxattr = PSP_A_ASSOC_SOCK_FD,
+ .flags = GENL_CMD_CAP_DO,
+ },
+ {
+ .cmd = PSP_CMD_TX_ASSOC,
+ .pre_doit = psp_assoc_device_get_locked,
+ .doit = psp_nl_tx_assoc_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_tx_assoc_nl_policy,
+ .maxattr = PSP_A_ASSOC_SOCK_FD,
+ .flags = GENL_CMD_CAP_DO,
+ },
+};
+
+static const struct genl_multicast_group psp_nl_mcgrps[] = {
+ [PSP_NLGRP_MGMT] = { "mgmt", },
+ [PSP_NLGRP_USE] = { "use", },
+};
+
+struct genl_family psp_nl_family __ro_after_init = {
+ .name = PSP_FAMILY_NAME,
+ .version = PSP_FAMILY_VERSION,
+ .netnsok = true,
+ .parallel_ops = true,
+ .module = THIS_MODULE,
+ .split_ops = psp_nl_ops,
+ .n_split_ops = ARRAY_SIZE(psp_nl_ops),
+ .mcgrps = psp_nl_mcgrps,
+ .n_mcgrps = ARRAY_SIZE(psp_nl_mcgrps),
+};
diff --git a/net/psp/psp-nl-gen.h b/net/psp/psp-nl-gen.h
new file mode 100644
index 000000000000..25268ed11fb5
--- /dev/null
+++ b/net/psp/psp-nl-gen.h
@@ -0,0 +1,39 @@
+/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */
+/* Do not edit directly, auto-generated from: */
+/* Documentation/netlink/specs/psp.yaml */
+/* YNL-GEN kernel header */
+
+#ifndef _LINUX_PSP_GEN_H
+#define _LINUX_PSP_GEN_H
+
+#include <net/netlink.h>
+#include <net/genetlink.h>
+
+#include <uapi/linux/psp.h>
+
+/* Common nested types */
+extern const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1];
+
+int psp_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info);
+int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info);
+void
+psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
+ struct genl_info *info);
+
+int psp_nl_dev_get_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_dev_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb);
+int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info);
+
+enum {
+ PSP_NLGRP_MGMT,
+ PSP_NLGRP_USE,
+};
+
+extern struct genl_family psp_nl_family;
+
+#endif /* _LINUX_PSP_GEN_H */
diff --git a/net/psp/psp.h b/net/psp/psp.h
new file mode 100644
index 000000000000..0f34e1a23fdd
--- /dev/null
+++ b/net/psp/psp.h
@@ -0,0 +1,54 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+
+#ifndef __PSP_PSP_H
+#define __PSP_PSP_H
+
+#include <linux/list.h>
+#include <linux/lockdep.h>
+#include <linux/mutex.h>
+#include <net/netns/generic.h>
+#include <net/psp.h>
+#include <net/sock.h>
+
+extern struct xarray psp_devs;
+extern struct mutex psp_devs_lock;
+
+void psp_dev_destroy(struct psp_dev *psd);
+int psp_dev_check_access(struct psp_dev *psd, struct net *net);
+
+void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd);
+
+struct psp_assoc *psp_assoc_create(struct psp_dev *psd);
+struct psp_dev *psp_dev_get_for_sock(struct sock *sk);
+void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas);
+int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
+ struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack);
+int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
+ u32 version, struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack);
+void psp_assocs_key_rotated(struct psp_dev *psd);
+
+static inline void psp_dev_get(struct psp_dev *psd)
+{
+ refcount_inc(&psd->refcnt);
+}
+
+static inline bool psp_dev_tryget(struct psp_dev *psd)
+{
+ return refcount_inc_not_zero(&psd->refcnt);
+}
+
+static inline void psp_dev_put(struct psp_dev *psd)
+{
+ if (refcount_dec_and_test(&psd->refcnt))
+ psp_dev_destroy(psd);
+}
+
+static inline bool psp_dev_is_registered(struct psp_dev *psd)
+{
+ lockdep_assert_held(&psd->lock);
+ return !!psd->ops;
+}
+
+#endif /* __PSP_PSP_H */
diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c
new file mode 100644
index 000000000000..b4b756f87382
--- /dev/null
+++ b/net/psp/psp_main.c
@@ -0,0 +1,321 @@
+// SPDX-License-Identifier: GPL-2.0-only
+
+#include <linux/bitfield.h>
+#include <linux/list.h>
+#include <linux/netdevice.h>
+#include <linux/xarray.h>
+#include <net/net_namespace.h>
+#include <net/psp.h>
+#include <net/udp.h>
+
+#include "psp.h"
+#include "psp-nl-gen.h"
+
+DEFINE_XARRAY_ALLOC1(psp_devs);
+struct mutex psp_devs_lock;
+
+/**
+ * DOC: PSP locking
+ *
+ * psp_devs_lock protects the psp_devs xarray.
+ * Ordering is take the psp_devs_lock and then the instance lock.
+ * Each instance is protected by RCU, and has a refcount.
+ * When driver unregisters the instance gets flushed, but struct sticks around.
+ */
+
+/**
+ * psp_dev_check_access() - check if user in a given net ns can access PSP dev
+ * @psd: PSP device structure user is trying to access
+ * @net: net namespace user is in
+ *
+ * Return: 0 if PSP device should be visible in @net, errno otherwise.
+ */
+int psp_dev_check_access(struct psp_dev *psd, struct net *net)
+{
+ if (dev_net(psd->main_netdev) == net)
+ return 0;
+ return -ENOENT;
+}
+
+/**
+ * psp_dev_create() - create and register PSP device
+ * @netdev: main netdevice
+ * @psd_ops: driver callbacks
+ * @psd_caps: device capabilities
+ * @priv_ptr: back-pointer to driver private data
+ *
+ * Return: pointer to allocated PSP device, or ERR_PTR.
+ */
+struct psp_dev *
+psp_dev_create(struct net_device *netdev,
+ struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps,
+ void *priv_ptr)
+{
+ struct psp_dev *psd;
+ static u32 last_id;
+ int err;
+
+ if (WARN_ON(!psd_caps->versions ||
+ !psd_ops->set_config ||
+ !psd_ops->key_rotate ||
+ !psd_ops->rx_spi_alloc ||
+ !psd_ops->tx_key_add ||
+ !psd_ops->tx_key_del))
+ return ERR_PTR(-EINVAL);
+
+ psd = kzalloc(sizeof(*psd), GFP_KERNEL);
+ if (!psd)
+ return ERR_PTR(-ENOMEM);
+
+ psd->main_netdev = netdev;
+ psd->ops = psd_ops;
+ psd->caps = psd_caps;
+ psd->drv_priv = priv_ptr;
+
+ mutex_init(&psd->lock);
+ INIT_LIST_HEAD(&psd->active_assocs);
+ INIT_LIST_HEAD(&psd->prev_assocs);
+ INIT_LIST_HEAD(&psd->stale_assocs);
+ refcount_set(&psd->refcnt, 1);
+
+ mutex_lock(&psp_devs_lock);
+ err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b,
+ &last_id, GFP_KERNEL);
+ if (err) {
+ mutex_unlock(&psp_devs_lock);
+ kfree(psd);
+ return ERR_PTR(err);
+ }
+ mutex_lock(&psd->lock);
+ mutex_unlock(&psp_devs_lock);
+
+ psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF);
+
+ rcu_assign_pointer(netdev->psp_dev, psd);
+
+ mutex_unlock(&psd->lock);
+
+ return psd;
+}
+EXPORT_SYMBOL(psp_dev_create);
+
+void psp_dev_destroy(struct psp_dev *psd)
+{
+ mutex_lock(&psp_devs_lock);
+ xa_erase(&psp_devs, psd->id);
+ mutex_unlock(&psp_devs_lock);
+
+ mutex_destroy(&psd->lock);
+ kfree_rcu(psd, rcu);
+}
+
+/**
+ * psp_dev_unregister() - unregister PSP device
+ * @psd: PSP device structure
+ */
+void psp_dev_unregister(struct psp_dev *psd)
+{
+ struct psp_assoc *pas, *next;
+
+ mutex_lock(&psp_devs_lock);
+ mutex_lock(&psd->lock);
+
+ psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF);
+
+ /* Wait until psp_dev_destroy() to call xa_erase() to prevent a
+ * different psd from being added to the xarray with this id, while
+ * there are still references to this psd being held.
+ */
+ xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL);
+ mutex_unlock(&psp_devs_lock);
+
+ list_splice_init(&psd->active_assocs, &psd->prev_assocs);
+ list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
+ list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list)
+ psp_dev_tx_key_del(psd, pas);
+
+ rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
+
+ psd->ops = NULL;
+ psd->drv_priv = NULL;
+
+ mutex_unlock(&psd->lock);
+
+ psp_dev_put(psd);
+}
+EXPORT_SYMBOL(psp_dev_unregister);
+
+unsigned int psp_key_size(u32 version)
+{
+ switch (version) {
+ case PSP_VERSION_HDR0_AES_GCM_128:
+ case PSP_VERSION_HDR0_AES_GMAC_128:
+ return 16;
+ case PSP_VERSION_HDR0_AES_GCM_256:
+ case PSP_VERSION_HDR0_AES_GMAC_256:
+ return 32;
+ default:
+ return 0;
+ }
+}
+EXPORT_SYMBOL(psp_key_size);
+
+static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi,
+ u8 ver, unsigned int udp_len, __be16 sport)
+{
+ struct udphdr *uh = udp_hdr(skb);
+ struct psphdr *psph = (struct psphdr *)(uh + 1);
+
+ uh->dest = htons(PSP_DEFAULT_UDP_PORT);
+ uh->source = udp_flow_src_port(net, skb, 0, 0, false);
+ uh->check = 0;
+ uh->len = htons(udp_len);
+
+ psph->nexthdr = IPPROTO_TCP;
+ psph->hdrlen = PSP_HDRLEN_NOOPT;
+ psph->crypt_offset = 0;
+ psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) |
+ FIELD_PREP(PSPHDR_VERFL_ONE, 1);
+ psph->spi = spi;
+ memset(&psph->iv, 0, sizeof(psph->iv));
+}
+
+/* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling
+ * them in.
+ */
+bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
+ u8 ver, __be16 sport)
+{
+ u32 network_len = skb_network_header_len(skb);
+ u32 ethr_len = skb_mac_header_len(skb);
+ u32 bufflen = ethr_len + network_len;
+
+ if (skb_cow_head(skb, PSP_ENCAP_HLEN))
+ return false;
+
+ skb_push(skb, PSP_ENCAP_HLEN);
+ skb->mac_header -= PSP_ENCAP_HLEN;
+ skb->network_header -= PSP_ENCAP_HLEN;
+ skb->transport_header -= PSP_ENCAP_HLEN;
+ memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen);
+
+ if (skb->protocol == htons(ETH_P_IP)) {
+ ip_hdr(skb)->protocol = IPPROTO_UDP;
+ be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN);
+ ip_hdr(skb)->check = 0;
+ ip_hdr(skb)->check =
+ ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl);
+ } else if (skb->protocol == htons(ETH_P_IPV6)) {
+ ipv6_hdr(skb)->nexthdr = IPPROTO_UDP;
+ be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN);
+ } else {
+ return false;
+ }
+
+ skb_set_inner_ipproto(skb, IPPROTO_TCP);
+ skb_set_inner_transport_header(skb, skb_transport_offset(skb) +
+ PSP_ENCAP_HLEN);
+ skb->encapsulation = 1;
+ psp_write_headers(net, skb, spi, ver,
+ skb->len - skb_transport_offset(skb), sport);
+
+ return true;
+}
+EXPORT_SYMBOL(psp_dev_encapsulate);
+
+/* Receive handler for PSP packets.
+ *
+ * Presently it accepts only already-authenticated packets and does not
+ * support optional fields, such as virtualization cookies. The caller should
+ * ensure that skb->data is pointing to the mac header, and that skb->mac_len
+ * is set.
+ */
+int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
+{
+ int l2_hlen = 0, l3_hlen, encap;
+ struct psp_skb_ext *pse;
+ struct psphdr *psph;
+ struct ethhdr *eth;
+ struct udphdr *uh;
+ __be16 proto;
+ bool is_udp;
+
+ eth = (struct ethhdr *)skb->data;
+ proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
+ if (proto == htons(ETH_P_IP))
+ l3_hlen = sizeof(struct iphdr);
+ else if (proto == htons(ETH_P_IPV6))
+ l3_hlen = sizeof(struct ipv6hdr);
+ else
+ return -EINVAL;
+
+ if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
+ return -EINVAL;
+
+ if (proto == htons(ETH_P_IP)) {
+ struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
+
+ is_udp = iph->protocol == IPPROTO_UDP;
+ l3_hlen = iph->ihl * 4;
+ if (l3_hlen != sizeof(struct iphdr) &&
+ !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
+ return -EINVAL;
+ } else {
+ struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
+
+ is_udp = ipv6h->nexthdr == IPPROTO_UDP;
+ }
+
+ if (unlikely(!is_udp))
+ return -EINVAL;
+
+ uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
+ if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
+ return -EINVAL;
+
+ pse = skb_ext_add(skb, SKB_EXT_PSP);
+ if (!pse)
+ return -EINVAL;
+
+ psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
+ sizeof(struct udphdr));
+ pse->spi = psph->spi;
+ pse->dev_id = dev_id;
+ pse->generation = generation;
+ pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);
+
+ encap = PSP_ENCAP_HLEN;
+ encap += strip_icv ? PSP_TRL_SIZE : 0;
+
+ if (proto == htons(ETH_P_IP)) {
+ struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
+
+ iph->protocol = psph->nexthdr;
+ iph->tot_len = htons(ntohs(iph->tot_len) - encap);
+ iph->check = 0;
+ iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
+ } else {
+ struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
+
+ ipv6h->nexthdr = psph->nexthdr;
+ ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
+ }
+
+ memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen);
+ skb_pull(skb, PSP_ENCAP_HLEN);
+
+ if (strip_icv)
+ pskb_trim(skb, skb->len - PSP_TRL_SIZE);
+
+ return 0;
+}
+EXPORT_SYMBOL(psp_dev_rcv);
+
+static int __init psp_init(void)
+{
+ mutex_init(&psp_devs_lock);
+
+ return genl_register_family(&psp_nl_family);
+}
+
+subsys_initcall(psp_init);
diff --git a/net/psp/psp_nl.c b/net/psp/psp_nl.c
new file mode 100644
index 000000000000..8aaca62744c3
--- /dev/null
+++ b/net/psp/psp_nl.c
@@ -0,0 +1,505 @@
+// SPDX-License-Identifier: GPL-2.0-only
+
+#include <linux/skbuff.h>
+#include <linux/xarray.h>
+#include <net/genetlink.h>
+#include <net/psp.h>
+#include <net/sock.h>
+
+#include "psp-nl-gen.h"
+#include "psp.h"
+
+/* Netlink helpers */
+
+static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
+{
+ struct sk_buff *rsp;
+ void *hdr;
+
+ rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
+ if (!rsp)
+ return NULL;
+
+ hdr = genlmsg_iput(rsp, info);
+ if (!hdr) {
+ nlmsg_free(rsp);
+ return NULL;
+ }
+
+ return rsp;
+}
+
+static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
+{
+ /* Note that this *only* works with a single message per skb! */
+ nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);
+
+ return genlmsg_reply(rsp, info);
+}
+
+/* Device stuff */
+
+static struct psp_dev *
+psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
+{
+ struct psp_dev *psd;
+ int err;
+
+ mutex_lock(&psp_devs_lock);
+ psd = xa_load(&psp_devs, nla_get_u32(dev_id));
+ if (!psd) {
+ mutex_unlock(&psp_devs_lock);
+ return ERR_PTR(-ENODEV);
+ }
+
+ mutex_lock(&psd->lock);
+ mutex_unlock(&psp_devs_lock);
+
+ err = psp_dev_check_access(psd, net);
+ if (err) {
+ mutex_unlock(&psd->lock);
+ return ERR_PTR(err);
+ }
+
+ return psd;
+}
+
+int psp_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info)
+{
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
+ return -EINVAL;
+
+ info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
+ info->attrs[PSP_A_DEV_ID]);
+ return PTR_ERR_OR_ZERO(info->user_ptr[0]);
+}
+
+void
+psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
+ struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+
+ mutex_unlock(&psd->lock);
+ if (socket)
+ sockfd_put(socket);
+}
+
+static int
+psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
+ const struct genl_info *info)
+{
+ void *hdr;
+
+ hdr = genlmsg_iput(rsp, info);
+ if (!hdr)
+ return -EMSGSIZE;
+
+ if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
+ nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
+ nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
+ nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
+ goto err_cancel_msg;
+
+ genlmsg_end(rsp, hdr);
+ return 0;
+
+err_cancel_msg:
+ genlmsg_cancel(rsp, hdr);
+ return -EMSGSIZE;
+}
+
+void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
+{
+ struct genl_info info;
+ struct sk_buff *ntf;
+
+ if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
+ PSP_NLGRP_MGMT))
+ return;
+
+ ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
+ if (!ntf)
+ return;
+
+ genl_info_init_ntf(&info, &psp_nl_family, cmd);
+ if (psp_nl_dev_fill(psd, ntf, &info)) {
+ nlmsg_free(ntf);
+ return;
+ }
+
+ genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
+ 0, PSP_NLGRP_MGMT, GFP_KERNEL);
+}
+
+int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
+{
+ struct psp_dev *psd = info->user_ptr[0];
+ struct sk_buff *rsp;
+ int err;
+
+ rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
+ if (!rsp)
+ return -ENOMEM;
+
+ err = psp_nl_dev_fill(psd, rsp, info);
+ if (err)
+ goto err_free_msg;
+
+ return genlmsg_reply(rsp, info);
+
+err_free_msg:
+ nlmsg_free(rsp);
+ return err;
+}
+
+static int
+psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
+ struct psp_dev *psd)
+{
+ if (psp_dev_check_access(psd, sock_net(rsp->sk)))
+ return 0;
+
+ return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
+}
+
+int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
+{
+ struct psp_dev *psd;
+ int err = 0;
+
+ mutex_lock(&psp_devs_lock);
+ xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
+ mutex_lock(&psd->lock);
+ err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
+ mutex_unlock(&psd->lock);
+ if (err)
+ break;
+ }
+ mutex_unlock(&psp_devs_lock);
+
+ return err;
+}
+
+int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_dev_config new_config;
+ struct sk_buff *rsp;
+ int err;
+
+ memcpy(&new_config, &psd->config, sizeof(new_config));
+
+ if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
+ new_config.versions =
+ nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
+ if (new_config.versions & ~psd->caps->versions) {
+ NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
+ return -EINVAL;
+ }
+ } else {
+ NL_SET_ERR_MSG(info->extack, "No settings present");
+ return -EINVAL;
+ }
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
+ err = psd->ops->set_config(psd, &new_config, info->extack);
+ if (err)
+ goto err_free_rsp;
+
+ memcpy(&psd->config, &new_config, sizeof(new_config));
+ }
+
+ psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct psp_dev *psd = info->user_ptr[0];
+ struct genl_info ntf_info;
+ struct sk_buff *ntf, *rsp;
+ u8 prev_gen;
+ int err;
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
+ ntf = psp_nl_reply_new(&ntf_info);
+ if (!ntf) {
+ err = -ENOMEM;
+ goto err_free_rsp;
+ }
+
+ if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
+ nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
+ err = -EMSGSIZE;
+ goto err_free_ntf;
+ }
+
+ /* suggest the next gen number, driver can override */
+ prev_gen = psd->generation;
+ psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK;
+
+ err = psd->ops->key_rotate(psd, info->extack);
+ if (err)
+ goto err_free_ntf;
+
+ WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) ||
+ psd->generation & ~PSP_GEN_VALID_MASK);
+
+ psp_assocs_key_rotated(psd);
+
+ nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
+ genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
+ 0, PSP_NLGRP_USE, GFP_KERNEL);
+ return psp_nl_reply_send(rsp, info);
+
+err_free_ntf:
+ nlmsg_free(ntf);
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+/* Key etc. */
+
+int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket;
+ struct psp_dev *psd;
+ struct nlattr *id;
+ int fd, err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
+ return -EINVAL;
+
+ fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
+ socket = sockfd_lookup(fd, &err);
+ if (!socket)
+ return err;
+
+ if (!sk_is_tcp(socket->sk)) {
+ NL_SET_ERR_MSG_ATTR(info->extack,
+ info->attrs[PSP_A_ASSOC_SOCK_FD],
+ "Unsupported socket family and type");
+ err = -EOPNOTSUPP;
+ goto err_sock_put;
+ }
+
+ psd = psp_dev_get_for_sock(socket->sk);
+ if (psd) {
+ err = psp_dev_check_access(psd, genl_info_net(info));
+ if (err) {
+ psp_dev_put(psd);
+ psd = NULL;
+ }
+ }
+
+ if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
+ err = -EINVAL;
+ goto err_sock_put;
+ }
+
+ id = info->attrs[PSP_A_ASSOC_DEV_ID];
+ if (psd) {
+ mutex_lock(&psd->lock);
+ if (id && psd->id != nla_get_u32(id)) {
+ mutex_unlock(&psd->lock);
+ NL_SET_ERR_MSG_ATTR(info->extack, id,
+ "Device id vs socket mismatch");
+ err = -EINVAL;
+ goto err_psd_put;
+ }
+
+ psp_dev_put(psd);
+ } else {
+ psd = psp_device_get_and_lock(genl_info_net(info), id);
+ if (IS_ERR(psd)) {
+ err = PTR_ERR(psd);
+ goto err_sock_put;
+ }
+ }
+
+ info->user_ptr[0] = psd;
+ info->user_ptr[1] = socket;
+
+ return 0;
+
+err_psd_put:
+ psp_dev_put(psd);
+err_sock_put:
+ sockfd_put(socket);
+ return err;
+}
+
+static int
+psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
+ unsigned int key_sz)
+{
+ struct nlattr *nest = info->attrs[attr];
+ struct nlattr *tb[PSP_A_KEYS_SPI + 1];
+ u32 spi;
+ int err;
+
+ err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
+ psp_keys_nl_policy, info->extack);
+ if (err)
+ return err;
+
+ if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
+ NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
+ return -EINVAL;
+
+ if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
+ NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
+ "incorrect key length");
+ return -EINVAL;
+ }
+
+ spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
+ if (!(spi & PSP_SPI_KEY_ID)) {
+ NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
+ "invalid SPI: lower 31b must be non-zero");
+ return -EINVAL;
+ }
+
+ key->spi = cpu_to_be32(spi);
+ memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);
+
+ return 0;
+}
+
+static int
+psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
+ struct psp_key_parsed *key)
+{
+ int key_sz = psp_key_size(version);
+ void *nest;
+
+ nest = nla_nest_start(skb, attr);
+
+ if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
+ nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
+ nla_nest_cancel(skb, nest);
+ return -EMSGSIZE;
+ }
+
+ nla_nest_end(skb, nest);
+
+ return 0;
+}
+
+int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_key_parsed key;
+ struct psp_assoc *pas;
+ struct sk_buff *rsp;
+ u32 version;
+ int err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
+ return -EINVAL;
+
+ version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
+ if (!(psd->caps->versions & (1 << version))) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
+ return -EOPNOTSUPP;
+ }
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ pas = psp_assoc_create(psd);
+ if (!pas) {
+ err = -ENOMEM;
+ goto err_free_rsp;
+ }
+ pas->version = version;
+
+ err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
+ if (err)
+ goto err_free_pas;
+
+ if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
+ psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
+ err = -EMSGSIZE;
+ goto err_free_pas;
+ }
+
+ err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
+ if (err) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
+ goto err_free_pas;
+ }
+ psp_assoc_put(pas);
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_pas:
+ psp_assoc_put(pas);
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_key_parsed key;
+ struct sk_buff *rsp;
+ unsigned int key_sz;
+ u32 version;
+ int err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
+ GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
+ return -EINVAL;
+
+ version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
+ if (!(psd->caps->versions & (1 << version))) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
+ return -EOPNOTSUPP;
+ }
+
+ key_sz = psp_key_size(version);
+ if (!key_sz)
+ return -EINVAL;
+
+ err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
+ if (err < 0)
+ return err;
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
+ info->extack);
+ if (err)
+ goto err_free_msg;
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_msg:
+ nlmsg_free(rsp);
+ return err;
+}
diff --git a/net/psp/psp_sock.c b/net/psp/psp_sock.c
new file mode 100644
index 000000000000..afa966c6b69d
--- /dev/null
+++ b/net/psp/psp_sock.c
@@ -0,0 +1,295 @@
+// SPDX-License-Identifier: GPL-2.0-only
+
+#include <linux/file.h>
+#include <linux/net.h>
+#include <linux/rcupdate.h>
+#include <linux/tcp.h>
+
+#include <net/ip.h>
+#include <net/psp.h>
+#include "psp.h"
+
+struct psp_dev *psp_dev_get_for_sock(struct sock *sk)
+{
+ struct dst_entry *dst;
+ struct psp_dev *psd;
+
+ dst = sk_dst_get(sk);
+ if (!dst)
+ return NULL;
+
+ rcu_read_lock();
+ psd = rcu_dereference(dst->dev->psp_dev);
+ if (psd && !psp_dev_tryget(psd))
+ psd = NULL;
+ rcu_read_unlock();
+
+ dst_release(dst);
+
+ return psd;
+}
+
+static struct sk_buff *
+psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
+{
+ struct psp_assoc *pas;
+ bool good;
+
+ rcu_read_lock();
+ pas = psp_skb_get_assoc_rcu(skb);
+ good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd;
+ rcu_read_unlock();
+ if (!good) {
+ kfree_skb_reason(skb, SKB_DROP_REASON_PSP_OUTPUT);
+ return NULL;
+ }
+
+ return skb;
+}
+
+struct psp_assoc *psp_assoc_create(struct psp_dev *psd)
+{
+ struct psp_assoc *pas;
+
+ lockdep_assert_held(&psd->lock);
+
+ pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc),
+ GFP_KERNEL_ACCOUNT);
+ if (!pas)
+ return NULL;
+
+ pas->psd = psd;
+ pas->dev_id = psd->id;
+ pas->generation = psd->generation;
+ psp_dev_get(psd);
+ refcount_set(&pas->refcnt, 1);
+
+ list_add_tail(&pas->assocs_list, &psd->active_assocs);
+
+ return pas;
+}
+
+static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas)
+{
+ struct psp_dev *psd = pas->psd;
+ size_t sz;
+
+ lockdep_assert_held(&psd->lock);
+
+ sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc);
+ return kmemdup(pas, sz, GFP_KERNEL);
+}
+
+static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas,
+ struct netlink_ext_ack *extack)
+{
+ return psd->ops->tx_key_add(psd, pas, extack);
+}
+
+void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas)
+{
+ if (pas->tx.spi)
+ psd->ops->tx_key_del(psd, pas);
+ list_del(&pas->assocs_list);
+}
+
+static void psp_assoc_free(struct work_struct *work)
+{
+ struct psp_assoc *pas = container_of(work, struct psp_assoc, work);
+ struct psp_dev *psd = pas->psd;
+
+ mutex_lock(&psd->lock);
+ if (psd->ops)
+ psp_dev_tx_key_del(psd, pas);
+ mutex_unlock(&psd->lock);
+ psp_dev_put(psd);
+ kfree(pas);
+}
+
+static void psp_assoc_free_queue(struct rcu_head *head)
+{
+ struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu);
+
+ INIT_WORK(&pas->work, psp_assoc_free);
+ schedule_work(&pas->work);
+}
+
+/**
+ * psp_assoc_put() - release a reference on a PSP association
+ * @pas: association to release
+ */
+void psp_assoc_put(struct psp_assoc *pas)
+{
+ if (pas && refcount_dec_and_test(&pas->refcnt))
+ call_rcu(&pas->rcu, psp_assoc_free_queue);
+}
+
+void psp_sk_assoc_free(struct sock *sk)
+{
+ struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1);
+
+ rcu_assign_pointer(sk->psp_assoc, NULL);
+ psp_assoc_put(pas);
+}
+
+int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
+ struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack)
+{
+ int err;
+
+ memcpy(&pas->rx, key, sizeof(*key));
+
+ lock_sock(sk);
+
+ if (psp_sk_assoc(sk)) {
+ NL_SET_ERR_MSG(extack, "Socket already has PSP state");
+ err = -EBUSY;
+ goto exit_unlock;
+ }
+
+ refcount_inc(&pas->refcnt);
+ rcu_assign_pointer(sk->psp_assoc, pas);
+ err = 0;
+
+exit_unlock:
+ release_sock(sk);
+
+ return err;
+}
+
+static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas)
+{
+ struct psp_skb_ext *pse;
+ struct sk_buff *skb;
+
+ skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) {
+ pse = skb_ext_find(skb, SKB_EXT_PSP);
+ if (!psp_pse_matches_pas(pse, pas))
+ return -EBUSY;
+ }
+
+ skb_queue_walk(&sk->sk_receive_queue, skb) {
+ pse = skb_ext_find(skb, SKB_EXT_PSP);
+ if (!psp_pse_matches_pas(pse, pas))
+ return -EBUSY;
+ }
+ return 0;
+}
+
+int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
+ u32 version, struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack)
+{
+ struct inet_connection_sock *icsk;
+ struct psp_assoc *pas, *dummy;
+ int err;
+
+ lock_sock(sk);
+
+ pas = psp_sk_assoc(sk);
+ if (!pas) {
+ NL_SET_ERR_MSG(extack, "Socket has no Rx key");
+ err = -EINVAL;
+ goto exit_unlock;
+ }
+ if (pas->psd != psd) {
+ NL_SET_ERR_MSG(extack, "Rx key from different device");
+ err = -EINVAL;
+ goto exit_unlock;
+ }
+ if (pas->version != version) {
+ NL_SET_ERR_MSG(extack,
+ "PSP version mismatch with existing state");
+ err = -EINVAL;
+ goto exit_unlock;
+ }
+ if (pas->tx.spi) {
+ NL_SET_ERR_MSG(extack, "Tx key already set");
+ err = -EBUSY;
+ goto exit_unlock;
+ }
+
+ err = psp_sock_recv_queue_check(sk, pas);
+ if (err) {
+ NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue");
+ goto exit_unlock;
+ }
+
+ /* Pass a fake association to drivers to make sure they don't
+ * try to store pointers to it. For re-keying we'll need to
+ * re-allocate the assoc structures.
+ */
+ dummy = psp_assoc_dummy(pas);
+ if (!dummy) {
+ err = -ENOMEM;
+ goto exit_unlock;
+ }
+
+ memcpy(&dummy->tx, key, sizeof(*key));
+ err = psp_dev_tx_key_add(psd, dummy, extack);
+ if (err)
+ goto exit_free_dummy;
+
+ memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc);
+ memcpy(&pas->tx, key, sizeof(*key));
+
+ WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit);
+ tcp_write_collapse_fence(sk);
+ pas->upgrade_seq = tcp_sk(sk)->rcv_nxt;
+
+ icsk = inet_csk(sk);
+ icsk->icsk_ext_hdr_len += psp_sk_overhead(sk);
+ icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie);
+
+exit_free_dummy:
+ kfree(dummy);
+exit_unlock:
+ release_sock(sk);
+ return err;
+}
+
+void psp_assocs_key_rotated(struct psp_dev *psd)
+{
+ struct psp_assoc *pas, *next;
+
+ /* Mark the stale associations as invalid, they will no longer
+ * be able to Rx any traffic.
+ */
+ list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list)
+ pas->generation |= ~PSP_GEN_VALID_MASK;
+ list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
+ list_splice_init(&psd->active_assocs, &psd->prev_assocs);
+
+ /* TODO: we should inform the sockets that got shut down */
+}
+
+void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk)
+{
+ struct psp_assoc *pas = psp_sk_assoc(sk);
+
+ if (pas)
+ refcount_inc(&pas->refcnt);
+ rcu_assign_pointer(tw->psp_assoc, pas);
+ tw->tw_validate_xmit_skb = psp_validate_xmit;
+}
+
+void psp_twsk_assoc_free(struct inet_timewait_sock *tw)
+{
+ struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1);
+
+ rcu_assign_pointer(tw->psp_assoc, NULL);
+ psp_assoc_put(pas);
+}
+
+void psp_reply_set_decrypted(struct sk_buff *skb)
+{
+ struct psp_assoc *pas;
+
+ rcu_read_lock();
+ pas = psp_sk_get_assoc_rcu(skb->sk);
+ if (pas && pas->tx.spi)
+ skb->decrypted = 1;
+ rcu_read_unlock();
+}
+EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);